model.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from typing import Generator
  17. import torch
  18. import numpy as np
  19. import threading
  20. import time
  21. from torch.nn import functional as F
  22. from contextlib import nullcontext
  23. import uuid
  24. from cosyvoice.utils.common import fade_in_out
  25. from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
  26. from cosyvoice.utils.common import TrtContextWrapper
  27. class CosyVoiceModel:
  28. def __init__(self,
  29. llm: torch.nn.Module,
  30. flow: torch.nn.Module,
  31. hift: torch.nn.Module,
  32. fp16: bool = False):
  33. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  34. self.llm = llm
  35. self.flow = flow
  36. self.hift = hift
  37. self.fp16 = fp16
  38. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  39. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  40. self.token_overlap_len = 20
  41. # mel fade in out
  42. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  43. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  44. # hift cache
  45. self.mel_cache_len = 20
  46. self.source_cache_len = int(self.mel_cache_len * 256)
  47. # speech fade in out
  48. self.speech_window = np.hamming(2 * self.source_cache_len)
  49. # rtf and decoding related
  50. self.stream_scale_factor = 1
  51. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  52. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  53. self.lock = threading.Lock()
  54. # dict used to store session related variable
  55. self.tts_speech_token_dict = {}
  56. self.llm_end_dict = {}
  57. self.mel_overlap_dict = {}
  58. self.flow_cache_dict = {}
  59. self.hift_cache_dict = {}
  60. self.silent_tokens = []
  61. def load(self, llm_model, flow_model, hift_model):
  62. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  63. self.llm.to(self.device).eval()
  64. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  65. self.flow.to(self.device).eval()
  66. # in case hift_model is a hifigan model
  67. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  68. self.hift.load_state_dict(hift_state_dict, strict=True)
  69. self.hift.to(self.device).eval()
  70. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  71. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  72. self.llm.text_encoder = llm_text_encoder
  73. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  74. self.llm.llm = llm_llm
  75. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  76. self.flow.encoder = flow_encoder
  77. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
  78. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  79. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  80. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  81. del self.flow.decoder.estimator
  82. import tensorrt as trt
  83. with open(flow_decoder_estimator_model, 'rb') as f:
  84. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  85. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  86. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
  87. def get_trt_kwargs(self):
  88. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  89. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  90. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  91. input_names = ["x", "mask", "mu", "cond"]
  92. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  93. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  94. cur_silent_token_num, max_silent_token_num = 0, 5
  95. with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
  96. if isinstance(text, Generator):
  97. assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
  98. token_generator = self.llm.inference_bistream(text=text,
  99. prompt_text=prompt_text.to(self.device),
  100. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  101. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  102. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  103. embedding=llm_embedding.to(self.device))
  104. else:
  105. token_generator = self.llm.inference(text=text.to(self.device),
  106. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  107. prompt_text=prompt_text.to(self.device),
  108. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  109. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  110. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  111. embedding=llm_embedding.to(self.device),
  112. uuid=uuid)
  113. for i in token_generator:
  114. if i in self.silent_tokens:
  115. cur_silent_token_num += 1
  116. if cur_silent_token_num > max_silent_token_num:
  117. continue
  118. else:
  119. cur_silent_token_num = 0
  120. self.tts_speech_token_dict[uuid].append(i)
  121. self.llm_end_dict[uuid] = True
  122. def vc_job(self, source_speech_token, uuid):
  123. self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
  124. self.llm_end_dict[uuid] = True
  125. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  126. with torch.cuda.amp.autocast(self.fp16):
  127. tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
  128. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  129. prompt_token=prompt_token.to(self.device),
  130. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  131. prompt_feat=prompt_feat.to(self.device),
  132. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  133. embedding=embedding.to(self.device),
  134. flow_cache=self.flow_cache_dict[uuid])
  135. # mel overlap fade in out
  136. if self.mel_overlap_dict[uuid].shape[2] != 0:
  137. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  138. # append hift cache
  139. if self.hift_cache_dict[uuid] is not None:
  140. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  141. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  142. else:
  143. hift_cache_source = torch.zeros(1, 1, 0)
  144. # keep overlap mel and hift cache
  145. if finalize is False:
  146. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  147. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  148. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  149. if self.hift_cache_dict[uuid] is not None:
  150. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  151. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  152. 'source': tts_source[:, :, -self.source_cache_len:],
  153. 'speech': tts_speech[:, -self.source_cache_len:]}
  154. tts_speech = tts_speech[:, :-self.source_cache_len]
  155. else:
  156. if speed != 1.0:
  157. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  158. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  159. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  160. if self.hift_cache_dict[uuid] is not None:
  161. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  162. return tts_speech
  163. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  164. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  165. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  166. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  167. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  168. # this_uuid is used to track variables related to this inference thread
  169. this_uuid = str(uuid.uuid1())
  170. with self.lock:
  171. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  172. self.hift_cache_dict[this_uuid] = None
  173. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  174. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  175. if source_speech_token.shape[1] == 0:
  176. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  177. else:
  178. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  179. p.start()
  180. if stream is True:
  181. token_hop_len = self.token_min_hop_len
  182. while True:
  183. time.sleep(0.1)
  184. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  185. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  186. .unsqueeze(dim=0)
  187. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  188. prompt_token=flow_prompt_speech_token,
  189. prompt_feat=prompt_speech_feat,
  190. embedding=flow_embedding,
  191. uuid=this_uuid,
  192. finalize=False)
  193. yield {'tts_speech': this_tts_speech.cpu()}
  194. with self.lock:
  195. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  196. # increase token_hop_len for better speech quality
  197. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  198. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  199. break
  200. p.join()
  201. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  202. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  203. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  204. prompt_token=flow_prompt_speech_token,
  205. prompt_feat=prompt_speech_feat,
  206. embedding=flow_embedding,
  207. uuid=this_uuid,
  208. finalize=True)
  209. yield {'tts_speech': this_tts_speech.cpu()}
  210. else:
  211. # deal with all tokens
  212. p.join()
  213. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  214. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  215. prompt_token=flow_prompt_speech_token,
  216. prompt_feat=prompt_speech_feat,
  217. embedding=flow_embedding,
  218. uuid=this_uuid,
  219. finalize=True,
  220. speed=speed)
  221. yield {'tts_speech': this_tts_speech.cpu()}
  222. with self.lock:
  223. self.tts_speech_token_dict.pop(this_uuid)
  224. self.llm_end_dict.pop(this_uuid)
  225. self.mel_overlap_dict.pop(this_uuid)
  226. self.hift_cache_dict.pop(this_uuid)
  227. self.flow_cache_dict.pop(this_uuid)
  228. if torch.cuda.is_available():
  229. torch.cuda.empty_cache()
  230. torch.cuda.current_stream().synchronize()
  231. class CosyVoice2Model(CosyVoiceModel):
  232. def __init__(self,
  233. llm: torch.nn.Module,
  234. flow: torch.nn.Module,
  235. hift: torch.nn.Module,
  236. fp16: bool = False):
  237. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  238. self.llm = llm
  239. self.flow = flow
  240. self.hift = hift
  241. self.fp16 = fp16
  242. # NOTE must matching training static_chunk_size
  243. self.token_hop_len = 25
  244. # hift cache
  245. self.mel_cache_len = 8
  246. self.source_cache_len = int(self.mel_cache_len * 480)
  247. # speech fade in out
  248. self.speech_window = np.hamming(2 * self.source_cache_len)
  249. # rtf and decoding related
  250. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  251. self.lock = threading.Lock()
  252. # dict used to store session related variable
  253. self.tts_speech_token_dict = {}
  254. self.llm_end_dict = {}
  255. self.hift_cache_dict = {}
  256. self.silent_tokens = []
  257. def load_jit(self, flow_encoder_model):
  258. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  259. self.flow.encoder = flow_encoder
  260. def load_vllm(self, model_dir):
  261. export_cosyvoice2_vllm(self.llm, model_dir, self.device)
  262. from vllm import EngineArgs, LLMEngine
  263. engine_args = EngineArgs(model=model_dir,
  264. skip_tokenizer_init=True,
  265. enable_prompt_embeds=True,
  266. gpu_memory_utilization=0.2)
  267. self.llm.vllm = LLMEngine.from_engine_args(engine_args)
  268. self.llm.lock = threading.Lock()
  269. del self.llm.llm.model.model.layers
  270. def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
  271. with torch.cuda.amp.autocast(self.fp16):
  272. tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
  273. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  274. prompt_token=prompt_token.to(self.device),
  275. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  276. prompt_feat=prompt_feat.to(self.device),
  277. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  278. embedding=embedding.to(self.device),
  279. streaming=stream,
  280. finalize=finalize)
  281. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  282. # append hift cache
  283. if self.hift_cache_dict[uuid] is not None:
  284. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  285. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  286. else:
  287. hift_cache_source = torch.zeros(1, 1, 0)
  288. # keep overlap mel and hift cache
  289. if finalize is False:
  290. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  291. if self.hift_cache_dict[uuid] is not None:
  292. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  293. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  294. 'source': tts_source[:, :, -self.source_cache_len:],
  295. 'speech': tts_speech[:, -self.source_cache_len:]}
  296. tts_speech = tts_speech[:, :-self.source_cache_len]
  297. else:
  298. if speed != 1.0:
  299. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  300. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  301. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  302. if self.hift_cache_dict[uuid] is not None:
  303. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  304. return tts_speech
  305. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  306. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  307. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  308. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  309. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  310. # this_uuid is used to track variables related to this inference thread
  311. this_uuid = str(uuid.uuid1())
  312. with self.lock:
  313. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  314. self.hift_cache_dict[this_uuid] = None
  315. if source_speech_token.shape[1] == 0:
  316. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  317. else:
  318. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  319. p.start()
  320. if stream is True:
  321. token_offset = 0
  322. prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
  323. while True:
  324. time.sleep(0.1)
  325. this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
  326. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
  327. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  328. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  329. prompt_token=flow_prompt_speech_token,
  330. prompt_feat=prompt_speech_feat,
  331. embedding=flow_embedding,
  332. token_offset=token_offset,
  333. uuid=this_uuid,
  334. stream=stream,
  335. finalize=False)
  336. token_offset += this_token_hop_len
  337. yield {'tts_speech': this_tts_speech.cpu()}
  338. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
  339. break
  340. p.join()
  341. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  342. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  343. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  344. prompt_token=flow_prompt_speech_token,
  345. prompt_feat=prompt_speech_feat,
  346. embedding=flow_embedding,
  347. token_offset=token_offset,
  348. uuid=this_uuid,
  349. finalize=True)
  350. yield {'tts_speech': this_tts_speech.cpu()}
  351. else:
  352. # deal with all tokens
  353. p.join()
  354. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  355. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  356. prompt_token=flow_prompt_speech_token,
  357. prompt_feat=prompt_speech_feat,
  358. embedding=flow_embedding,
  359. token_offset=0,
  360. uuid=this_uuid,
  361. finalize=True,
  362. speed=speed)
  363. yield {'tts_speech': this_tts_speech.cpu()}
  364. with self.lock:
  365. self.tts_speech_token_dict.pop(this_uuid)
  366. self.llm_end_dict.pop(this_uuid)
  367. self.hift_cache_dict.pop(this_uuid)
  368. if torch.cuda.is_available():
  369. torch.cuda.empty_cache()
  370. torch.cuda.current_stream().synchronize()
  371. class CosyVoice3Model(CosyVoice2Model):
  372. def __init__(self,
  373. llm: torch.nn.Module,
  374. flow: torch.nn.Module,
  375. hift: torch.nn.Module,
  376. fp16: bool = False):
  377. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  378. self.llm = llm
  379. self.flow = flow
  380. self.hift = hift
  381. self.fp16 = fp16
  382. # NOTE must matching training static_chunk_size
  383. self.token_hop_len = 25
  384. # rtf and decoding related
  385. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  386. self.lock = threading.Lock()
  387. # dict used to store session related variable
  388. self.tts_speech_token_dict = {}
  389. self.llm_end_dict = {}
  390. self.hift_cache_dict = {}
  391. # FSQ silent token
  392. self.silent_tokens = [28, 29]
  393. def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
  394. with torch.cuda.amp.autocast(self.fp16):
  395. tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
  396. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  397. prompt_token=prompt_token.to(self.device),
  398. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  399. prompt_feat=prompt_feat.to(self.device),
  400. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  401. embedding=embedding.to(self.device),
  402. streaming=stream,
  403. finalize=finalize)
  404. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  405. # append mel cache
  406. if self.hift_cache_dict[uuid] is not None:
  407. hift_cache_mel = self.hift_cache_dict[uuid]['mel']
  408. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  409. self.hift_cache_dict[uuid]['mel'] = tts_mel
  410. else:
  411. self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
  412. if speed != 1.0:
  413. assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
  414. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  415. tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
  416. tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
  417. self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
  418. return tts_speech