model.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Generator
  16. import torch
  17. import numpy as np
  18. import threading
  19. import time
  20. from torch.nn import functional as F
  21. from contextlib import nullcontext
  22. import uuid
  23. from cosyvoice.utils.common import fade_in_out
  24. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  25. class CosyVoiceModel:
  26. def __init__(self,
  27. llm: torch.nn.Module,
  28. flow: torch.nn.Module,
  29. hift: torch.nn.Module,
  30. fp16: bool):
  31. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  32. self.llm = llm
  33. self.flow = flow
  34. self.hift = hift
  35. self.fp16 = fp16
  36. self.llm.fp16 = fp16
  37. self.flow.fp16 = fp16
  38. if self.fp16 is True:
  39. self.llm.half()
  40. self.flow.half()
  41. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  42. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  43. self.token_overlap_len = 20
  44. # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
  45. self.flow.decoder.estimator.static_chunk_size = 0
  46. # mel fade in out
  47. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  48. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  49. # hift cache
  50. self.mel_cache_len = 20
  51. self.source_cache_len = int(self.mel_cache_len * 256)
  52. # speech fade in out
  53. self.speech_window = np.hamming(2 * self.source_cache_len)
  54. # rtf and decoding related
  55. self.stream_scale_factor = 1
  56. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  57. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  58. self.lock = threading.Lock()
  59. # dict used to store session related variable
  60. self.tts_speech_token_dict = {}
  61. self.llm_end_dict = {}
  62. self.mel_overlap_dict = {}
  63. self.flow_cache_dict = {}
  64. self.hift_cache_dict = {}
  65. def load(self, llm_model, flow_model, hift_model):
  66. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  67. self.llm.to(self.device).eval()
  68. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  69. self.flow.to(self.device).eval()
  70. # in case hift_model is a hifigan model
  71. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  72. self.hift.load_state_dict(hift_state_dict, strict=True)
  73. self.hift.to(self.device).eval()
  74. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  75. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  76. self.llm.text_encoder = llm_text_encoder
  77. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  78. self.llm.llm = llm_llm
  79. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  80. self.flow.encoder = flow_encoder
  81. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
  82. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  83. if not os.path.exists(flow_decoder_estimator_model):
  84. convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
  85. if os.path.getsize(flow_decoder_estimator_model) == 0:
  86. raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
  87. del self.flow.decoder.estimator
  88. import tensorrt as trt
  89. with open(flow_decoder_estimator_model, 'rb') as f:
  90. self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  91. if self.flow.decoder.estimator_engine is None:
  92. raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
  93. self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
  94. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  95. with self.llm_context:
  96. if isinstance(text, Generator):
  97. assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
  98. for i in self.llm.inference_bistream(text=text,
  99. prompt_text=prompt_text.to(self.device),
  100. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  101. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  102. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  103. embedding=llm_embedding.to(self.device)):
  104. self.tts_speech_token_dict[uuid].append(i)
  105. else:
  106. for i in self.llm.inference(text=text.to(self.device),
  107. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  108. prompt_text=prompt_text.to(self.device),
  109. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  110. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  111. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  112. embedding=llm_embedding.to(self.device)):
  113. self.tts_speech_token_dict[uuid].append(i)
  114. self.llm_end_dict[uuid] = True
  115. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  116. tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
  117. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  118. prompt_token=prompt_token.to(self.device),
  119. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  120. prompt_feat=prompt_feat.to(self.device),
  121. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  122. embedding=embedding.to(self.device),
  123. flow_cache=self.flow_cache_dict[uuid])
  124. self.flow_cache_dict[uuid] = flow_cache
  125. # mel overlap fade in out
  126. if self.mel_overlap_dict[uuid].shape[2] != 0:
  127. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  128. # append hift cache
  129. if self.hift_cache_dict[uuid] is not None:
  130. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  131. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  132. else:
  133. hift_cache_source = torch.zeros(1, 1, 0)
  134. # keep overlap mel and hift cache
  135. if finalize is False:
  136. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  137. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  138. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  139. if self.hift_cache_dict[uuid] is not None:
  140. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  141. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  142. 'source': tts_source[:, :, -self.source_cache_len:],
  143. 'speech': tts_speech[:, -self.source_cache_len:]}
  144. tts_speech = tts_speech[:, :-self.source_cache_len]
  145. else:
  146. if speed != 1.0:
  147. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  148. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  149. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  150. if self.hift_cache_dict[uuid] is not None:
  151. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  152. return tts_speech
  153. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  154. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  155. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  156. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  157. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  158. # this_uuid is used to track variables related to this inference thread
  159. this_uuid = str(uuid.uuid1())
  160. with self.lock:
  161. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  162. self.hift_cache_dict[this_uuid] = None
  163. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  164. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  165. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  166. p.start()
  167. if stream is True:
  168. token_hop_len = self.token_min_hop_len
  169. while True:
  170. time.sleep(0.1)
  171. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  172. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  173. .unsqueeze(dim=0)
  174. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  175. prompt_token=flow_prompt_speech_token,
  176. prompt_feat=prompt_speech_feat,
  177. embedding=flow_embedding,
  178. uuid=this_uuid,
  179. finalize=False)
  180. yield {'tts_speech': this_tts_speech.cpu()}
  181. with self.lock:
  182. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  183. # increase token_hop_len for better speech quality
  184. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  185. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  186. break
  187. p.join()
  188. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  189. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  190. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  191. prompt_token=flow_prompt_speech_token,
  192. prompt_feat=prompt_speech_feat,
  193. embedding=flow_embedding,
  194. uuid=this_uuid,
  195. finalize=True)
  196. yield {'tts_speech': this_tts_speech.cpu()}
  197. else:
  198. # deal with all tokens
  199. p.join()
  200. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  201. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  202. prompt_token=flow_prompt_speech_token,
  203. prompt_feat=prompt_speech_feat,
  204. embedding=flow_embedding,
  205. uuid=this_uuid,
  206. finalize=True,
  207. speed=speed)
  208. yield {'tts_speech': this_tts_speech.cpu()}
  209. with self.lock:
  210. self.tts_speech_token_dict.pop(this_uuid)
  211. self.llm_end_dict.pop(this_uuid)
  212. self.mel_overlap_dict.pop(this_uuid)
  213. self.hift_cache_dict.pop(this_uuid)
  214. self.flow_cache_dict.pop(this_uuid)
  215. torch.cuda.empty_cache()
  216. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  217. # this_uuid is used to track variables related to this inference thread
  218. this_uuid = str(uuid.uuid1())
  219. with self.lock:
  220. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  221. self.hift_cache_dict[this_uuid] = None
  222. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  223. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  224. if stream is True:
  225. token_hop_len = self.token_min_hop_len
  226. while True:
  227. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  228. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  229. .unsqueeze(dim=0)
  230. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  231. prompt_token=flow_prompt_speech_token,
  232. prompt_feat=prompt_speech_feat,
  233. embedding=flow_embedding,
  234. uuid=this_uuid,
  235. finalize=False)
  236. yield {'tts_speech': this_tts_speech.cpu()}
  237. with self.lock:
  238. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  239. # increase token_hop_len for better speech quality
  240. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  241. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  242. break
  243. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  244. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  245. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  246. prompt_token=flow_prompt_speech_token,
  247. prompt_feat=prompt_speech_feat,
  248. embedding=flow_embedding,
  249. uuid=this_uuid,
  250. finalize=True)
  251. yield {'tts_speech': this_tts_speech.cpu()}
  252. else:
  253. # deal with all tokens
  254. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  255. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  256. prompt_token=flow_prompt_speech_token,
  257. prompt_feat=prompt_speech_feat,
  258. embedding=flow_embedding,
  259. uuid=this_uuid,
  260. finalize=True,
  261. speed=speed)
  262. yield {'tts_speech': this_tts_speech.cpu()}
  263. with self.lock:
  264. self.tts_speech_token_dict.pop(this_uuid)
  265. self.llm_end_dict.pop(this_uuid)
  266. self.mel_overlap_dict.pop(this_uuid)
  267. self.hift_cache_dict.pop(this_uuid)
  268. torch.cuda.empty_cache()
  269. class CosyVoice2Model(CosyVoiceModel):
  270. def __init__(self,
  271. llm: torch.nn.Module,
  272. flow: torch.nn.Module,
  273. hift: torch.nn.Module,
  274. fp16: bool):
  275. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  276. self.llm = llm
  277. self.flow = flow
  278. self.hift = hift
  279. self.fp16 = fp16
  280. self.llm.fp16 = fp16
  281. self.flow.fp16 = fp16
  282. if self.fp16 is True:
  283. self.llm.half()
  284. self.flow.half()
  285. self.token_hop_len = 2 * self.flow.input_frame_rate
  286. # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
  287. self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
  288. self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
  289. # hift cache
  290. self.mel_cache_len = 8
  291. self.source_cache_len = int(self.mel_cache_len * 480)
  292. # speech fade in out
  293. self.speech_window = np.hamming(2 * self.source_cache_len)
  294. # rtf and decoding related
  295. self.stream_scale_factor = 1
  296. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  297. self.lock = threading.Lock()
  298. # dict used to store session related variable
  299. self.tts_speech_token_dict = {}
  300. self.llm_end_dict = {}
  301. self.hift_cache_dict = {}
  302. def load_jit(self, flow_encoder_model):
  303. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  304. self.flow.encoder = flow_encoder
  305. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
  306. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  307. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  308. prompt_token=prompt_token.to(self.device),
  309. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  310. prompt_feat=prompt_feat.to(self.device),
  311. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  312. embedding=embedding.to(self.device),
  313. finalize=finalize)
  314. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  315. # append hift cache
  316. if self.hift_cache_dict[uuid] is not None:
  317. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  318. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  319. else:
  320. hift_cache_source = torch.zeros(1, 1, 0)
  321. # keep overlap mel and hift cache
  322. if finalize is False:
  323. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  324. if self.hift_cache_dict[uuid] is not None:
  325. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  326. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  327. 'source': tts_source[:, :, -self.source_cache_len:],
  328. 'speech': tts_speech[:, -self.source_cache_len:]}
  329. tts_speech = tts_speech[:, :-self.source_cache_len]
  330. else:
  331. if speed != 1.0:
  332. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  333. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  334. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  335. if self.hift_cache_dict[uuid] is not None:
  336. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  337. return tts_speech
  338. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  339. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  340. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  341. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  342. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  343. # this_uuid is used to track variables related to this inference thread
  344. this_uuid = str(uuid.uuid1())
  345. with self.lock:
  346. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  347. self.hift_cache_dict[this_uuid] = None
  348. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  349. p.start()
  350. if stream is True:
  351. token_offset = 0
  352. while True:
  353. time.sleep(0.1)
  354. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
  355. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  356. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  357. prompt_token=flow_prompt_speech_token,
  358. prompt_feat=prompt_speech_feat,
  359. embedding=flow_embedding,
  360. uuid=this_uuid,
  361. token_offset=token_offset,
  362. finalize=False)
  363. token_offset += self.token_hop_len
  364. yield {'tts_speech': this_tts_speech.cpu()}
  365. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
  366. break
  367. p.join()
  368. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  369. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  370. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  371. prompt_token=flow_prompt_speech_token,
  372. prompt_feat=prompt_speech_feat,
  373. embedding=flow_embedding,
  374. uuid=this_uuid,
  375. token_offset=token_offset,
  376. finalize=True)
  377. yield {'tts_speech': this_tts_speech.cpu()}
  378. else:
  379. # deal with all tokens
  380. p.join()
  381. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  382. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  383. prompt_token=flow_prompt_speech_token,
  384. prompt_feat=prompt_speech_feat,
  385. embedding=flow_embedding,
  386. uuid=this_uuid,
  387. token_offset=0,
  388. finalize=True,
  389. speed=speed)
  390. yield {'tts_speech': this_tts_speech.cpu()}
  391. with self.lock:
  392. self.tts_speech_token_dict.pop(this_uuid)
  393. self.llm_end_dict.pop(this_uuid)
  394. torch.cuda.empty_cache()