model.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Generator
  16. import torch
  17. import numpy as np
  18. import threading
  19. import time
  20. from torch.nn import functional as F
  21. from contextlib import nullcontext
  22. import uuid
  23. from cosyvoice.utils.common import fade_in_out
  24. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  25. from cosyvoice.flow.flow_matching import EstimatorWrapper
  26. import queue
  27. class CosyVoiceModel:
  28. def __init__(self,
  29. llm: torch.nn.Module,
  30. flow: torch.nn.Module,
  31. hift: torch.nn.Module,
  32. fp16: bool):
  33. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  34. self.llm = llm
  35. self.flow = flow
  36. self.hift = hift
  37. self.fp16 = fp16
  38. self.llm.fp16 = fp16
  39. self.flow.fp16 = fp16
  40. if self.fp16 is True:
  41. self.llm.half()
  42. self.flow.half()
  43. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  44. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  45. self.token_overlap_len = 20
  46. # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
  47. self.flow.decoder.estimator.static_chunk_size = 0
  48. # mel fade in out
  49. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  50. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  51. # hift cache
  52. self.mel_cache_len = 20
  53. self.source_cache_len = int(self.mel_cache_len * 256)
  54. # speech fade in out
  55. self.speech_window = np.hamming(2 * self.source_cache_len)
  56. # rtf and decoding related
  57. self.stream_scale_factor = 1
  58. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  59. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  60. self.lock = threading.Lock()
  61. # dict used to store session related variable
  62. self.tts_speech_token_dict = {}
  63. self.llm_end_dict = {}
  64. self.mel_overlap_dict = {}
  65. self.flow_cache_dict = {}
  66. self.hift_cache_dict = {}
  67. self.stream_context_pool = queue.Queue()
  68. for _ in range(10):
  69. self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
  70. self.is_cuda_available = torch.cuda.is_available()
  71. def load(self, llm_model, flow_model, hift_model):
  72. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  73. self.llm.to(self.device).eval()
  74. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  75. self.flow.to(self.device).eval()
  76. # in case hift_model is a hifigan model
  77. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  78. self.hift.load_state_dict(hift_state_dict, strict=True)
  79. self.hift.to(self.device).eval()
  80. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  81. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  82. self.llm.text_encoder = llm_text_encoder
  83. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  84. self.llm.llm = llm_llm
  85. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  86. self.flow.encoder = flow_encoder
  87. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16, estimator_count=1):
  88. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  89. if not os.path.exists(flow_decoder_estimator_model):
  90. convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
  91. if os.path.getsize(flow_decoder_estimator_model) == 0:
  92. raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
  93. del self.flow.decoder.estimator
  94. import tensorrt as trt
  95. with open(flow_decoder_estimator_model, 'rb') as f:
  96. self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  97. if self.flow.decoder.estimator_engine is None:
  98. raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
  99. self.flow.decoder.estimator = EstimatorWrapper(self.flow.decoder.estimator_engine, estimator_count=estimator_count)
  100. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  101. with self.llm_context:
  102. if isinstance(text, Generator):
  103. assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
  104. for i in self.llm.inference_bistream(text=text,
  105. prompt_text=prompt_text.to(self.device),
  106. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  107. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  108. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  109. embedding=llm_embedding.to(self.device)):
  110. self.tts_speech_token_dict[uuid].append(i)
  111. else:
  112. for i in self.llm.inference(text=text.to(self.device),
  113. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  114. prompt_text=prompt_text.to(self.device),
  115. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  116. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  117. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  118. embedding=llm_embedding.to(self.device)):
  119. self.tts_speech_token_dict[uuid].append(i)
  120. self.llm_end_dict[uuid] = True
  121. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  122. tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
  123. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  124. prompt_token=prompt_token.to(self.device),
  125. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  126. prompt_feat=prompt_feat.to(self.device),
  127. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  128. embedding=embedding.to(self.device),
  129. flow_cache=self.flow_cache_dict[uuid])
  130. self.flow_cache_dict[uuid] = flow_cache
  131. # mel overlap fade in out
  132. if self.mel_overlap_dict[uuid].shape[2] != 0:
  133. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  134. # append hift cache
  135. if self.hift_cache_dict[uuid] is not None:
  136. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  137. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  138. else:
  139. hift_cache_source = torch.zeros(1, 1, 0)
  140. # keep overlap mel and hift cache
  141. if finalize is False:
  142. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  143. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  144. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  145. if self.hift_cache_dict[uuid] is not None:
  146. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  147. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  148. 'source': tts_source[:, :, -self.source_cache_len:],
  149. 'speech': tts_speech[:, -self.source_cache_len:]}
  150. tts_speech = tts_speech[:, :-self.source_cache_len]
  151. else:
  152. if speed != 1.0:
  153. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  154. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  155. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  156. if self.hift_cache_dict[uuid] is not None:
  157. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  158. return tts_speech
  159. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  160. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  161. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  162. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  163. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  164. # this_uuid is used to track variables related to this inference thread
  165. stream_context = self.stream_context_pool.get()
  166. with stream_context:
  167. this_uuid = str(uuid.uuid1())
  168. with self.lock:
  169. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  170. self.hift_cache_dict[this_uuid] = None
  171. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  172. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  173. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  174. p.start()
  175. if stream is True:
  176. token_hop_len = self.token_min_hop_len
  177. while True:
  178. time.sleep(0.1)
  179. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  180. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  181. .unsqueeze(dim=0)
  182. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  183. prompt_token=flow_prompt_speech_token,
  184. prompt_feat=prompt_speech_feat,
  185. embedding=flow_embedding,
  186. uuid=this_uuid,
  187. finalize=False)
  188. yield {'tts_speech': this_tts_speech.cpu()}
  189. with self.lock:
  190. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  191. # increase token_hop_len for better speech quality
  192. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  193. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  194. break
  195. p.join()
  196. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  197. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  198. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  199. prompt_token=flow_prompt_speech_token,
  200. prompt_feat=prompt_speech_feat,
  201. embedding=flow_embedding,
  202. uuid=this_uuid,
  203. finalize=True)
  204. yield {'tts_speech': this_tts_speech.cpu()}
  205. else:
  206. # deal with all tokens
  207. p.join()
  208. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  209. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  210. prompt_token=flow_prompt_speech_token,
  211. prompt_feat=prompt_speech_feat,
  212. embedding=flow_embedding,
  213. uuid=this_uuid,
  214. finalize=True,
  215. speed=speed)
  216. yield {'tts_speech': this_tts_speech.cpu()}
  217. with self.lock:
  218. self.tts_speech_token_dict.pop(this_uuid)
  219. self.llm_end_dict.pop(this_uuid)
  220. self.mel_overlap_dict.pop(this_uuid)
  221. self.hift_cache_dict.pop(this_uuid)
  222. self.flow_cache_dict.pop(this_uuid)
  223. self.synchronize_stream()
  224. self.stream_context_pool.put(stream_context)
  225. torch.cuda.empty_cache()
  226. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  227. # this_uuid is used to track variables related to this inference thread
  228. this_uuid = str(uuid.uuid1())
  229. with self.lock:
  230. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  231. self.hift_cache_dict[this_uuid] = None
  232. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  233. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  234. if stream is True:
  235. token_hop_len = self.token_min_hop_len
  236. while True:
  237. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  238. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  239. .unsqueeze(dim=0)
  240. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  241. prompt_token=flow_prompt_speech_token,
  242. prompt_feat=prompt_speech_feat,
  243. embedding=flow_embedding,
  244. uuid=this_uuid,
  245. finalize=False)
  246. yield {'tts_speech': this_tts_speech.cpu()}
  247. with self.lock:
  248. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  249. # increase token_hop_len for better speech quality
  250. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  251. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  252. break
  253. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  254. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  255. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  256. prompt_token=flow_prompt_speech_token,
  257. prompt_feat=prompt_speech_feat,
  258. embedding=flow_embedding,
  259. uuid=this_uuid,
  260. finalize=True)
  261. yield {'tts_speech': this_tts_speech.cpu()}
  262. else:
  263. # deal with all tokens
  264. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  265. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  266. prompt_token=flow_prompt_speech_token,
  267. prompt_feat=prompt_speech_feat,
  268. embedding=flow_embedding,
  269. uuid=this_uuid,
  270. finalize=True,
  271. speed=speed)
  272. yield {'tts_speech': this_tts_speech.cpu()}
  273. with self.lock:
  274. self.tts_speech_token_dict.pop(this_uuid)
  275. self.llm_end_dict.pop(this_uuid)
  276. self.mel_overlap_dict.pop(this_uuid)
  277. self.hift_cache_dict.pop(this_uuid)
  278. torch.cuda.empty_cache()
  279. def synchronize_stream(self):
  280. if self.is_cuda_available:
  281. torch.cuda.current_stream().synchronize()
  282. class CosyVoice2Model(CosyVoiceModel):
  283. def __init__(self,
  284. llm: torch.nn.Module,
  285. flow: torch.nn.Module,
  286. hift: torch.nn.Module,
  287. fp16: bool):
  288. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  289. self.llm = llm
  290. self.flow = flow
  291. self.hift = hift
  292. self.fp16 = fp16
  293. self.llm.fp16 = fp16
  294. self.flow.fp16 = fp16
  295. if self.fp16 is True:
  296. self.llm.half()
  297. self.flow.half()
  298. self.token_hop_len = 2 * self.flow.input_frame_rate
  299. # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
  300. self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
  301. self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
  302. # hift cache
  303. self.mel_cache_len = 8
  304. self.source_cache_len = int(self.mel_cache_len * 480)
  305. # speech fade in out
  306. self.speech_window = np.hamming(2 * self.source_cache_len)
  307. # rtf and decoding related
  308. self.stream_scale_factor = 1
  309. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  310. self.lock = threading.Lock()
  311. # dict used to store session related variable
  312. self.tts_speech_token_dict = {}
  313. self.llm_end_dict = {}
  314. self.hift_cache_dict = {}
  315. self.stream_context_pool = queue.Queue()
  316. for _ in range(10):
  317. self.stream_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
  318. self.is_cuda_available = torch.cuda.is_available()
  319. def load_jit(self, flow_encoder_model):
  320. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  321. self.flow.encoder = flow_encoder
  322. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
  323. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  324. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  325. prompt_token=prompt_token.to(self.device),
  326. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  327. prompt_feat=prompt_feat.to(self.device),
  328. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  329. embedding=embedding.to(self.device),
  330. finalize=finalize)
  331. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  332. # append hift cache
  333. if self.hift_cache_dict[uuid] is not None:
  334. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  335. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  336. else:
  337. hift_cache_source = torch.zeros(1, 1, 0)
  338. # keep overlap mel and hift cache
  339. if finalize is False:
  340. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  341. if self.hift_cache_dict[uuid] is not None:
  342. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  343. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  344. 'source': tts_source[:, :, -self.source_cache_len:],
  345. 'speech': tts_speech[:, -self.source_cache_len:]}
  346. tts_speech = tts_speech[:, :-self.source_cache_len]
  347. else:
  348. if speed != 1.0:
  349. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  350. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  351. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  352. if self.hift_cache_dict[uuid] is not None:
  353. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  354. return tts_speech
  355. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  356. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  357. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  358. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  359. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  360. # this_uuid is used to track variables related to this inference thread
  361. self.synchronize_stream()
  362. stream_context = self.stream_context_pool.get()
  363. with torch.cuda.stream(stream_context):
  364. this_uuid = str(uuid.uuid1())
  365. with self.lock:
  366. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  367. self.hift_cache_dict[this_uuid] = None
  368. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  369. p.start()
  370. if stream is True:
  371. token_offset = 0
  372. while True:
  373. time.sleep(0.1)
  374. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
  375. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  376. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  377. prompt_token=flow_prompt_speech_token,
  378. prompt_feat=prompt_speech_feat,
  379. embedding=flow_embedding,
  380. uuid=this_uuid,
  381. token_offset=token_offset,
  382. finalize=False)
  383. token_offset += self.token_hop_len
  384. yield {'tts_speech': this_tts_speech.cpu()}
  385. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
  386. break
  387. p.join()
  388. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  389. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  390. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  391. prompt_token=flow_prompt_speech_token,
  392. prompt_feat=prompt_speech_feat,
  393. embedding=flow_embedding,
  394. uuid=this_uuid,
  395. token_offset=token_offset,
  396. finalize=True)
  397. yield {'tts_speech': this_tts_speech.cpu()}
  398. else:
  399. # deal with all tokens
  400. p.join()
  401. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  402. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  403. prompt_token=flow_prompt_speech_token,
  404. prompt_feat=prompt_speech_feat,
  405. embedding=flow_embedding,
  406. uuid=this_uuid,
  407. token_offset=0,
  408. finalize=True,
  409. speed=speed)
  410. yield {'tts_speech': this_tts_speech.cpu()}
  411. with self.lock:
  412. self.tts_speech_token_dict.pop(this_uuid)
  413. self.llm_end_dict.pop(this_uuid)
  414. self.synchronize_stream()
  415. self.stream_context_pool.put(stream_context)
  416. torch.cuda.empty_cache()
  417. class VllmCosyVoice2Model(CosyVoice2Model):
  418. def __init__(self,
  419. model_dir: str,
  420. flow: torch.nn.Module,
  421. hift: torch.nn.Module,
  422. fp16: bool):
  423. try:
  424. from cosyvoice.llm.llm_vllm import VllmQwen2LM
  425. except Exception as e:
  426. raise e
  427. llm = VllmQwen2LM(model_dir)
  428. super().__init__(llm,flow,hift,fp16)
  429. def load(self, llm_model, flow_model, hift_model):
  430. self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True)
  431. self.flow.to(self.device).eval()
  432. # in case hift_model is a hifigan model
  433. hift_state_dict = {k.replace('generator.', ''): v for k, v in
  434. torch.load(hift_model, weights_only=True, map_location=self.device).items()}
  435. self.hift.load_state_dict(hift_state_dict, strict=True)
  436. self.hift.to(self.device).eval()