model.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Generator
  16. import torch
  17. import numpy as np
  18. import threading
  19. import time
  20. from torch.nn import functional as F
  21. from contextlib import nullcontext
  22. import uuid
  23. from cosyvoice.utils.common import fade_in_out
  24. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  25. class CosyVoiceModel:
  26. def __init__(self,
  27. llm: torch.nn.Module,
  28. flow: torch.nn.Module,
  29. hift: torch.nn.Module,
  30. fp16: bool):
  31. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  32. self.llm = llm
  33. self.flow = flow
  34. self.hift = hift
  35. self.fp16 = fp16
  36. self.llm.fp16 = fp16
  37. self.flow.fp16 = fp16
  38. if self.fp16 is True:
  39. self.llm.half()
  40. self.flow.half()
  41. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  42. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  43. self.token_overlap_len = 20
  44. # mel fade in out
  45. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  46. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  47. # hift cache
  48. self.mel_cache_len = 20
  49. self.source_cache_len = int(self.mel_cache_len * 256)
  50. # speech fade in out
  51. self.speech_window = np.hamming(2 * self.source_cache_len)
  52. # rtf and decoding related
  53. self.stream_scale_factor = 1
  54. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  55. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  56. self.lock = threading.Lock()
  57. # dict used to store session related variable
  58. self.tts_speech_token_dict = {}
  59. self.llm_end_dict = {}
  60. self.mel_overlap_dict = {}
  61. self.flow_cache_dict = {}
  62. self.hift_cache_dict = {}
  63. def load(self, llm_model, flow_model, hift_model):
  64. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  65. self.llm.to(self.device).eval()
  66. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  67. self.flow.to(self.device).eval()
  68. # in case hift_model is a hifigan model
  69. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  70. self.hift.load_state_dict(hift_state_dict, strict=True)
  71. self.hift.to(self.device).eval()
  72. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  73. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  74. self.llm.text_encoder = llm_text_encoder
  75. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  76. self.llm.llm = llm_llm
  77. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  78. self.flow.encoder = flow_encoder
  79. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
  80. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  81. if not os.path.exists(flow_decoder_estimator_model):
  82. convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
  83. if os.path.getsize(flow_decoder_estimator_model) == 0:
  84. raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
  85. del self.flow.decoder.estimator
  86. import tensorrt as trt
  87. with open(flow_decoder_estimator_model, 'rb') as f:
  88. self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  89. if self.flow.decoder.estimator_engine is None:
  90. raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
  91. self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
  92. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  93. with self.llm_context:
  94. if isinstance(text, Generator):
  95. assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
  96. for i in self.llm.inference_bistream(text=text,
  97. prompt_text=prompt_text.to(self.device),
  98. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  99. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  100. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  101. embedding=llm_embedding.to(self.device)):
  102. self.tts_speech_token_dict[uuid].append(i)
  103. else:
  104. for i in self.llm.inference(text=text.to(self.device),
  105. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  106. prompt_text=prompt_text.to(self.device),
  107. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  108. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  109. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  110. embedding=llm_embedding.to(self.device)):
  111. self.tts_speech_token_dict[uuid].append(i)
  112. self.llm_end_dict[uuid] = True
  113. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  114. tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
  115. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  116. prompt_token=prompt_token.to(self.device),
  117. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  118. prompt_feat=prompt_feat.to(self.device),
  119. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  120. embedding=embedding.to(self.device),
  121. flow_cache=self.flow_cache_dict[uuid])
  122. # mel overlap fade in out
  123. if self.mel_overlap_dict[uuid].shape[2] != 0:
  124. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  125. # append hift cache
  126. if self.hift_cache_dict[uuid] is not None:
  127. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  128. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  129. else:
  130. hift_cache_source = torch.zeros(1, 1, 0)
  131. # keep overlap mel and hift cache
  132. if finalize is False:
  133. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  134. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  135. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  136. if self.hift_cache_dict[uuid] is not None:
  137. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  138. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  139. 'source': tts_source[:, :, -self.source_cache_len:],
  140. 'speech': tts_speech[:, -self.source_cache_len:]}
  141. tts_speech = tts_speech[:, :-self.source_cache_len]
  142. else:
  143. if speed != 1.0:
  144. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  145. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  146. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  147. if self.hift_cache_dict[uuid] is not None:
  148. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  149. return tts_speech
  150. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  151. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  152. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  153. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  154. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  155. # this_uuid is used to track variables related to this inference thread
  156. this_uuid = str(uuid.uuid1())
  157. with self.lock:
  158. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  159. self.hift_cache_dict[this_uuid] = None
  160. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  161. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  162. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  163. p.start()
  164. if stream is True:
  165. token_hop_len = self.token_min_hop_len
  166. while True:
  167. time.sleep(0.1)
  168. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  169. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  170. .unsqueeze(dim=0)
  171. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  172. prompt_token=flow_prompt_speech_token,
  173. prompt_feat=prompt_speech_feat,
  174. embedding=flow_embedding,
  175. uuid=this_uuid,
  176. finalize=False)
  177. yield {'tts_speech': this_tts_speech.cpu()}
  178. with self.lock:
  179. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  180. # increase token_hop_len for better speech quality
  181. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  182. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  183. break
  184. p.join()
  185. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  186. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  187. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  188. prompt_token=flow_prompt_speech_token,
  189. prompt_feat=prompt_speech_feat,
  190. embedding=flow_embedding,
  191. uuid=this_uuid,
  192. finalize=True)
  193. yield {'tts_speech': this_tts_speech.cpu()}
  194. else:
  195. # deal with all tokens
  196. p.join()
  197. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  198. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  199. prompt_token=flow_prompt_speech_token,
  200. prompt_feat=prompt_speech_feat,
  201. embedding=flow_embedding,
  202. uuid=this_uuid,
  203. finalize=True,
  204. speed=speed)
  205. yield {'tts_speech': this_tts_speech.cpu()}
  206. with self.lock:
  207. self.tts_speech_token_dict.pop(this_uuid)
  208. self.llm_end_dict.pop(this_uuid)
  209. self.mel_overlap_dict.pop(this_uuid)
  210. self.hift_cache_dict.pop(this_uuid)
  211. self.flow_cache_dict.pop(this_uuid)
  212. torch.cuda.empty_cache()
  213. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  214. # this_uuid is used to track variables related to this inference thread
  215. this_uuid = str(uuid.uuid1())
  216. with self.lock:
  217. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  218. self.hift_cache_dict[this_uuid] = None
  219. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  220. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  221. if stream is True:
  222. token_hop_len = self.token_min_hop_len
  223. while True:
  224. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  225. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  226. .unsqueeze(dim=0)
  227. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  228. prompt_token=flow_prompt_speech_token,
  229. prompt_feat=prompt_speech_feat,
  230. embedding=flow_embedding,
  231. uuid=this_uuid,
  232. finalize=False)
  233. yield {'tts_speech': this_tts_speech.cpu()}
  234. with self.lock:
  235. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  236. # increase token_hop_len for better speech quality
  237. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  238. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  239. break
  240. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  241. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  242. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  243. prompt_token=flow_prompt_speech_token,
  244. prompt_feat=prompt_speech_feat,
  245. embedding=flow_embedding,
  246. uuid=this_uuid,
  247. finalize=True)
  248. yield {'tts_speech': this_tts_speech.cpu()}
  249. else:
  250. # deal with all tokens
  251. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  252. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  253. prompt_token=flow_prompt_speech_token,
  254. prompt_feat=prompt_speech_feat,
  255. embedding=flow_embedding,
  256. uuid=this_uuid,
  257. finalize=True,
  258. speed=speed)
  259. yield {'tts_speech': this_tts_speech.cpu()}
  260. with self.lock:
  261. self.tts_speech_token_dict.pop(this_uuid)
  262. self.llm_end_dict.pop(this_uuid)
  263. self.mel_overlap_dict.pop(this_uuid)
  264. self.hift_cache_dict.pop(this_uuid)
  265. self.flow_cache_dict.pop(this_uuid)
  266. torch.cuda.empty_cache()
  267. class CosyVoice2Model(CosyVoiceModel):
  268. def __init__(self,
  269. llm: torch.nn.Module,
  270. flow: torch.nn.Module,
  271. hift: torch.nn.Module,
  272. fp16: bool):
  273. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  274. self.llm = llm
  275. self.flow = flow
  276. self.hift = hift
  277. self.fp16 = fp16
  278. self.llm.fp16 = fp16
  279. self.flow.fp16 = fp16
  280. if self.fp16 is True:
  281. self.llm.half()
  282. self.flow.half()
  283. self.token_hop_len = 2 * self.flow.input_frame_rate
  284. # flow decoder required_cache_size
  285. self.flow_decoder_required_cache_size = self.flow.decoder.estimator.num_decoding_left_chunks * self.flow.input_frame_rate * self.flow.token_mel_ratio
  286. # hift cache
  287. self.mel_cache_len = 8
  288. self.source_cache_len = int(self.mel_cache_len * 480)
  289. # speech fade in out
  290. self.speech_window = np.hamming(2 * self.source_cache_len)
  291. # rtf and decoding related
  292. self.stream_scale_factor = 1
  293. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  294. self.lock = threading.Lock()
  295. # dict used to store session related variable
  296. self.tts_speech_token_dict = {}
  297. self.llm_end_dict = {}
  298. self.flow_cache_dict = {}
  299. self.hift_cache_dict = {}
  300. def init_flow_cache(self):
  301. encoder_cache = {'offset': 0,
  302. 'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
  303. 'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
  304. 'upsample_offset': 0,
  305. 'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
  306. 'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
  307. decoder_cache = {'offset': 0,
  308. 'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
  309. 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
  310. 'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
  311. 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device),
  312. 'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
  313. 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
  314. 'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
  315. cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
  316. return cache
  317. def trim_flow_cache(self, cache):
  318. if cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size:
  319. cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
  320. cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
  321. cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
  322. return cache
  323. def load_jit(self, flow_encoder_model):
  324. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  325. self.flow.encoder = flow_encoder
  326. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  327. tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
  328. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  329. prompt_token=prompt_token.to(self.device),
  330. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  331. prompt_feat=prompt_feat.to(self.device),
  332. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  333. embedding=embedding.to(self.device),
  334. cache=self.flow_cache_dict[uuid],
  335. finalize=finalize)
  336. self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid])
  337. # append hift cache
  338. if self.hift_cache_dict[uuid] is not None:
  339. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  340. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  341. else:
  342. hift_cache_source = torch.zeros(1, 1, 0)
  343. # keep overlap mel and hift cache
  344. if finalize is False:
  345. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  346. if self.hift_cache_dict[uuid] is not None:
  347. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  348. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  349. 'source': tts_source[:, :, -self.source_cache_len:],
  350. 'speech': tts_speech[:, -self.source_cache_len:]}
  351. tts_speech = tts_speech[:, :-self.source_cache_len]
  352. else:
  353. if speed != 1.0:
  354. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  355. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  356. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  357. if self.hift_cache_dict[uuid] is not None:
  358. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  359. return tts_speech
  360. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  361. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  362. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  363. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  364. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  365. # this_uuid is used to track variables related to this inference thread
  366. this_uuid = str(uuid.uuid1())
  367. with self.lock:
  368. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  369. self.hift_cache_dict[this_uuid] = None
  370. self.flow_cache_dict[this_uuid] = self.init_flow_cache()
  371. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  372. p.start()
  373. if stream is True:
  374. while True:
  375. time.sleep(0.1)
  376. if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
  377. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  378. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  379. prompt_token=flow_prompt_speech_token,
  380. prompt_feat=prompt_speech_feat,
  381. embedding=flow_embedding,
  382. uuid=this_uuid,
  383. finalize=False)
  384. # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
  385. flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
  386. prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
  387. yield {'tts_speech': this_tts_speech.cpu()}
  388. with self.lock:
  389. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
  390. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
  391. break
  392. p.join()
  393. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  394. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  395. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  396. prompt_token=flow_prompt_speech_token,
  397. prompt_feat=prompt_speech_feat,
  398. embedding=flow_embedding,
  399. uuid=this_uuid,
  400. finalize=True)
  401. yield {'tts_speech': this_tts_speech.cpu()}
  402. else:
  403. # deal with all tokens
  404. p.join()
  405. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  406. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  407. prompt_token=flow_prompt_speech_token,
  408. prompt_feat=prompt_speech_feat,
  409. embedding=flow_embedding,
  410. uuid=this_uuid,
  411. finalize=True,
  412. speed=speed)
  413. yield {'tts_speech': this_tts_speech.cpu()}
  414. with self.lock:
  415. self.tts_speech_token_dict.pop(this_uuid)
  416. self.llm_end_dict.pop(this_uuid)
  417. self.hift_cache_dict.pop(this_uuid)
  418. self.flow_cache_dict.pop(this_uuid)
  419. torch.cuda.empty_cache()