model.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from typing import Generator
  17. import torch
  18. import numpy as np
  19. import threading
  20. import time
  21. from torch.nn import functional as F
  22. from contextlib import nullcontext
  23. import uuid
  24. from cosyvoice.utils.common import fade_in_out
  25. from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
  26. from cosyvoice.utils.common import TrtContextWrapper
  27. class CosyVoiceModel:
  28. def __init__(self,
  29. llm: torch.nn.Module,
  30. flow: torch.nn.Module,
  31. hift: torch.nn.Module,
  32. fp16: bool = False):
  33. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  34. self.llm = llm
  35. self.flow = flow
  36. self.hift = hift
  37. self.fp16 = fp16
  38. if self.fp16 is True:
  39. self.llm.half()
  40. self.flow.half()
  41. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  42. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  43. self.token_overlap_len = 20
  44. # mel fade in out
  45. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  46. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  47. # hift cache
  48. self.mel_cache_len = 20
  49. self.source_cache_len = int(self.mel_cache_len * 256)
  50. # speech fade in out
  51. self.speech_window = np.hamming(2 * self.source_cache_len)
  52. # rtf and decoding related
  53. self.stream_scale_factor = 1
  54. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  55. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  56. self.lock = threading.Lock()
  57. # dict used to store session related variable
  58. self.tts_speech_token_dict = {}
  59. self.llm_end_dict = {}
  60. self.mel_overlap_dict = {}
  61. self.flow_cache_dict = {}
  62. self.hift_cache_dict = {}
  63. def load(self, llm_model, flow_model, hift_model):
  64. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  65. self.llm.to(self.device).eval()
  66. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  67. self.flow.to(self.device).eval()
  68. # in case hift_model is a hifigan model
  69. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  70. self.hift.load_state_dict(hift_state_dict, strict=True)
  71. self.hift.to(self.device).eval()
  72. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  73. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  74. self.llm.text_encoder = llm_text_encoder
  75. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  76. self.llm.llm = llm_llm
  77. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  78. self.flow.encoder = flow_encoder
  79. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
  80. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  81. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  82. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  83. del self.flow.decoder.estimator
  84. import tensorrt as trt
  85. with open(flow_decoder_estimator_model, 'rb') as f:
  86. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  87. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  88. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  89. def get_trt_kwargs(self):
  90. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  91. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  92. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  93. input_names = ["x", "mask", "mu", "cond"]
  94. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  95. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  96. with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
  97. if isinstance(text, Generator):
  98. assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
  99. for i in self.llm.inference_bistream(text=text,
  100. prompt_text=prompt_text.to(self.device),
  101. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  102. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  103. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  104. embedding=llm_embedding.to(self.device)):
  105. self.tts_speech_token_dict[uuid].append(i)
  106. else:
  107. for i in self.llm.inference(text=text.to(self.device),
  108. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  109. prompt_text=prompt_text.to(self.device),
  110. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  111. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  112. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  113. embedding=llm_embedding.to(self.device),
  114. uuid=uuid):
  115. self.tts_speech_token_dict[uuid].append(i)
  116. self.llm_end_dict[uuid] = True
  117. def vc_job(self, source_speech_token, uuid):
  118. self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
  119. self.llm_end_dict[uuid] = True
  120. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  121. with torch.cuda.amp.autocast(self.fp16):
  122. tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
  123. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  124. prompt_token=prompt_token.to(self.device),
  125. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  126. prompt_feat=prompt_feat.to(self.device),
  127. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  128. embedding=embedding.to(self.device),
  129. flow_cache=self.flow_cache_dict[uuid])
  130. # mel overlap fade in out
  131. if self.mel_overlap_dict[uuid].shape[2] != 0:
  132. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  133. # append hift cache
  134. if self.hift_cache_dict[uuid] is not None:
  135. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  136. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  137. else:
  138. hift_cache_source = torch.zeros(1, 1, 0)
  139. # keep overlap mel and hift cache
  140. if finalize is False:
  141. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  142. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  143. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  144. if self.hift_cache_dict[uuid] is not None:
  145. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  146. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  147. 'source': tts_source[:, :, -self.source_cache_len:],
  148. 'speech': tts_speech[:, -self.source_cache_len:]}
  149. tts_speech = tts_speech[:, :-self.source_cache_len]
  150. else:
  151. if speed != 1.0:
  152. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  153. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  154. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  155. if self.hift_cache_dict[uuid] is not None:
  156. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  157. return tts_speech
  158. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  159. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  160. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  161. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  162. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  163. # this_uuid is used to track variables related to this inference thread
  164. this_uuid = str(uuid.uuid1())
  165. with self.lock:
  166. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  167. self.hift_cache_dict[this_uuid] = None
  168. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  169. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  170. if source_speech_token.shape[1] == 0:
  171. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  172. else:
  173. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  174. p.start()
  175. if stream is True:
  176. token_hop_len = self.token_min_hop_len
  177. while True:
  178. time.sleep(0.1)
  179. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  180. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  181. .unsqueeze(dim=0)
  182. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  183. prompt_token=flow_prompt_speech_token,
  184. prompt_feat=prompt_speech_feat,
  185. embedding=flow_embedding,
  186. uuid=this_uuid,
  187. finalize=False)
  188. yield {'tts_speech': this_tts_speech.cpu()}
  189. with self.lock:
  190. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  191. # increase token_hop_len for better speech quality
  192. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  193. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  194. break
  195. p.join()
  196. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  197. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  198. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  199. prompt_token=flow_prompt_speech_token,
  200. prompt_feat=prompt_speech_feat,
  201. embedding=flow_embedding,
  202. uuid=this_uuid,
  203. finalize=True)
  204. yield {'tts_speech': this_tts_speech.cpu()}
  205. else:
  206. # deal with all tokens
  207. p.join()
  208. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  209. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  210. prompt_token=flow_prompt_speech_token,
  211. prompt_feat=prompt_speech_feat,
  212. embedding=flow_embedding,
  213. uuid=this_uuid,
  214. finalize=True,
  215. speed=speed)
  216. yield {'tts_speech': this_tts_speech.cpu()}
  217. with self.lock:
  218. self.tts_speech_token_dict.pop(this_uuid)
  219. self.llm_end_dict.pop(this_uuid)
  220. self.mel_overlap_dict.pop(this_uuid)
  221. self.hift_cache_dict.pop(this_uuid)
  222. self.flow_cache_dict.pop(this_uuid)
  223. if torch.cuda.is_available():
  224. torch.cuda.empty_cache()
  225. torch.cuda.current_stream().synchronize()
  226. class CosyVoice2Model(CosyVoiceModel):
  227. def __init__(self,
  228. llm: torch.nn.Module,
  229. flow: torch.nn.Module,
  230. hift: torch.nn.Module,
  231. fp16: bool = False):
  232. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  233. self.llm = llm
  234. self.flow = flow
  235. self.hift = hift
  236. self.fp16 = fp16
  237. if self.fp16 is True:
  238. self.llm.half()
  239. self.flow.half()
  240. # NOTE must matching training static_chunk_size
  241. self.token_hop_len = 25
  242. # hift cache
  243. self.mel_cache_len = 8
  244. self.source_cache_len = int(self.mel_cache_len * 480)
  245. # speech fade in out
  246. self.speech_window = np.hamming(2 * self.source_cache_len)
  247. # rtf and decoding related
  248. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  249. self.lock = threading.Lock()
  250. # dict used to store session related variable
  251. self.tts_speech_token_dict = {}
  252. self.llm_end_dict = {}
  253. self.hift_cache_dict = {}
  254. def load_jit(self, flow_encoder_model):
  255. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  256. self.flow.encoder = flow_encoder
  257. def load_vllm(self, model_dir):
  258. export_cosyvoice2_vllm(self.llm, model_dir, self.device)
  259. from vllm import EngineArgs, LLMEngine
  260. engine_args = EngineArgs(model=model_dir,
  261. skip_tokenizer_init=True,
  262. enable_prompt_embeds=True,
  263. gpu_memory_utilization=0.2)
  264. self.llm.vllm = LLMEngine.from_engine_args(engine_args)
  265. self.llm.lock = threading.Lock()
  266. del self.llm.llm.model.model.layers
  267. def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
  268. with torch.cuda.amp.autocast(self.fp16):
  269. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  270. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  271. prompt_token=prompt_token.to(self.device),
  272. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  273. prompt_feat=prompt_feat.to(self.device),
  274. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  275. embedding=embedding.to(self.device),
  276. streaming=stream,
  277. finalize=finalize)
  278. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  279. # append hift cache
  280. if self.hift_cache_dict[uuid] is not None:
  281. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  282. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  283. else:
  284. hift_cache_source = torch.zeros(1, 1, 0)
  285. # keep overlap mel and hift cache
  286. if finalize is False:
  287. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  288. if self.hift_cache_dict[uuid] is not None:
  289. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  290. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  291. 'source': tts_source[:, :, -self.source_cache_len:],
  292. 'speech': tts_speech[:, -self.source_cache_len:]}
  293. tts_speech = tts_speech[:, :-self.source_cache_len]
  294. else:
  295. if speed != 1.0:
  296. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  297. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  298. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  299. if self.hift_cache_dict[uuid] is not None:
  300. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  301. return tts_speech
  302. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  303. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  304. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  305. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  306. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  307. # this_uuid is used to track variables related to this inference thread
  308. this_uuid = str(uuid.uuid1())
  309. with self.lock:
  310. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  311. self.hift_cache_dict[this_uuid] = None
  312. if source_speech_token.shape[1] == 0:
  313. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  314. else:
  315. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  316. p.start()
  317. if stream is True:
  318. token_offset = 0
  319. prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
  320. while True:
  321. time.sleep(0.1)
  322. this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
  323. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
  324. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  325. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  326. prompt_token=flow_prompt_speech_token,
  327. prompt_feat=prompt_speech_feat,
  328. embedding=flow_embedding,
  329. token_offset=token_offset,
  330. uuid=this_uuid,
  331. stream=stream,
  332. finalize=False)
  333. token_offset += this_token_hop_len
  334. yield {'tts_speech': this_tts_speech.cpu()}
  335. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
  336. break
  337. p.join()
  338. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  339. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  340. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  341. prompt_token=flow_prompt_speech_token,
  342. prompt_feat=prompt_speech_feat,
  343. embedding=flow_embedding,
  344. token_offset=token_offset,
  345. uuid=this_uuid,
  346. finalize=True)
  347. yield {'tts_speech': this_tts_speech.cpu()}
  348. else:
  349. # deal with all tokens
  350. p.join()
  351. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  352. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  353. prompt_token=flow_prompt_speech_token,
  354. prompt_feat=prompt_speech_feat,
  355. embedding=flow_embedding,
  356. token_offset=0,
  357. uuid=this_uuid,
  358. finalize=True,
  359. speed=speed)
  360. yield {'tts_speech': this_tts_speech.cpu()}
  361. with self.lock:
  362. self.tts_speech_token_dict.pop(this_uuid)
  363. self.llm_end_dict.pop(this_uuid)
  364. self.hift_cache_dict.pop(this_uuid)
  365. if torch.cuda.is_available():
  366. torch.cuda.empty_cache()
  367. torch.cuda.current_stream().synchronize()