model.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import torch
  16. import numpy as np
  17. import threading
  18. import time
  19. from torch.nn import functional as F
  20. from contextlib import nullcontext
  21. import uuid
  22. from cosyvoice.utils.common import fade_in_out
  23. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  24. class CosyVoiceModel:
  25. def __init__(self,
  26. llm: torch.nn.Module,
  27. flow: torch.nn.Module,
  28. hift: torch.nn.Module,
  29. fp16: bool):
  30. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  31. self.llm = llm
  32. self.flow = flow
  33. self.hift = hift
  34. self.fp16 = fp16
  35. self.llm.fp16 = fp16
  36. self.flow.fp16 = fp16
  37. if self.fp16 is True:
  38. self.llm.half()
  39. self.flow.half()
  40. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  41. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  42. self.token_overlap_len = 20
  43. # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
  44. self.flow.decoder.estimator.static_chunk_size = 0
  45. # mel fade in out
  46. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  47. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  48. # hift cache
  49. self.mel_cache_len = 20
  50. self.source_cache_len = int(self.mel_cache_len * 256)
  51. # speech fade in out
  52. self.speech_window = np.hamming(2 * self.source_cache_len)
  53. # rtf and decoding related
  54. self.stream_scale_factor = 1
  55. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  56. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  57. self.lock = threading.Lock()
  58. # dict used to store session related variable
  59. self.tts_speech_token_dict = {}
  60. self.llm_end_dict = {}
  61. self.mel_overlap_dict = {}
  62. self.flow_cache_dict = {}
  63. self.hift_cache_dict = {}
  64. def load(self, llm_model, flow_model, hift_model):
  65. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  66. self.llm.to(self.device).eval()
  67. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  68. self.flow.to(self.device).eval()
  69. # in case hift_model is a hifigan model
  70. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  71. self.hift.load_state_dict(hift_state_dict, strict=True)
  72. self.hift.to(self.device).eval()
  73. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  74. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  75. self.llm.text_encoder = llm_text_encoder
  76. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  77. self.llm.llm = llm_llm
  78. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  79. self.flow.encoder = flow_encoder
  80. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
  81. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  82. if not os.path.exists(flow_decoder_estimator_model):
  83. convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
  84. del self.flow.decoder.estimator
  85. import tensorrt as trt
  86. with open(flow_decoder_estimator_model, 'rb') as f:
  87. self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  88. if self.flow.decoder.estimator_engine is None:
  89. raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
  90. self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
  91. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  92. with self.llm_context:
  93. for i in self.llm.inference(text=text.to(self.device),
  94. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  95. prompt_text=prompt_text.to(self.device),
  96. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  97. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  98. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  99. embedding=llm_embedding.to(self.device)):
  100. self.tts_speech_token_dict[uuid].append(i)
  101. self.llm_end_dict[uuid] = True
  102. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  103. tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
  104. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  105. prompt_token=prompt_token.to(self.device),
  106. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  107. prompt_feat=prompt_feat.to(self.device),
  108. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  109. embedding=embedding.to(self.device),
  110. flow_cache=self.flow_cache_dict[uuid])
  111. self.flow_cache_dict[uuid] = flow_cache
  112. # mel overlap fade in out
  113. if self.mel_overlap_dict[uuid].shape[2] != 0:
  114. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  115. # append hift cache
  116. if self.hift_cache_dict[uuid] is not None:
  117. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  118. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  119. else:
  120. hift_cache_source = torch.zeros(1, 1, 0)
  121. # keep overlap mel and hift cache
  122. if finalize is False:
  123. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  124. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  125. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  126. if self.hift_cache_dict[uuid] is not None:
  127. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  128. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  129. 'source': tts_source[:, :, -self.source_cache_len:],
  130. 'speech': tts_speech[:, -self.source_cache_len:]}
  131. tts_speech = tts_speech[:, :-self.source_cache_len]
  132. else:
  133. if speed != 1.0:
  134. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  135. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  136. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  137. if self.hift_cache_dict[uuid] is not None:
  138. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  139. return tts_speech
  140. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  141. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  142. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  143. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  144. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  145. # this_uuid is used to track variables related to this inference thread
  146. this_uuid = str(uuid.uuid1())
  147. with self.lock:
  148. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  149. self.hift_cache_dict[this_uuid] = None
  150. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  151. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  152. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  153. p.start()
  154. if stream is True:
  155. token_hop_len = self.token_min_hop_len
  156. while True:
  157. time.sleep(0.1)
  158. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  159. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  160. .unsqueeze(dim=0)
  161. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  162. prompt_token=flow_prompt_speech_token,
  163. prompt_feat=prompt_speech_feat,
  164. embedding=flow_embedding,
  165. uuid=this_uuid,
  166. finalize=False)
  167. yield {'tts_speech': this_tts_speech.cpu()}
  168. with self.lock:
  169. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  170. # increase token_hop_len for better speech quality
  171. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  172. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  173. break
  174. p.join()
  175. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  176. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  177. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  178. prompt_token=flow_prompt_speech_token,
  179. prompt_feat=prompt_speech_feat,
  180. embedding=flow_embedding,
  181. uuid=this_uuid,
  182. finalize=True)
  183. yield {'tts_speech': this_tts_speech.cpu()}
  184. else:
  185. # deal with all tokens
  186. p.join()
  187. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  188. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  189. prompt_token=flow_prompt_speech_token,
  190. prompt_feat=prompt_speech_feat,
  191. embedding=flow_embedding,
  192. uuid=this_uuid,
  193. finalize=True,
  194. speed=speed)
  195. yield {'tts_speech': this_tts_speech.cpu()}
  196. with self.lock:
  197. self.tts_speech_token_dict.pop(this_uuid)
  198. self.llm_end_dict.pop(this_uuid)
  199. self.mel_overlap_dict.pop(this_uuid)
  200. self.hift_cache_dict.pop(this_uuid)
  201. self.flow_cache_dict.pop(this_uuid)
  202. torch.cuda.empty_cache()
  203. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  204. # this_uuid is used to track variables related to this inference thread
  205. this_uuid = str(uuid.uuid1())
  206. with self.lock:
  207. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  208. self.hift_cache_dict[this_uuid] = None
  209. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  210. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  211. if stream is True:
  212. token_hop_len = self.token_min_hop_len
  213. while True:
  214. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  215. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  216. .unsqueeze(dim=0)
  217. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  218. prompt_token=flow_prompt_speech_token,
  219. prompt_feat=prompt_speech_feat,
  220. embedding=flow_embedding,
  221. uuid=this_uuid,
  222. finalize=False)
  223. yield {'tts_speech': this_tts_speech.cpu()}
  224. with self.lock:
  225. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  226. # increase token_hop_len for better speech quality
  227. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  228. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  229. break
  230. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  231. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  232. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  233. prompt_token=flow_prompt_speech_token,
  234. prompt_feat=prompt_speech_feat,
  235. embedding=flow_embedding,
  236. uuid=this_uuid,
  237. finalize=True)
  238. yield {'tts_speech': this_tts_speech.cpu()}
  239. else:
  240. # deal with all tokens
  241. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  242. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  243. prompt_token=flow_prompt_speech_token,
  244. prompt_feat=prompt_speech_feat,
  245. embedding=flow_embedding,
  246. uuid=this_uuid,
  247. finalize=True,
  248. speed=speed)
  249. yield {'tts_speech': this_tts_speech.cpu()}
  250. with self.lock:
  251. self.tts_speech_token_dict.pop(this_uuid)
  252. self.llm_end_dict.pop(this_uuid)
  253. self.mel_overlap_dict.pop(this_uuid)
  254. self.hift_cache_dict.pop(this_uuid)
  255. torch.cuda.empty_cache()
  256. class CosyVoice2Model(CosyVoiceModel):
  257. def __init__(self,
  258. llm: torch.nn.Module,
  259. flow: torch.nn.Module,
  260. hift: torch.nn.Module,
  261. fp16: bool):
  262. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  263. self.llm = llm
  264. self.flow = flow
  265. self.hift = hift
  266. self.fp16 = fp16
  267. self.llm.fp16 = fp16
  268. self.flow.fp16 = fp16
  269. if self.fp16 is True:
  270. self.llm.half()
  271. self.flow.half()
  272. self.token_hop_len = 2 * self.flow.input_frame_rate
  273. # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
  274. self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
  275. self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
  276. # hift cache
  277. self.mel_cache_len = 8
  278. self.source_cache_len = int(self.mel_cache_len * 480)
  279. # speech fade in out
  280. self.speech_window = np.hamming(2 * self.source_cache_len)
  281. # rtf and decoding related
  282. self.stream_scale_factor = 1
  283. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  284. self.lock = threading.Lock()
  285. # dict used to store session related variable
  286. self.tts_speech_token_dict = {}
  287. self.llm_end_dict = {}
  288. self.hift_cache_dict = {}
  289. def load_jit(self, flow_encoder_model):
  290. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  291. self.flow.encoder = flow_encoder
  292. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
  293. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  294. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  295. prompt_token=prompt_token.to(self.device),
  296. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  297. prompt_feat=prompt_feat.to(self.device),
  298. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  299. embedding=embedding.to(self.device),
  300. finalize=finalize)
  301. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  302. # append hift cache
  303. if self.hift_cache_dict[uuid] is not None:
  304. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  305. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  306. else:
  307. hift_cache_source = torch.zeros(1, 1, 0)
  308. # keep overlap mel and hift cache
  309. if finalize is False:
  310. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  311. if self.hift_cache_dict[uuid] is not None:
  312. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  313. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  314. 'source': tts_source[:, :, -self.source_cache_len:],
  315. 'speech': tts_speech[:, -self.source_cache_len:]}
  316. tts_speech = tts_speech[:, :-self.source_cache_len]
  317. else:
  318. if speed != 1.0:
  319. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  320. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  321. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  322. if self.hift_cache_dict[uuid] is not None:
  323. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  324. return tts_speech
  325. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  326. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  327. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  328. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  329. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  330. # this_uuid is used to track variables related to this inference thread
  331. this_uuid = str(uuid.uuid1())
  332. with self.lock:
  333. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  334. self.hift_cache_dict[this_uuid] = None
  335. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  336. p.start()
  337. if stream is True:
  338. token_offset = 0
  339. while True:
  340. time.sleep(0.1)
  341. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
  342. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  343. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  344. prompt_token=flow_prompt_speech_token,
  345. prompt_feat=prompt_speech_feat,
  346. embedding=flow_embedding,
  347. uuid=this_uuid,
  348. token_offset=token_offset,
  349. finalize=False)
  350. token_offset += self.token_hop_len
  351. yield {'tts_speech': this_tts_speech.cpu()}
  352. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
  353. break
  354. p.join()
  355. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  356. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  357. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  358. prompt_token=flow_prompt_speech_token,
  359. prompt_feat=prompt_speech_feat,
  360. embedding=flow_embedding,
  361. uuid=this_uuid,
  362. token_offset=token_offset,
  363. finalize=True)
  364. yield {'tts_speech': this_tts_speech.cpu()}
  365. else:
  366. # deal with all tokens
  367. p.join()
  368. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  369. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  370. prompt_token=flow_prompt_speech_token,
  371. prompt_feat=prompt_speech_feat,
  372. embedding=flow_embedding,
  373. uuid=this_uuid,
  374. token_offset=0,
  375. finalize=True,
  376. speed=speed)
  377. yield {'tts_speech': this_tts_speech.cpu()}
  378. with self.lock:
  379. self.tts_speech_token_dict.pop(this_uuid)
  380. self.llm_end_dict.pop(this_uuid)
  381. torch.cuda.empty_cache()