model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import numpy as np
  16. import threading
  17. import time
  18. from torch.nn import functional as F
  19. from contextlib import nullcontext
  20. import uuid
  21. from cosyvoice.utils.common import fade_in_out
  22. class CosyVoiceModel:
  23. def __init__(self,
  24. llm: torch.nn.Module,
  25. flow: torch.nn.Module,
  26. hift: torch.nn.Module):
  27. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  28. self.llm = llm
  29. self.flow = flow
  30. self.hift = hift
  31. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  32. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  33. self.token_overlap_len = 20
  34. # mel fade in out
  35. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  36. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  37. # hift cache
  38. self.mel_cache_len = 20
  39. self.source_cache_len = int(self.mel_cache_len * 256)
  40. # speech fade in out
  41. self.speech_window = np.hamming(2 * self.source_cache_len)
  42. # rtf and decoding related
  43. self.stream_scale_factor = 1
  44. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  45. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  46. self.lock = threading.Lock()
  47. # dict used to store session related variable
  48. self.tts_speech_token_dict = {}
  49. self.llm_end_dict = {}
  50. self.mel_overlap_dict = {}
  51. self.hift_cache_dict = {}
  52. def load(self, llm_model, flow_model, hift_model):
  53. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  54. self.llm.to(self.device).eval()
  55. self.llm.half()
  56. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  57. self.flow.to(self.device).eval()
  58. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  59. self.hift.to(self.device).eval()
  60. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  61. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  62. self.llm.text_encoder = llm_text_encoder
  63. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  64. self.llm.llm = llm_llm
  65. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  66. self.flow.encoder = flow_encoder
  67. def load_onnx(self, flow_decoder_estimator_model):
  68. import onnxruntime
  69. option = onnxruntime.SessionOptions()
  70. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  71. option.intra_op_num_threads = 1
  72. providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
  73. del self.flow.decoder.estimator
  74. self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
  75. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  76. with self.llm_context:
  77. for i in self.llm.inference(text=text.to(self.device),
  78. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  79. prompt_text=prompt_text.to(self.device),
  80. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  81. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  82. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  83. embedding=llm_embedding.to(self.device).half()):
  84. self.tts_speech_token_dict[uuid].append(i)
  85. self.llm_end_dict[uuid] = True
  86. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  87. tts_mel = self.flow.inference(token=token.to(self.device),
  88. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  89. prompt_token=prompt_token.to(self.device),
  90. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  91. prompt_feat=prompt_feat.to(self.device),
  92. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  93. embedding=embedding.to(self.device))
  94. # mel overlap fade in out
  95. if self.mel_overlap_dict[uuid] is not None:
  96. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  97. # append hift cache
  98. if self.hift_cache_dict[uuid] is not None:
  99. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  100. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  101. else:
  102. hift_cache_source = torch.zeros(1, 1, 0)
  103. # keep overlap mel and hift cache
  104. if finalize is False:
  105. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  106. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  107. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  108. if self.hift_cache_dict[uuid] is not None:
  109. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  110. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  111. 'source': tts_source[:, :, -self.source_cache_len:],
  112. 'speech': tts_speech[:, -self.source_cache_len:]}
  113. tts_speech = tts_speech[:, :-self.source_cache_len]
  114. else:
  115. if speed != 1.0:
  116. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  117. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  118. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  119. if self.hift_cache_dict[uuid] is not None:
  120. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  121. return tts_speech
  122. def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  123. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  124. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  125. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  126. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
  127. # this_uuid is used to track variables related to this inference thread
  128. this_uuid = str(uuid.uuid1())
  129. with self.lock:
  130. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  131. self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
  132. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  133. p.start()
  134. if stream is True:
  135. token_hop_len = self.token_min_hop_len
  136. while True:
  137. time.sleep(0.1)
  138. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  139. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  140. .unsqueeze(dim=0)
  141. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  142. prompt_token=flow_prompt_speech_token,
  143. prompt_feat=prompt_speech_feat,
  144. embedding=flow_embedding,
  145. uuid=this_uuid,
  146. finalize=False)
  147. yield {'tts_speech': this_tts_speech.cpu()}
  148. with self.lock:
  149. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  150. # increase token_hop_len for better speech quality
  151. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  152. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  153. break
  154. p.join()
  155. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  156. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  157. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  158. prompt_token=flow_prompt_speech_token,
  159. prompt_feat=prompt_speech_feat,
  160. embedding=flow_embedding,
  161. uuid=this_uuid,
  162. finalize=True)
  163. yield {'tts_speech': this_tts_speech.cpu()}
  164. else:
  165. # deal with all tokens
  166. p.join()
  167. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  168. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  169. prompt_token=flow_prompt_speech_token,
  170. prompt_feat=prompt_speech_feat,
  171. embedding=flow_embedding,
  172. uuid=this_uuid,
  173. finalize=True,
  174. speed=speed)
  175. yield {'tts_speech': this_tts_speech.cpu()}
  176. with self.lock:
  177. self.tts_speech_token_dict.pop(this_uuid)
  178. self.llm_end_dict.pop(this_uuid)
  179. self.mel_overlap_dict.pop(this_uuid)
  180. self.hift_cache_dict.pop(this_uuid)
  181. def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
  182. # this_uuid is used to track variables related to this inference thread
  183. this_uuid = str(uuid.uuid1())
  184. with self.lock:
  185. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
  186. self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
  187. if stream is True:
  188. token_hop_len = self.token_min_hop_len
  189. while True:
  190. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  191. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  192. .unsqueeze(dim=0)
  193. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  194. prompt_token=flow_prompt_speech_token,
  195. prompt_feat=prompt_speech_feat,
  196. embedding=flow_embedding,
  197. uuid=this_uuid,
  198. finalize=False)
  199. yield {'tts_speech': this_tts_speech.cpu()}
  200. with self.lock:
  201. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  202. # increase token_hop_len for better speech quality
  203. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  204. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  205. break
  206. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  207. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
  208. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  209. prompt_token=flow_prompt_speech_token,
  210. prompt_feat=prompt_speech_feat,
  211. embedding=flow_embedding,
  212. uuid=this_uuid,
  213. finalize=True)
  214. yield {'tts_speech': this_tts_speech.cpu()}
  215. else:
  216. # deal with all tokens
  217. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  218. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  219. prompt_token=flow_prompt_speech_token,
  220. prompt_feat=prompt_speech_feat,
  221. embedding=flow_embedding,
  222. uuid=this_uuid,
  223. finalize=True,
  224. speed=speed)
  225. yield {'tts_speech': this_tts_speech.cpu()}
  226. with self.lock:
  227. self.tts_speech_token_dict.pop(this_uuid)
  228. self.llm_end_dict.pop(this_uuid)
  229. self.mel_overlap_dict.pop(this_uuid)
  230. self.hift_cache_dict.pop(this_uuid)