model.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from typing import Generator
  17. import queue
  18. import torch
  19. import numpy as np
  20. import threading
  21. import time
  22. from torch.nn import functional as F
  23. from contextlib import nullcontext
  24. import uuid
  25. from cosyvoice.utils.common import fade_in_out
  26. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  27. from cosyvoice.utils.common import TrtContextWrapper
  28. class CosyVoiceModel:
  29. def __init__(self,
  30. llm: torch.nn.Module,
  31. flow: torch.nn.Module,
  32. hift: torch.nn.Module,
  33. fp16: bool = False):
  34. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  35. self.llm = llm
  36. self.flow = flow
  37. self.hift = hift
  38. self.fp16 = fp16
  39. if self.fp16 is True:
  40. self.llm.half()
  41. self.flow.half()
  42. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  43. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  44. self.token_overlap_len = 20
  45. # mel fade in out
  46. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  47. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  48. # hift cache
  49. self.mel_cache_len = 20
  50. self.source_cache_len = int(self.mel_cache_len * 256)
  51. # speech fade in out
  52. self.speech_window = np.hamming(2 * self.source_cache_len)
  53. # rtf and decoding related
  54. self.stream_scale_factor = 1
  55. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  56. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  57. self.lock = threading.Lock()
  58. # dict used to store session related variable
  59. self.tts_speech_token_dict = {}
  60. self.llm_end_dict = {}
  61. self.mel_overlap_dict = {}
  62. self.flow_cache_dict = {}
  63. self.hift_cache_dict = {}
  64. def load(self, llm_model, flow_model, hift_model):
  65. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  66. self.llm.to(self.device).eval()
  67. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  68. self.flow.to(self.device).eval()
  69. # in case hift_model is a hifigan model
  70. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  71. self.hift.load_state_dict(hift_state_dict, strict=True)
  72. self.hift.to(self.device).eval()
  73. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  74. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  75. self.llm.text_encoder = llm_text_encoder
  76. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  77. self.llm.llm = llm_llm
  78. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  79. self.flow.encoder = flow_encoder
  80. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
  81. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  82. if not os.path.exists(flow_decoder_estimator_model):
  83. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  84. if os.path.getsize(flow_decoder_estimator_model) == 0:
  85. raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
  86. del self.flow.decoder.estimator
  87. import tensorrt as trt
  88. with open(flow_decoder_estimator_model, 'rb') as f:
  89. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  90. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  91. if isinstance(self, CosyVoice2Model):
  92. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
  93. else:
  94. self.flow.decoder.estimator = estimator_engine.create_execution_context()
  95. def get_trt_kwargs(self):
  96. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  97. opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
  98. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  99. input_names = ["x", "mask", "mu", "cond"]
  100. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  101. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  102. with self.llm_context, torch.cuda.amp.autocast(self.fp16):
  103. if isinstance(text, Generator):
  104. assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
  105. for i in self.llm.inference_bistream(text=text,
  106. prompt_text=prompt_text.to(self.device),
  107. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  108. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  109. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  110. embedding=llm_embedding.to(self.device)):
  111. self.tts_speech_token_dict[uuid].append(i)
  112. else:
  113. for i in self.llm.inference(text=text.to(self.device),
  114. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  115. prompt_text=prompt_text.to(self.device),
  116. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  117. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  118. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  119. embedding=llm_embedding.to(self.device)):
  120. self.tts_speech_token_dict[uuid].append(i)
  121. self.llm_end_dict[uuid] = True
  122. def vc_job(self, source_speech_token, uuid):
  123. self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
  124. self.llm_end_dict[uuid] = True
  125. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  126. with torch.cuda.amp.autocast(self.fp16):
  127. tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
  128. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  129. prompt_token=prompt_token.to(self.device),
  130. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  131. prompt_feat=prompt_feat.to(self.device),
  132. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  133. embedding=embedding.to(self.device),
  134. flow_cache=self.flow_cache_dict[uuid])
  135. # mel overlap fade in out
  136. if self.mel_overlap_dict[uuid].shape[2] != 0:
  137. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  138. # append hift cache
  139. if self.hift_cache_dict[uuid] is not None:
  140. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  141. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  142. else:
  143. hift_cache_source = torch.zeros(1, 1, 0)
  144. # keep overlap mel and hift cache
  145. if finalize is False:
  146. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  147. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  148. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  149. if self.hift_cache_dict[uuid] is not None:
  150. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  151. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  152. 'source': tts_source[:, :, -self.source_cache_len:],
  153. 'speech': tts_speech[:, -self.source_cache_len:]}
  154. tts_speech = tts_speech[:, :-self.source_cache_len]
  155. else:
  156. if speed != 1.0:
  157. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  158. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  159. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  160. if self.hift_cache_dict[uuid] is not None:
  161. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  162. return tts_speech
  163. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  164. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  165. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  166. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  167. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  168. # this_uuid is used to track variables related to this inference thread
  169. this_uuid = str(uuid.uuid1())
  170. with self.lock:
  171. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  172. self.hift_cache_dict[this_uuid] = None
  173. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  174. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  175. if source_speech_token.shape[1] == 0:
  176. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  177. else:
  178. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  179. p.start()
  180. if stream is True:
  181. token_hop_len = self.token_min_hop_len
  182. while True:
  183. time.sleep(0.1)
  184. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  185. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  186. .unsqueeze(dim=0)
  187. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  188. prompt_token=flow_prompt_speech_token,
  189. prompt_feat=prompt_speech_feat,
  190. embedding=flow_embedding,
  191. uuid=this_uuid,
  192. finalize=False)
  193. yield {'tts_speech': this_tts_speech.cpu()}
  194. with self.lock:
  195. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  196. # increase token_hop_len for better speech quality
  197. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  198. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  199. break
  200. p.join()
  201. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  202. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  203. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  204. prompt_token=flow_prompt_speech_token,
  205. prompt_feat=prompt_speech_feat,
  206. embedding=flow_embedding,
  207. uuid=this_uuid,
  208. finalize=True)
  209. yield {'tts_speech': this_tts_speech.cpu()}
  210. else:
  211. # deal with all tokens
  212. p.join()
  213. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  214. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  215. prompt_token=flow_prompt_speech_token,
  216. prompt_feat=prompt_speech_feat,
  217. embedding=flow_embedding,
  218. uuid=this_uuid,
  219. finalize=True,
  220. speed=speed)
  221. yield {'tts_speech': this_tts_speech.cpu()}
  222. with self.lock:
  223. self.tts_speech_token_dict.pop(this_uuid)
  224. self.llm_end_dict.pop(this_uuid)
  225. self.mel_overlap_dict.pop(this_uuid)
  226. self.hift_cache_dict.pop(this_uuid)
  227. self.flow_cache_dict.pop(this_uuid)
  228. if torch.cuda.is_available():
  229. torch.cuda.empty_cache()
  230. torch.cuda.current_stream().synchronize()
  231. class CosyVoice2Model(CosyVoiceModel):
  232. def __init__(self,
  233. llm: torch.nn.Module,
  234. flow: torch.nn.Module,
  235. hift: torch.nn.Module,
  236. fp16: bool = False,
  237. use_flow_cache: bool = False,
  238. trt_concurrent: int = 1):
  239. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  240. self.llm = llm
  241. self.flow = flow
  242. self.hift = hift
  243. self.fp16 = fp16
  244. self.use_flow_cache = use_flow_cache
  245. self.trt_concurrent = trt_concurrent
  246. if self.fp16 is True:
  247. self.llm.half()
  248. self.flow.half()
  249. # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
  250. self.token_hop_len = 25
  251. self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
  252. # hift cache
  253. self.mel_cache_len = 8
  254. self.source_cache_len = int(self.mel_cache_len * 480)
  255. # speech fade in out
  256. self.speech_window = np.hamming(2 * self.source_cache_len)
  257. # rtf and decoding related
  258. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  259. self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
  260. for _ in range(trt_concurrent):
  261. self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
  262. self.lock = threading.Lock()
  263. # dict used to store session related variable
  264. self.tts_speech_token_dict = {}
  265. self.llm_end_dict = {}
  266. self.flow_cache_dict = {}
  267. self.hift_cache_dict = {}
  268. self.trt_context_dict = {}
  269. def init_flow_cache(self):
  270. encoder_cache = {'offset': 0,
  271. 'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
  272. 'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
  273. 'upsample_offset': 0,
  274. 'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
  275. 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
  276. decoder_cache = {'offset': 0,
  277. 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
  278. 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
  279. 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
  280. 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
  281. 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
  282. 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
  283. 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
  284. if self.fp16 is True:
  285. for cache in [encoder_cache, decoder_cache]:
  286. for k, v in cache.items():
  287. if isinstance(v, torch.Tensor):
  288. cache[k] = v.half()
  289. cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
  290. return cache
  291. def load_jit(self, flow_encoder_model):
  292. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  293. self.flow.encoder = flow_encoder
  294. def get_trt_kwargs(self):
  295. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
  296. opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
  297. max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
  298. input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
  299. assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
  300. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  301. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  302. with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
  303. tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
  304. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  305. prompt_token=prompt_token.to(self.device),
  306. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  307. prompt_feat=prompt_feat.to(self.device),
  308. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  309. embedding=embedding.to(self.device),
  310. cache=self.flow_cache_dict[uuid],
  311. finalize=finalize)
  312. # append hift cache
  313. if self.hift_cache_dict[uuid] is not None:
  314. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  315. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  316. else:
  317. hift_cache_source = torch.zeros(1, 1, 0)
  318. # keep overlap mel and hift cache
  319. if finalize is False:
  320. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  321. if self.hift_cache_dict[uuid] is not None:
  322. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  323. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  324. 'source': tts_source[:, :, -self.source_cache_len:],
  325. 'speech': tts_speech[:, -self.source_cache_len:]}
  326. tts_speech = tts_speech[:, :-self.source_cache_len]
  327. else:
  328. if speed != 1.0:
  329. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  330. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  331. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  332. if self.hift_cache_dict[uuid] is not None:
  333. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  334. return tts_speech
  335. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  336. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  337. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  338. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  339. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  340. # this_uuid is used to track variables related to this inference thread
  341. this_uuid = str(uuid.uuid1())
  342. with self.lock:
  343. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  344. self.hift_cache_dict[this_uuid] = None
  345. self.flow_cache_dict[this_uuid] = self.init_flow_cache()
  346. self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
  347. if source_speech_token.shape[1] == 0:
  348. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  349. else:
  350. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  351. p.start()
  352. if stream is True:
  353. assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
  354. # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
  355. flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
  356. prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
  357. while True:
  358. time.sleep(0.1)
  359. if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
  360. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  361. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  362. prompt_token=flow_prompt_speech_token,
  363. prompt_feat=prompt_speech_feat,
  364. embedding=flow_embedding,
  365. uuid=this_uuid,
  366. finalize=False)
  367. # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
  368. flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
  369. prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
  370. yield {'tts_speech': this_tts_speech.cpu()}
  371. with self.lock:
  372. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
  373. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
  374. break
  375. p.join()
  376. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  377. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  378. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  379. prompt_token=flow_prompt_speech_token,
  380. prompt_feat=prompt_speech_feat,
  381. embedding=flow_embedding,
  382. uuid=this_uuid,
  383. finalize=True)
  384. yield {'tts_speech': this_tts_speech.cpu()}
  385. else:
  386. # deal with all tokens
  387. assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
  388. p.join()
  389. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  390. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  391. prompt_token=flow_prompt_speech_token,
  392. prompt_feat=prompt_speech_feat,
  393. embedding=flow_embedding,
  394. uuid=this_uuid,
  395. finalize=True,
  396. speed=speed)
  397. yield {'tts_speech': this_tts_speech.cpu()}
  398. with self.lock:
  399. self.tts_speech_token_dict.pop(this_uuid)
  400. self.llm_end_dict.pop(this_uuid)
  401. self.hift_cache_dict.pop(this_uuid)
  402. self.flow_cache_dict.pop(this_uuid)
  403. self.trt_context_pool.put(self.trt_context_dict[this_uuid])
  404. self.trt_context_dict.pop(this_uuid)
  405. if torch.cuda.is_available():
  406. torch.cuda.empty_cache()
  407. torch.cuda.current_stream().synchronize()