model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import numpy as np
  16. import threading
  17. import time
  18. from contextlib import nullcontext
  19. import uuid
  20. from cosyvoice.utils.common import fade_in_out
  21. class CosyVoiceModel:
  22. def __init__(self,
  23. llm: torch.nn.Module,
  24. flow: torch.nn.Module,
  25. hift: torch.nn.Module):
  26. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  27. self.llm = llm
  28. self.flow = flow
  29. self.hift = hift
  30. self.token_min_hop_len = 100
  31. self.token_max_hop_len = 200
  32. self.token_overlap_len = 20
  33. # mel fade in out
  34. self.mel_overlap_len = 34
  35. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  36. # hift cache
  37. self.mel_cache_len = 20
  38. self.source_cache_len = int(self.mel_cache_len * 256)
  39. # rtf and decoding related
  40. self.stream_scale_factor = 1
  41. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  42. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  43. self.lock = threading.Lock()
  44. # dict used to store session related variable
  45. self.tts_speech_token_dict = {}
  46. self.llm_end_dict = {}
  47. self.mel_overlap_dict = {}
  48. self.hift_cache_dict = {}
  49. self.speech_window = np.hamming(2 * self.source_cache_len)
  50. def load(self, llm_model, flow_model, hift_model):
  51. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  52. self.llm.to(self.device).eval()
  53. self.llm.half()
  54. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  55. self.flow.to(self.device).eval()
  56. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  57. self.hift.to(self.device).eval()
  58. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  59. llm_text_encoder = torch.jit.load(llm_text_encoder_model)
  60. self.llm.text_encoder = llm_text_encoder
  61. llm_llm = torch.jit.load(llm_llm_model)
  62. self.llm.llm = llm_llm
  63. flow_encoder = torch.jit.load(flow_encoder_model)
  64. self.flow.encoder = flow_encoder
  65. def load_onnx(self, flow_decoder_estimator_model):
  66. import onnxruntime
  67. option = onnxruntime.SessionOptions()
  68. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  69. option.intra_op_num_threads = 1
  70. providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
  71. del self.flow.decoder.estimator
  72. self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
  73. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  74. with self.llm_context:
  75. for i in self.llm.inference(text=text.to(self.device),
  76. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  77. prompt_text=prompt_text.to(self.device),
  78. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  79. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  80. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  81. embedding=llm_embedding.to(self.device).half(),
  82. sampling=25,
  83. max_token_text_ratio=30,
  84. min_token_text_ratio=3):
  85. self.tts_speech_token_dict[uuid].append(i)
  86. self.llm_end_dict[uuid] = True
  87. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
  88. tts_mel = self.flow.inference(token=token.to(self.device),
  89. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  90. prompt_token=prompt_token.to(self.device),
  91. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  92. prompt_feat=prompt_feat.to(self.device),
  93. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  94. embedding=embedding.to(self.device))
  95. # mel overlap fade in out
  96. if self.mel_overlap_dict[uuid] is not None:
  97. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  98. # append hift cache
  99. if self.hift_cache_dict[uuid] is not None:
  100. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  101. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  102. else:
  103. hift_cache_source = torch.zeros(1, 1, 0)
  104. # keep overlap mel and hift cache
  105. if finalize is False:
  106. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  107. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  108. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  109. if self.hift_cache_dict[uuid] is not None:
  110. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  111. self.hift_cache_dict[uuid] = {
  112. 'mel': tts_mel[:, :, -self.mel_cache_len:],
  113. 'source': tts_source[:, :, -self.source_cache_len:],
  114. 'speech': tts_speech[:, -self.source_cache_len:]}
  115. tts_speech = tts_speech[:, :-self.source_cache_len]
  116. else:
  117. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  118. if self.hift_cache_dict[uuid] is not None:
  119. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  120. return tts_speech
  121. def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  122. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  123. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  124. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  125. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
  126. # this_uuid is used to track variables related to this inference thread
  127. this_uuid = str(uuid.uuid1())
  128. with self.lock:
  129. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  130. self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
  131. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  132. p.start()
  133. if stream is True:
  134. token_hop_len = self.token_min_hop_len
  135. while True:
  136. time.sleep(0.1)
  137. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  138. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
  139. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  140. prompt_token=flow_prompt_speech_token,
  141. prompt_feat=prompt_speech_feat,
  142. embedding=flow_embedding,
  143. uuid=this_uuid,
  144. finalize=False)
  145. yield {'tts_speech': this_tts_speech.cpu()}
  146. with self.lock:
  147. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  148. # increase token_hop_len for better speech quality
  149. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  150. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  151. break
  152. p.join()
  153. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  154. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  155. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  156. prompt_token=flow_prompt_speech_token,
  157. prompt_feat=prompt_speech_feat,
  158. embedding=flow_embedding,
  159. uuid=this_uuid,
  160. finalize=True)
  161. yield {'tts_speech': this_tts_speech.cpu()}
  162. else:
  163. # deal with all tokens
  164. p.join()
  165. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  166. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  167. prompt_token=flow_prompt_speech_token,
  168. prompt_feat=prompt_speech_feat,
  169. embedding=flow_embedding,
  170. uuid=this_uuid,
  171. finalize=True)
  172. yield {'tts_speech': this_tts_speech.cpu()}
  173. with self.lock:
  174. self.tts_speech_token_dict.pop(this_uuid)
  175. self.llm_end_dict.pop(this_uuid)
  176. self.mel_overlap_dict.pop(this_uuid)
  177. self.hift_cache_dict.pop(this_uuid)