1
0

model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import numpy as np
  16. import threading
  17. import time
  18. from torch.nn import functional as F
  19. from contextlib import nullcontext
  20. import uuid
  21. from cosyvoice.utils.common import fade_in_out
  22. class CosyVoiceModel:
  23. def __init__(self,
  24. llm: torch.nn.Module,
  25. flow: torch.nn.Module,
  26. hift: torch.nn.Module,
  27. fp16: bool):
  28. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  29. self.llm = llm
  30. self.flow = flow
  31. self.hift = hift
  32. self.fp16 = fp16
  33. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  34. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  35. self.token_overlap_len = 20
  36. # mel fade in out
  37. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  38. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  39. # hift cache
  40. self.mel_cache_len = 20
  41. self.source_cache_len = int(self.mel_cache_len * 256)
  42. # speech fade in out
  43. self.speech_window = np.hamming(2 * self.source_cache_len)
  44. # rtf and decoding related
  45. self.stream_scale_factor = 1
  46. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  47. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  48. self.lock = threading.Lock()
  49. # dict used to store session related variable
  50. self.tts_speech_token_dict = {}
  51. self.llm_end_dict = {}
  52. self.mel_overlap_dict = {}
  53. self.flow_cache_dict = {}
  54. self.hift_cache_dict = {}
  55. def load(self, llm_model, flow_model, hift_model):
  56. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
  57. self.llm.to(self.device).eval()
  58. if self.fp16 is True:
  59. self.llm.half()
  60. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
  61. self.flow.to(self.device).eval()
  62. # in case hift_model is a hifigan model
  63. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  64. self.hift.load_state_dict(hift_state_dict, strict=False)
  65. self.hift.to(self.device).eval()
  66. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  67. assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
  68. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  69. self.llm.text_encoder = llm_text_encoder
  70. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  71. self.llm.llm = llm_llm
  72. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  73. self.flow.encoder = flow_encoder
  74. def load_onnx(self, flow_decoder_estimator_model):
  75. import onnxruntime
  76. option = onnxruntime.SessionOptions()
  77. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  78. option.intra_op_num_threads = 1
  79. providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
  80. del self.flow.decoder.estimator
  81. self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
  82. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  83. if self.fp16 is True:
  84. llm_embedding = llm_embedding.half()
  85. with self.llm_context:
  86. for i in self.llm.inference(text=text.to(self.device),
  87. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  88. prompt_text=prompt_text.to(self.device),
  89. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  90. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  91. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  92. embedding=llm_embedding.to(self.device)):
  93. self.tts_speech_token_dict[uuid].append(i)
  94. self.llm_end_dict[uuid] = True
  95. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  96. tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
  97. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  98. prompt_token=prompt_token.to(self.device),
  99. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  100. prompt_feat=prompt_feat.to(self.device),
  101. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  102. embedding=embedding.to(self.device),
  103. flow_cache=self.flow_cache_dict[uuid])
  104. self.flow_cache_dict[uuid] = flow_cache
  105. # mel overlap fade in out
  106. if self.mel_overlap_dict[uuid].shape[2] != 0:
  107. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  108. # append hift cache
  109. if self.hift_cache_dict[uuid] is not None:
  110. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  111. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  112. else:
  113. hift_cache_source = torch.zeros(1, 1, 0)
  114. # keep overlap mel and hift cache
  115. if finalize is False:
  116. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  117. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  118. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  119. if self.hift_cache_dict[uuid] is not None:
  120. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  121. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  122. 'source': tts_source[:, :, -self.source_cache_len:],
  123. 'speech': tts_speech[:, -self.source_cache_len:]}
  124. tts_speech = tts_speech[:, :-self.source_cache_len]
  125. else:
  126. if speed != 1.0:
  127. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  128. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  129. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  130. if self.hift_cache_dict[uuid] is not None:
  131. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  132. return tts_speech
  133. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  134. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  135. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  136. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  137. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  138. # this_uuid is used to track variables related to this inference thread
  139. this_uuid = str(uuid.uuid1())
  140. with self.lock:
  141. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  142. self.hift_cache_dict[this_uuid] = None
  143. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  144. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  145. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  146. p.start()
  147. if stream is True:
  148. token_hop_len = self.token_min_hop_len
  149. while True:
  150. time.sleep(0.1)
  151. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  152. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  153. .unsqueeze(dim=0)
  154. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  155. prompt_token=flow_prompt_speech_token,
  156. prompt_feat=prompt_speech_feat,
  157. embedding=flow_embedding,
  158. uuid=this_uuid,
  159. finalize=False)
  160. yield {'tts_speech': this_tts_speech.cpu()}
  161. with self.lock:
  162. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  163. # increase token_hop_len for better speech quality
  164. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  165. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  166. break
  167. p.join()
  168. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  169. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  170. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  171. prompt_token=flow_prompt_speech_token,
  172. prompt_feat=prompt_speech_feat,
  173. embedding=flow_embedding,
  174. uuid=this_uuid,
  175. finalize=True)
  176. yield {'tts_speech': this_tts_speech.cpu()}
  177. else:
  178. # deal with all tokens
  179. p.join()
  180. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  181. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  182. prompt_token=flow_prompt_speech_token,
  183. prompt_feat=prompt_speech_feat,
  184. embedding=flow_embedding,
  185. uuid=this_uuid,
  186. finalize=True,
  187. speed=speed)
  188. yield {'tts_speech': this_tts_speech.cpu()}
  189. with self.lock:
  190. self.tts_speech_token_dict.pop(this_uuid)
  191. self.llm_end_dict.pop(this_uuid)
  192. self.mel_overlap_dict.pop(this_uuid)
  193. self.hift_cache_dict.pop(this_uuid)
  194. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  195. # this_uuid is used to track variables related to this inference thread
  196. this_uuid = str(uuid.uuid1())
  197. with self.lock:
  198. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  199. self.hift_cache_dict[this_uuid] = None
  200. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  201. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  202. if stream is True:
  203. token_hop_len = self.token_min_hop_len
  204. while True:
  205. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  206. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  207. .unsqueeze(dim=0)
  208. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  209. prompt_token=flow_prompt_speech_token,
  210. prompt_feat=prompt_speech_feat,
  211. embedding=flow_embedding,
  212. uuid=this_uuid,
  213. finalize=False)
  214. yield {'tts_speech': this_tts_speech.cpu()}
  215. with self.lock:
  216. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  217. # increase token_hop_len for better speech quality
  218. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  219. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  220. break
  221. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  222. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
  223. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  224. prompt_token=flow_prompt_speech_token,
  225. prompt_feat=prompt_speech_feat,
  226. embedding=flow_embedding,
  227. uuid=this_uuid,
  228. finalize=True)
  229. yield {'tts_speech': this_tts_speech.cpu()}
  230. else:
  231. # deal with all tokens
  232. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  233. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  234. prompt_token=flow_prompt_speech_token,
  235. prompt_feat=prompt_speech_feat,
  236. embedding=flow_embedding,
  237. uuid=this_uuid,
  238. finalize=True,
  239. speed=speed)
  240. yield {'tts_speech': this_tts_speech.cpu()}
  241. with self.lock:
  242. self.tts_speech_token_dict.pop(this_uuid)
  243. self.llm_end_dict.pop(this_uuid)
  244. self.mel_overlap_dict.pop(this_uuid)
  245. self.hift_cache_dict.pop(this_uuid)