model.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import numpy as np
  16. import threading
  17. import time
  18. from torch.nn import functional as F
  19. from contextlib import nullcontext
  20. import uuid
  21. from cosyvoice.utils.common import fade_in_out
  22. from cosyvoice.utils.common import is_only_punctuation
  23. class CosyVoiceModel:
  24. def __init__(self,
  25. llm: torch.nn.Module,
  26. flow: torch.nn.Module,
  27. hift: torch.nn.Module,
  28. fp16: bool):
  29. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  30. self.llm = llm
  31. self.flow = flow
  32. self.hift = hift
  33. self.fp16 = fp16
  34. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  35. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  36. self.token_overlap_len = 20
  37. # mel fade in out
  38. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  39. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  40. # hift cache
  41. self.mel_cache_len = 20
  42. self.source_cache_len = int(self.mel_cache_len * 256)
  43. # speech fade in out
  44. self.speech_window = np.hamming(2 * self.source_cache_len)
  45. # rtf and decoding related
  46. self.stream_scale_factor = 1
  47. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  48. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  49. self.lock = threading.Lock()
  50. # dict used to store session related variable
  51. self.tts_speech_token_dict = {}
  52. self.llm_end_dict = {}
  53. self.mel_overlap_dict = {}
  54. self.flow_cache_dict = {}
  55. self.hift_cache_dict = {}
  56. def load(self, llm_model, flow_model, hift_model):
  57. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
  58. self.llm.to(self.device).eval()
  59. if self.fp16 is True:
  60. self.llm.half()
  61. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
  62. self.flow.to(self.device).eval()
  63. # in case hift_model is a hifigan model
  64. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  65. self.hift.load_state_dict(hift_state_dict, strict=False)
  66. self.hift.to(self.device).eval()
  67. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  68. assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
  69. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  70. self.llm.text_encoder = llm_text_encoder
  71. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  72. self.llm.llm = llm_llm
  73. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  74. self.flow.encoder = flow_encoder
  75. def load_onnx(self, flow_decoder_estimator_model):
  76. import onnxruntime
  77. option = onnxruntime.SessionOptions()
  78. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  79. option.intra_op_num_threads = 1
  80. providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
  81. del self.flow.decoder.estimator
  82. self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
  83. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  84. if self.fp16 is True:
  85. llm_embedding = llm_embedding.half()
  86. with self.llm_context:
  87. for i in self.llm.inference(text=text.to(self.device),
  88. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  89. prompt_text=prompt_text.to(self.device),
  90. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  91. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  92. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  93. embedding=llm_embedding.to(self.device)):
  94. self.tts_speech_token_dict[uuid].append(i)
  95. self.llm_end_dict[uuid] = True
  96. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  97. tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
  98. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  99. prompt_token=prompt_token.to(self.device),
  100. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  101. prompt_feat=prompt_feat.to(self.device),
  102. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  103. embedding=embedding.to(self.device),
  104. flow_cache=self.flow_cache_dict[uuid])
  105. self.flow_cache_dict[uuid] = flow_cache
  106. # mel overlap fade in out
  107. if self.mel_overlap_dict[uuid].shape[2] != 0:
  108. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  109. # append hift cache
  110. if self.hift_cache_dict[uuid] is not None:
  111. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  112. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  113. else:
  114. hift_cache_source = torch.zeros(1, 1, 0)
  115. # keep overlap mel and hift cache
  116. if finalize is False:
  117. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  118. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  119. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  120. if self.hift_cache_dict[uuid] is not None:
  121. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  122. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  123. 'source': tts_source[:, :, -self.source_cache_len:],
  124. 'speech': tts_speech[:, -self.source_cache_len:]}
  125. tts_speech = tts_speech[:, :-self.source_cache_len]
  126. else:
  127. if speed != 1.0:
  128. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  129. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  130. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  131. if self.hift_cache_dict[uuid] is not None:
  132. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  133. return tts_speech
  134. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  135. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  136. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  137. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  138. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  139. # When generating text that contains only punctuation marks or whitespace characters
  140. # - Returning 10ms of silence ensures consistent processing logic.
  141. if is_only_punctuation(text):
  142. return {'tts_speech': torch.zeros(1, int(0.01 * 22050))}
  143. # this_uuid is used to track variables related to this inference thread
  144. this_uuid = str(uuid.uuid1())
  145. with self.lock:
  146. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  147. self.hift_cache_dict[this_uuid] = None
  148. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  149. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  150. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  151. p.start()
  152. if stream is True:
  153. token_hop_len = self.token_min_hop_len
  154. while True:
  155. time.sleep(0.1)
  156. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  157. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  158. .unsqueeze(dim=0)
  159. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  160. prompt_token=flow_prompt_speech_token,
  161. prompt_feat=prompt_speech_feat,
  162. embedding=flow_embedding,
  163. uuid=this_uuid,
  164. finalize=False)
  165. yield {'tts_speech': this_tts_speech.cpu()}
  166. with self.lock:
  167. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  168. # increase token_hop_len for better speech quality
  169. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  170. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  171. break
  172. p.join()
  173. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  174. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  175. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  176. prompt_token=flow_prompt_speech_token,
  177. prompt_feat=prompt_speech_feat,
  178. embedding=flow_embedding,
  179. uuid=this_uuid,
  180. finalize=True)
  181. yield {'tts_speech': this_tts_speech.cpu()}
  182. else:
  183. # deal with all tokens
  184. p.join()
  185. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  186. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  187. prompt_token=flow_prompt_speech_token,
  188. prompt_feat=prompt_speech_feat,
  189. embedding=flow_embedding,
  190. uuid=this_uuid,
  191. finalize=True,
  192. speed=speed)
  193. yield {'tts_speech': this_tts_speech.cpu()}
  194. with self.lock:
  195. self.tts_speech_token_dict.pop(this_uuid)
  196. self.llm_end_dict.pop(this_uuid)
  197. self.mel_overlap_dict.pop(this_uuid)
  198. self.hift_cache_dict.pop(this_uuid)
  199. self.flow_cache_dict.pop(this_uuid)
  200. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  201. # this_uuid is used to track variables related to this inference thread
  202. this_uuid = str(uuid.uuid1())
  203. with self.lock:
  204. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  205. self.hift_cache_dict[this_uuid] = None
  206. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  207. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  208. if stream is True:
  209. token_hop_len = self.token_min_hop_len
  210. while True:
  211. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  212. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  213. .unsqueeze(dim=0)
  214. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  215. prompt_token=flow_prompt_speech_token,
  216. prompt_feat=prompt_speech_feat,
  217. embedding=flow_embedding,
  218. uuid=this_uuid,
  219. finalize=False)
  220. yield {'tts_speech': this_tts_speech.cpu()}
  221. with self.lock:
  222. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  223. # increase token_hop_len for better speech quality
  224. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  225. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  226. break
  227. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  228. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  229. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  230. prompt_token=flow_prompt_speech_token,
  231. prompt_feat=prompt_speech_feat,
  232. embedding=flow_embedding,
  233. uuid=this_uuid,
  234. finalize=True)
  235. yield {'tts_speech': this_tts_speech.cpu()}
  236. else:
  237. # deal with all tokens
  238. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  239. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  240. prompt_token=flow_prompt_speech_token,
  241. prompt_feat=prompt_speech_feat,
  242. embedding=flow_embedding,
  243. uuid=this_uuid,
  244. finalize=True,
  245. speed=speed)
  246. yield {'tts_speech': this_tts_speech.cpu()}
  247. with self.lock:
  248. self.tts_speech_token_dict.pop(this_uuid)
  249. self.llm_end_dict.pop(this_uuid)
  250. self.mel_overlap_dict.pop(this_uuid)
  251. self.hift_cache_dict.pop(this_uuid)