model.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import numpy as np
  16. import threading
  17. import time
  18. from torch.nn import functional as F
  19. from contextlib import nullcontext
  20. import uuid
  21. from cosyvoice.utils.common import fade_in_out
  22. from cosyvoice.trt.estimator_trt import EstimatorTRT
  23. class CosyVoiceModel:
  24. def __init__(self,
  25. llm: torch.nn.Module,
  26. flow: torch.nn.Module,
  27. hift: torch.nn.Module,
  28. fp16: bool):
  29. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  30. self.llm = llm
  31. self.flow = flow
  32. self.hift = hift
  33. self.fp16 = fp16
  34. self.llm.fp16 = fp16
  35. self.flow.fp16 = fp16
  36. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  37. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  38. self.token_overlap_len = 20
  39. # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
  40. self.flow.decoder.estimator.static_chunk_size = 0
  41. # mel fade in out
  42. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  43. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  44. # hift cache
  45. self.mel_cache_len = 20
  46. self.source_cache_len = int(self.mel_cache_len * 256)
  47. # speech fade in out
  48. self.speech_window = np.hamming(2 * self.source_cache_len)
  49. # rtf and decoding related
  50. self.stream_scale_factor = 1
  51. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  52. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  53. self.lock = threading.Lock()
  54. # dict used to store session related variable
  55. self.tts_speech_token_dict = {}
  56. self.llm_end_dict = {}
  57. self.mel_overlap_dict = {}
  58. self.flow_cache_dict = {}
  59. self.hift_cache_dict = {}
  60. def load(self, llm_model, flow_model, hift_model):
  61. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  62. self.llm.to(self.device).eval()
  63. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  64. self.flow.to(self.device).eval()
  65. # in case hift_model is a hifigan model
  66. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  67. self.hift.load_state_dict(hift_state_dict, strict=True)
  68. self.hift.to(self.device).eval()
  69. if self.fp16 is True:
  70. self.llm.half()
  71. self.flow.half()
  72. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  73. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  74. self.llm.text_encoder = llm_text_encoder
  75. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  76. self.llm.llm = llm_llm
  77. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  78. self.flow.encoder = flow_encoder
  79. def load_trt(self, flow_decoder_estimator_model, fp16):
  80. del self.flow.decoder.estimator
  81. self.flow.decoder.estimator = EstimatorTRT(flow_decoder_estimator_model, self.device, fp16)
  82. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  83. with self.llm_context:
  84. for i in self.llm.inference(text=text.to(self.device),
  85. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  86. prompt_text=prompt_text.to(self.device),
  87. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  88. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  89. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  90. embedding=llm_embedding.to(self.device)):
  91. self.tts_speech_token_dict[uuid].append(i)
  92. self.llm_end_dict[uuid] = True
  93. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  94. tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
  95. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  96. prompt_token=prompt_token.to(self.device),
  97. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  98. prompt_feat=prompt_feat.to(self.device),
  99. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  100. embedding=embedding.to(self.device),
  101. flow_cache=self.flow_cache_dict[uuid])
  102. self.flow_cache_dict[uuid] = flow_cache
  103. # mel overlap fade in out
  104. if self.mel_overlap_dict[uuid].shape[2] != 0:
  105. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  106. # append hift cache
  107. if self.hift_cache_dict[uuid] is not None:
  108. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  109. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  110. else:
  111. hift_cache_source = torch.zeros(1, 1, 0)
  112. # keep overlap mel and hift cache
  113. if finalize is False:
  114. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  115. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  116. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  117. if self.hift_cache_dict[uuid] is not None:
  118. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  119. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  120. 'source': tts_source[:, :, -self.source_cache_len:],
  121. 'speech': tts_speech[:, -self.source_cache_len:]}
  122. tts_speech = tts_speech[:, :-self.source_cache_len]
  123. else:
  124. if speed != 1.0:
  125. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  126. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  127. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  128. if self.hift_cache_dict[uuid] is not None:
  129. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  130. return tts_speech
  131. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  132. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  133. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  134. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  135. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  136. # this_uuid is used to track variables related to this inference thread
  137. this_uuid = str(uuid.uuid1())
  138. with self.lock:
  139. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  140. self.hift_cache_dict[this_uuid] = None
  141. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  142. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  143. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  144. p.start()
  145. if stream is True:
  146. token_hop_len = self.token_min_hop_len
  147. while True:
  148. time.sleep(0.1)
  149. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  150. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  151. .unsqueeze(dim=0)
  152. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  153. prompt_token=flow_prompt_speech_token,
  154. prompt_feat=prompt_speech_feat,
  155. embedding=flow_embedding,
  156. uuid=this_uuid,
  157. finalize=False)
  158. yield {'tts_speech': this_tts_speech.cpu()}
  159. with self.lock:
  160. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  161. # increase token_hop_len for better speech quality
  162. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  163. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  164. break
  165. p.join()
  166. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  167. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  168. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  169. prompt_token=flow_prompt_speech_token,
  170. prompt_feat=prompt_speech_feat,
  171. embedding=flow_embedding,
  172. uuid=this_uuid,
  173. finalize=True)
  174. yield {'tts_speech': this_tts_speech.cpu()}
  175. else:
  176. # deal with all tokens
  177. p.join()
  178. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  179. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  180. prompt_token=flow_prompt_speech_token,
  181. prompt_feat=prompt_speech_feat,
  182. embedding=flow_embedding,
  183. uuid=this_uuid,
  184. finalize=True,
  185. speed=speed)
  186. yield {'tts_speech': this_tts_speech.cpu()}
  187. with self.lock:
  188. self.tts_speech_token_dict.pop(this_uuid)
  189. self.llm_end_dict.pop(this_uuid)
  190. self.mel_overlap_dict.pop(this_uuid)
  191. self.hift_cache_dict.pop(this_uuid)
  192. self.flow_cache_dict.pop(this_uuid)
  193. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  194. # this_uuid is used to track variables related to this inference thread
  195. this_uuid = str(uuid.uuid1())
  196. with self.lock:
  197. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  198. self.hift_cache_dict[this_uuid] = None
  199. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  200. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  201. if stream is True:
  202. token_hop_len = self.token_min_hop_len
  203. while True:
  204. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  205. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  206. .unsqueeze(dim=0)
  207. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  208. prompt_token=flow_prompt_speech_token,
  209. prompt_feat=prompt_speech_feat,
  210. embedding=flow_embedding,
  211. uuid=this_uuid,
  212. finalize=False)
  213. yield {'tts_speech': this_tts_speech.cpu()}
  214. with self.lock:
  215. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  216. # increase token_hop_len for better speech quality
  217. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  218. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  219. break
  220. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  221. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  222. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  223. prompt_token=flow_prompt_speech_token,
  224. prompt_feat=prompt_speech_feat,
  225. embedding=flow_embedding,
  226. uuid=this_uuid,
  227. finalize=True)
  228. yield {'tts_speech': this_tts_speech.cpu()}
  229. else:
  230. # deal with all tokens
  231. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  232. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  233. prompt_token=flow_prompt_speech_token,
  234. prompt_feat=prompt_speech_feat,
  235. embedding=flow_embedding,
  236. uuid=this_uuid,
  237. finalize=True,
  238. speed=speed)
  239. yield {'tts_speech': this_tts_speech.cpu()}
  240. with self.lock:
  241. self.tts_speech_token_dict.pop(this_uuid)
  242. self.llm_end_dict.pop(this_uuid)
  243. self.mel_overlap_dict.pop(this_uuid)
  244. self.hift_cache_dict.pop(this_uuid)
  245. class CosyVoice2Model(CosyVoiceModel):
  246. def __init__(self,
  247. llm: torch.nn.Module,
  248. flow: torch.nn.Module,
  249. hift: torch.nn.Module,
  250. fp16: bool):
  251. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  252. self.llm = llm
  253. self.flow = flow
  254. self.hift = hift
  255. self.fp16 = fp16
  256. self.llm.fp16 = fp16
  257. self.flow.fp16 = fp16
  258. self.token_hop_len = 2 * self.flow.input_frame_rate
  259. # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
  260. self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
  261. self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
  262. # hift cache
  263. self.mel_cache_len = 8
  264. self.source_cache_len = int(self.mel_cache_len * 480)
  265. # speech fade in out
  266. self.speech_window = np.hamming(2 * self.source_cache_len)
  267. # rtf and decoding related
  268. self.stream_scale_factor = 1
  269. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  270. self.lock = threading.Lock()
  271. # dict used to store session related variable
  272. self.tts_speech_token_dict = {}
  273. self.llm_end_dict = {}
  274. self.hift_cache_dict = {}
  275. def load_jit(self, flow_encoder_model):
  276. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  277. self.flow.encoder = flow_encoder
  278. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
  279. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  280. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  281. prompt_token=prompt_token.to(self.device),
  282. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  283. prompt_feat=prompt_feat.to(self.device),
  284. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  285. embedding=embedding.to(self.device),
  286. finalize=finalize)
  287. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  288. # append hift cache
  289. if self.hift_cache_dict[uuid] is not None:
  290. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  291. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  292. else:
  293. hift_cache_source = torch.zeros(1, 1, 0)
  294. # keep overlap mel and hift cache
  295. if finalize is False:
  296. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  297. if self.hift_cache_dict[uuid] is not None:
  298. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  299. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  300. 'source': tts_source[:, :, -self.source_cache_len:],
  301. 'speech': tts_speech[:, -self.source_cache_len:]}
  302. tts_speech = tts_speech[:, :-self.source_cache_len]
  303. else:
  304. if speed != 1.0:
  305. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  306. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  307. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  308. if self.hift_cache_dict[uuid] is not None:
  309. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  310. return tts_speech
  311. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  312. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  313. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  314. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  315. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  316. # this_uuid is used to track variables related to this inference thread
  317. this_uuid = str(uuid.uuid1())
  318. with self.lock:
  319. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  320. self.hift_cache_dict[this_uuid] = None
  321. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  322. p.start()
  323. if stream is True:
  324. token_offset = 0
  325. while True:
  326. time.sleep(0.1)
  327. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
  328. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  329. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  330. prompt_token=flow_prompt_speech_token,
  331. prompt_feat=prompt_speech_feat,
  332. embedding=flow_embedding,
  333. uuid=this_uuid,
  334. token_offset=token_offset,
  335. finalize=False)
  336. token_offset += self.token_hop_len
  337. yield {'tts_speech': this_tts_speech.cpu()}
  338. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
  339. break
  340. p.join()
  341. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  342. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  343. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  344. prompt_token=flow_prompt_speech_token,
  345. prompt_feat=prompt_speech_feat,
  346. embedding=flow_embedding,
  347. uuid=this_uuid,
  348. token_offset=token_offset,
  349. finalize=True)
  350. yield {'tts_speech': this_tts_speech.cpu()}
  351. else:
  352. # deal with all tokens
  353. p.join()
  354. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  355. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  356. prompt_token=flow_prompt_speech_token,
  357. prompt_feat=prompt_speech_feat,
  358. embedding=flow_embedding,
  359. uuid=this_uuid,
  360. token_offset=0,
  361. finalize=True,
  362. speed=speed)
  363. yield {'tts_speech': this_tts_speech.cpu()}
  364. with self.lock:
  365. self.tts_speech_token_dict.pop(this_uuid)
  366. self.llm_end_dict.pop(this_uuid)