model.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from typing import Generator
  17. import queue
  18. import torch
  19. import numpy as np
  20. import threading
  21. import time
  22. from torch.nn import functional as F
  23. from contextlib import nullcontext
  24. import uuid
  25. from cosyvoice.utils.common import fade_in_out
  26. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  27. from cosyvoice.utils.common import TrtContextWrapper
  28. class CosyVoiceModel:
  29. def __init__(self,
  30. llm: torch.nn.Module,
  31. flow: torch.nn.Module,
  32. hift: torch.nn.Module,
  33. fp16: bool = False,
  34. trt_concurrent: int = 1):
  35. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  36. self.llm = llm
  37. self.flow = flow
  38. self.hift = hift
  39. self.fp16 = fp16
  40. self.trt_concurrent = trt_concurrent
  41. if self.fp16 is True:
  42. self.llm.half()
  43. self.flow.half()
  44. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  45. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  46. self.token_overlap_len = 20
  47. # mel fade in out
  48. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  49. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  50. # hift cache
  51. self.mel_cache_len = 20
  52. self.source_cache_len = int(self.mel_cache_len * 256)
  53. # speech fade in out
  54. self.speech_window = np.hamming(2 * self.source_cache_len)
  55. # rtf and decoding related
  56. self.stream_scale_factor = 1
  57. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  58. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  59. self.lock = threading.Lock()
  60. # dict used to store session related variable
  61. self.tts_speech_token_dict = {}
  62. self.llm_end_dict = {}
  63. self.mel_overlap_dict = {}
  64. self.flow_cache_dict = {}
  65. self.hift_cache_dict = {}
  66. def load(self, llm_model, flow_model, hift_model):
  67. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  68. self.llm.to(self.device).eval()
  69. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  70. self.flow.to(self.device).eval()
  71. # in case hift_model is a hifigan model
  72. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  73. self.hift.load_state_dict(hift_state_dict, strict=True)
  74. self.hift.to(self.device).eval()
  75. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  76. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  77. self.llm.text_encoder = llm_text_encoder
  78. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  79. self.llm.llm = llm_llm
  80. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  81. self.flow.encoder = flow_encoder
  82. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
  83. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  84. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  85. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  86. del self.flow.decoder.estimator
  87. import tensorrt as trt
  88. with open(flow_decoder_estimator_model, 'rb') as f:
  89. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  90. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  91. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
  92. def get_trt_kwargs(self):
  93. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  94. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  95. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  96. input_names = ["x", "mask", "mu", "cond"]
  97. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  98. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  99. with self.llm_context, torch.cuda.amp.autocast(self.fp16):
  100. if isinstance(text, Generator):
  101. assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
  102. for i in self.llm.inference_bistream(text=text,
  103. prompt_text=prompt_text.to(self.device),
  104. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  105. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  106. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  107. embedding=llm_embedding.to(self.device)):
  108. self.tts_speech_token_dict[uuid].append(i)
  109. else:
  110. for i in self.llm.inference(text=text.to(self.device),
  111. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  112. prompt_text=prompt_text.to(self.device),
  113. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  114. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  115. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  116. embedding=llm_embedding.to(self.device)):
  117. self.tts_speech_token_dict[uuid].append(i)
  118. self.llm_end_dict[uuid] = True
  119. def vc_job(self, source_speech_token, uuid):
  120. self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
  121. self.llm_end_dict[uuid] = True
  122. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  123. with torch.cuda.amp.autocast(self.fp16):
  124. tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
  125. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  126. prompt_token=prompt_token.to(self.device),
  127. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  128. prompt_feat=prompt_feat.to(self.device),
  129. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  130. embedding=embedding.to(self.device),
  131. flow_cache=self.flow_cache_dict[uuid])
  132. # mel overlap fade in out
  133. if self.mel_overlap_dict[uuid].shape[2] != 0:
  134. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  135. # append hift cache
  136. if self.hift_cache_dict[uuid] is not None:
  137. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  138. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  139. else:
  140. hift_cache_source = torch.zeros(1, 1, 0)
  141. # keep overlap mel and hift cache
  142. if finalize is False:
  143. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  144. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  145. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  146. if self.hift_cache_dict[uuid] is not None:
  147. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  148. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  149. 'source': tts_source[:, :, -self.source_cache_len:],
  150. 'speech': tts_speech[:, -self.source_cache_len:]}
  151. tts_speech = tts_speech[:, :-self.source_cache_len]
  152. else:
  153. if speed != 1.0:
  154. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  155. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  156. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  157. if self.hift_cache_dict[uuid] is not None:
  158. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  159. return tts_speech
  160. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  161. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  162. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  163. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  164. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  165. # this_uuid is used to track variables related to this inference thread
  166. this_uuid = str(uuid.uuid1())
  167. with self.lock:
  168. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  169. self.hift_cache_dict[this_uuid] = None
  170. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  171. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  172. if source_speech_token.shape[1] == 0:
  173. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  174. else:
  175. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  176. p.start()
  177. if stream is True:
  178. token_hop_len = self.token_min_hop_len
  179. while True:
  180. time.sleep(0.1)
  181. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  182. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  183. .unsqueeze(dim=0)
  184. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  185. prompt_token=flow_prompt_speech_token,
  186. prompt_feat=prompt_speech_feat,
  187. embedding=flow_embedding,
  188. uuid=this_uuid,
  189. finalize=False)
  190. yield {'tts_speech': this_tts_speech.cpu()}
  191. with self.lock:
  192. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  193. # increase token_hop_len for better speech quality
  194. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  195. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  196. break
  197. p.join()
  198. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  199. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  200. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  201. prompt_token=flow_prompt_speech_token,
  202. prompt_feat=prompt_speech_feat,
  203. embedding=flow_embedding,
  204. uuid=this_uuid,
  205. finalize=True)
  206. yield {'tts_speech': this_tts_speech.cpu()}
  207. else:
  208. # deal with all tokens
  209. p.join()
  210. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  211. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  212. prompt_token=flow_prompt_speech_token,
  213. prompt_feat=prompt_speech_feat,
  214. embedding=flow_embedding,
  215. uuid=this_uuid,
  216. finalize=True,
  217. speed=speed)
  218. yield {'tts_speech': this_tts_speech.cpu()}
  219. with self.lock:
  220. self.tts_speech_token_dict.pop(this_uuid)
  221. self.llm_end_dict.pop(this_uuid)
  222. self.mel_overlap_dict.pop(this_uuid)
  223. self.hift_cache_dict.pop(this_uuid)
  224. self.flow_cache_dict.pop(this_uuid)
  225. if torch.cuda.is_available():
  226. torch.cuda.empty_cache()
  227. torch.cuda.current_stream().synchronize()
  228. class CosyVoice2Model(CosyVoiceModel):
  229. def __init__(self,
  230. llm: torch.nn.Module,
  231. flow: torch.nn.Module,
  232. hift: torch.nn.Module,
  233. fp16: bool = False,
  234. trt_concurrent: int = 1):
  235. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  236. self.llm = llm
  237. self.flow = flow
  238. # NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
  239. self.flow.encoder.streaming = True
  240. self.flow.decoder.estimator.streaming = True
  241. self.hift = hift
  242. self.fp16 = fp16
  243. self.trt_concurrent = trt_concurrent
  244. if self.fp16 is True:
  245. self.llm.half()
  246. self.flow.half()
  247. # NOTE must matching training static_chunk_size
  248. self.token_hop_len = 25
  249. # hift cache
  250. self.mel_cache_len = 8
  251. self.source_cache_len = int(self.mel_cache_len * 480)
  252. # speech fade in out
  253. self.speech_window = np.hamming(2 * self.source_cache_len)
  254. # rtf and decoding related
  255. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  256. self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
  257. for _ in range(trt_concurrent):
  258. self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
  259. self.lock = threading.Lock()
  260. # dict used to store session related variable
  261. self.tts_speech_token_dict = {}
  262. self.llm_end_dict = {}
  263. self.hift_cache_dict = {}
  264. self.trt_context_dict = {}
  265. def load_jit(self, flow_encoder_model):
  266. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  267. self.flow.encoder = flow_encoder
  268. def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
  269. with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
  270. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  271. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  272. prompt_token=prompt_token.to(self.device),
  273. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  274. prompt_feat=prompt_feat.to(self.device),
  275. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  276. embedding=embedding.to(self.device),
  277. finalize=finalize)
  278. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  279. # append hift cache
  280. if self.hift_cache_dict[uuid] is not None:
  281. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  282. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  283. else:
  284. hift_cache_source = torch.zeros(1, 1, 0)
  285. # keep overlap mel and hift cache
  286. if finalize is False:
  287. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  288. if self.hift_cache_dict[uuid] is not None:
  289. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  290. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  291. 'source': tts_source[:, :, -self.source_cache_len:],
  292. 'speech': tts_speech[:, -self.source_cache_len:]}
  293. tts_speech = tts_speech[:, :-self.source_cache_len]
  294. else:
  295. if speed != 1.0:
  296. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  297. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  298. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  299. if self.hift_cache_dict[uuid] is not None:
  300. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  301. return tts_speech
  302. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  303. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  304. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  305. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  306. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  307. # this_uuid is used to track variables related to this inference thread
  308. this_uuid = str(uuid.uuid1())
  309. with self.lock:
  310. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  311. self.hift_cache_dict[this_uuid] = None
  312. self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
  313. if source_speech_token.shape[1] == 0:
  314. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  315. else:
  316. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  317. p.start()
  318. if stream is True:
  319. token_offset = 0
  320. prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
  321. while True:
  322. time.sleep(0.1)
  323. this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
  324. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
  325. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  326. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  327. prompt_token=flow_prompt_speech_token,
  328. prompt_feat=prompt_speech_feat,
  329. embedding=flow_embedding,
  330. token_offset=token_offset,
  331. uuid=this_uuid,
  332. finalize=False)
  333. token_offset += this_token_hop_len
  334. yield {'tts_speech': this_tts_speech.cpu()}
  335. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
  336. break
  337. p.join()
  338. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  339. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  340. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  341. prompt_token=flow_prompt_speech_token,
  342. prompt_feat=prompt_speech_feat,
  343. embedding=flow_embedding,
  344. token_offset=token_offset,
  345. uuid=this_uuid,
  346. finalize=True)
  347. yield {'tts_speech': this_tts_speech.cpu()}
  348. else:
  349. # deal with all tokens
  350. p.join()
  351. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  352. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  353. prompt_token=flow_prompt_speech_token,
  354. prompt_feat=prompt_speech_feat,
  355. embedding=flow_embedding,
  356. token_offset=0,
  357. uuid=this_uuid,
  358. finalize=True,
  359. speed=speed)
  360. yield {'tts_speech': this_tts_speech.cpu()}
  361. with self.lock:
  362. self.tts_speech_token_dict.pop(this_uuid)
  363. self.llm_end_dict.pop(this_uuid)
  364. self.hift_cache_dict.pop(this_uuid)
  365. self.trt_context_pool.put(self.trt_context_dict[this_uuid])
  366. self.trt_context_dict.pop(this_uuid)
  367. if torch.cuda.is_available():
  368. torch.cuda.empty_cache()
  369. torch.cuda.current_stream().synchronize()