model.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Generator
  16. import torch
  17. import numpy as np
  18. import threading
  19. import time
  20. from torch.nn import functional as F
  21. from contextlib import nullcontext
  22. import uuid
  23. from cosyvoice.utils.common import fade_in_out
  24. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  25. from cosyvoice.flow.flow_matching import EstimatorWrapper
  26. class CosyVoiceModel:
  27. def __init__(self,
  28. llm: torch.nn.Module,
  29. flow: torch.nn.Module,
  30. hift: torch.nn.Module,
  31. fp16: bool):
  32. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  33. self.llm = llm
  34. self.flow = flow
  35. self.hift = hift
  36. self.fp16 = fp16
  37. self.llm.fp16 = fp16
  38. self.flow.fp16 = fp16
  39. if self.fp16 is True:
  40. self.llm.half()
  41. self.flow.half()
  42. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  43. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  44. self.token_overlap_len = 20
  45. # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
  46. self.flow.decoder.estimator.static_chunk_size = 0
  47. # mel fade in out
  48. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  49. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  50. # hift cache
  51. self.mel_cache_len = 20
  52. self.source_cache_len = int(self.mel_cache_len * 256)
  53. # speech fade in out
  54. self.speech_window = np.hamming(2 * self.source_cache_len)
  55. # rtf and decoding related
  56. self.stream_scale_factor = 1
  57. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  58. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  59. self.lock = threading.Lock()
  60. # dict used to store session related variable
  61. self.tts_speech_token_dict = {}
  62. self.llm_end_dict = {}
  63. self.mel_overlap_dict = {}
  64. self.flow_cache_dict = {}
  65. self.hift_cache_dict = {}
  66. def load(self, llm_model, flow_model, hift_model):
  67. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  68. self.llm.to(self.device).eval()
  69. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  70. self.flow.to(self.device).eval()
  71. # in case hift_model is a hifigan model
  72. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  73. self.hift.load_state_dict(hift_state_dict, strict=True)
  74. self.hift.to(self.device).eval()
  75. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  76. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  77. self.llm.text_encoder = llm_text_encoder
  78. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  79. self.llm.llm = llm_llm
  80. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  81. self.flow.encoder = flow_encoder
  82. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16, estimator_count=1):
  83. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  84. if not os.path.exists(flow_decoder_estimator_model):
  85. convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
  86. if os.path.getsize(flow_decoder_estimator_model) == 0:
  87. raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
  88. del self.flow.decoder.estimator
  89. import tensorrt as trt
  90. with open(flow_decoder_estimator_model, 'rb') as f:
  91. self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  92. if self.flow.decoder.estimator_engine is None:
  93. raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
  94. self.flow.decoder.estimator = EstimatorWrapper(self.flow.decoder.estimator_engine, estimator_count=estimator_count)
  95. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  96. with self.llm_context:
  97. if isinstance(text, Generator):
  98. assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
  99. for i in self.llm.inference_bistream(text=text,
  100. prompt_text=prompt_text.to(self.device),
  101. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  102. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  103. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  104. embedding=llm_embedding.to(self.device)):
  105. self.tts_speech_token_dict[uuid].append(i)
  106. else:
  107. for i in self.llm.inference(text=text.to(self.device),
  108. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  109. prompt_text=prompt_text.to(self.device),
  110. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  111. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  112. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  113. embedding=llm_embedding.to(self.device)):
  114. self.tts_speech_token_dict[uuid].append(i)
  115. self.llm_end_dict[uuid] = True
  116. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  117. tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
  118. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  119. prompt_token=prompt_token.to(self.device),
  120. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  121. prompt_feat=prompt_feat.to(self.device),
  122. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  123. embedding=embedding.to(self.device),
  124. flow_cache=self.flow_cache_dict[uuid])
  125. self.flow_cache_dict[uuid] = flow_cache
  126. # mel overlap fade in out
  127. if self.mel_overlap_dict[uuid].shape[2] != 0:
  128. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  129. # append hift cache
  130. if self.hift_cache_dict[uuid] is not None:
  131. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  132. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  133. else:
  134. hift_cache_source = torch.zeros(1, 1, 0)
  135. # keep overlap mel and hift cache
  136. if finalize is False:
  137. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  138. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  139. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  140. if self.hift_cache_dict[uuid] is not None:
  141. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  142. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  143. 'source': tts_source[:, :, -self.source_cache_len:],
  144. 'speech': tts_speech[:, -self.source_cache_len:]}
  145. tts_speech = tts_speech[:, :-self.source_cache_len]
  146. else:
  147. if speed != 1.0:
  148. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  149. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  150. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  151. if self.hift_cache_dict[uuid] is not None:
  152. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  153. return tts_speech
  154. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  155. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  156. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  157. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  158. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  159. # this_uuid is used to track variables related to this inference thread
  160. this_uuid = str(uuid.uuid1())
  161. with self.lock:
  162. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  163. self.hift_cache_dict[this_uuid] = None
  164. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  165. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  166. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  167. p.start()
  168. if stream is True:
  169. token_hop_len = self.token_min_hop_len
  170. while True:
  171. time.sleep(0.1)
  172. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  173. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  174. .unsqueeze(dim=0)
  175. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  176. prompt_token=flow_prompt_speech_token,
  177. prompt_feat=prompt_speech_feat,
  178. embedding=flow_embedding,
  179. uuid=this_uuid,
  180. finalize=False)
  181. yield {'tts_speech': this_tts_speech.cpu()}
  182. with self.lock:
  183. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  184. # increase token_hop_len for better speech quality
  185. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  186. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  187. break
  188. p.join()
  189. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  190. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  191. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  192. prompt_token=flow_prompt_speech_token,
  193. prompt_feat=prompt_speech_feat,
  194. embedding=flow_embedding,
  195. uuid=this_uuid,
  196. finalize=True)
  197. yield {'tts_speech': this_tts_speech.cpu()}
  198. else:
  199. # deal with all tokens
  200. p.join()
  201. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  202. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  203. prompt_token=flow_prompt_speech_token,
  204. prompt_feat=prompt_speech_feat,
  205. embedding=flow_embedding,
  206. uuid=this_uuid,
  207. finalize=True,
  208. speed=speed)
  209. yield {'tts_speech': this_tts_speech.cpu()}
  210. with self.lock:
  211. self.tts_speech_token_dict.pop(this_uuid)
  212. self.llm_end_dict.pop(this_uuid)
  213. self.mel_overlap_dict.pop(this_uuid)
  214. self.hift_cache_dict.pop(this_uuid)
  215. self.flow_cache_dict.pop(this_uuid)
  216. torch.cuda.empty_cache()
  217. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  218. # this_uuid is used to track variables related to this inference thread
  219. this_uuid = str(uuid.uuid1())
  220. with self.lock:
  221. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  222. self.hift_cache_dict[this_uuid] = None
  223. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  224. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  225. if stream is True:
  226. token_hop_len = self.token_min_hop_len
  227. while True:
  228. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  229. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  230. .unsqueeze(dim=0)
  231. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  232. prompt_token=flow_prompt_speech_token,
  233. prompt_feat=prompt_speech_feat,
  234. embedding=flow_embedding,
  235. uuid=this_uuid,
  236. finalize=False)
  237. yield {'tts_speech': this_tts_speech.cpu()}
  238. with self.lock:
  239. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  240. # increase token_hop_len for better speech quality
  241. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  242. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  243. break
  244. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  245. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  246. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  247. prompt_token=flow_prompt_speech_token,
  248. prompt_feat=prompt_speech_feat,
  249. embedding=flow_embedding,
  250. uuid=this_uuid,
  251. finalize=True)
  252. yield {'tts_speech': this_tts_speech.cpu()}
  253. else:
  254. # deal with all tokens
  255. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  256. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  257. prompt_token=flow_prompt_speech_token,
  258. prompt_feat=prompt_speech_feat,
  259. embedding=flow_embedding,
  260. uuid=this_uuid,
  261. finalize=True,
  262. speed=speed)
  263. yield {'tts_speech': this_tts_speech.cpu()}
  264. with self.lock:
  265. self.tts_speech_token_dict.pop(this_uuid)
  266. self.llm_end_dict.pop(this_uuid)
  267. self.mel_overlap_dict.pop(this_uuid)
  268. self.hift_cache_dict.pop(this_uuid)
  269. torch.cuda.empty_cache()
  270. class CosyVoice2Model(CosyVoiceModel):
  271. def __init__(self,
  272. llm: torch.nn.Module,
  273. flow: torch.nn.Module,
  274. hift: torch.nn.Module,
  275. fp16: bool):
  276. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  277. self.llm = llm
  278. self.flow = flow
  279. self.hift = hift
  280. self.fp16 = fp16
  281. self.llm.fp16 = fp16
  282. self.flow.fp16 = fp16
  283. if self.fp16 is True:
  284. self.llm.half()
  285. self.flow.half()
  286. self.token_hop_len = 2 * self.flow.input_frame_rate
  287. # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
  288. self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
  289. self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
  290. # hift cache
  291. self.mel_cache_len = 8
  292. self.source_cache_len = int(self.mel_cache_len * 480)
  293. # speech fade in out
  294. self.speech_window = np.hamming(2 * self.source_cache_len)
  295. # rtf and decoding related
  296. self.stream_scale_factor = 1
  297. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  298. self.lock = threading.Lock()
  299. # dict used to store session related variable
  300. self.tts_speech_token_dict = {}
  301. self.llm_end_dict = {}
  302. self.hift_cache_dict = {}
  303. def load_jit(self, flow_encoder_model):
  304. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  305. self.flow.encoder = flow_encoder
  306. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
  307. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  308. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  309. prompt_token=prompt_token.to(self.device),
  310. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  311. prompt_feat=prompt_feat.to(self.device),
  312. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  313. embedding=embedding.to(self.device),
  314. finalize=finalize)
  315. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  316. # append hift cache
  317. if self.hift_cache_dict[uuid] is not None:
  318. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  319. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  320. else:
  321. hift_cache_source = torch.zeros(1, 1, 0)
  322. # keep overlap mel and hift cache
  323. if finalize is False:
  324. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  325. if self.hift_cache_dict[uuid] is not None:
  326. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  327. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  328. 'source': tts_source[:, :, -self.source_cache_len:],
  329. 'speech': tts_speech[:, -self.source_cache_len:]}
  330. tts_speech = tts_speech[:, :-self.source_cache_len]
  331. else:
  332. if speed != 1.0:
  333. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  334. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  335. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  336. if self.hift_cache_dict[uuid] is not None:
  337. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  338. return tts_speech
  339. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  340. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  341. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  342. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  343. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  344. # this_uuid is used to track variables related to this inference thread
  345. this_uuid = str(uuid.uuid1())
  346. with self.lock:
  347. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  348. self.hift_cache_dict[this_uuid] = None
  349. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  350. p.start()
  351. if stream is True:
  352. token_offset = 0
  353. while True:
  354. time.sleep(0.1)
  355. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
  356. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  357. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  358. prompt_token=flow_prompt_speech_token,
  359. prompt_feat=prompt_speech_feat,
  360. embedding=flow_embedding,
  361. uuid=this_uuid,
  362. token_offset=token_offset,
  363. finalize=False)
  364. token_offset += self.token_hop_len
  365. yield {'tts_speech': this_tts_speech.cpu()}
  366. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
  367. break
  368. p.join()
  369. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  370. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  371. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  372. prompt_token=flow_prompt_speech_token,
  373. prompt_feat=prompt_speech_feat,
  374. embedding=flow_embedding,
  375. uuid=this_uuid,
  376. token_offset=token_offset,
  377. finalize=True)
  378. yield {'tts_speech': this_tts_speech.cpu()}
  379. else:
  380. # deal with all tokens
  381. p.join()
  382. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  383. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  384. prompt_token=flow_prompt_speech_token,
  385. prompt_feat=prompt_speech_feat,
  386. embedding=flow_embedding,
  387. uuid=this_uuid,
  388. token_offset=0,
  389. finalize=True,
  390. speed=speed)
  391. yield {'tts_speech': this_tts_speech.cpu()}
  392. with self.lock:
  393. self.tts_speech_token_dict.pop(this_uuid)
  394. self.llm_end_dict.pop(this_uuid)
  395. torch.cuda.empty_cache()
  396. class VllmCosyVoice2Model(CosyVoice2Model):
  397. def __init__(self,
  398. model_dir: str,
  399. flow: torch.nn.Module,
  400. hift: torch.nn.Module,
  401. fp16: bool):
  402. try:
  403. from cosyvoice.llm.llm_vllm import VllmQwen2LM
  404. except Exception as e:
  405. raise e
  406. llm = VllmQwen2LM(model_dir)
  407. super().__init__(llm,flow,hift,fp16)
  408. def load(self, llm_model, flow_model, hift_model):
  409. self.flow.load_state_dict(torch.load(flow_model, weights_only=True, map_location=self.device), strict=True)
  410. self.flow.to(self.device).eval()
  411. # in case hift_model is a hifigan model
  412. hift_state_dict = {k.replace('generator.', ''): v for k, v in
  413. torch.load(hift_model, weights_only=True, map_location=self.device).items()}
  414. self.hift.load_state_dict(hift_state_dict, strict=True)
  415. self.hift.to(self.device).eval()