1
0

model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import torch
  16. import numpy as np
  17. import threading
  18. import time
  19. from contextlib import nullcontext
  20. import uuid
  21. from cosyvoice.utils.common import fade_in_out
  22. import numpy as np
  23. import onnxruntime as ort
  24. # try:
  25. # import tensorrt as trt
  26. # except ImportError:
  27. # ...
  28. class CosyVoiceModel:
  29. def __init__(self,
  30. llm: torch.nn.Module,
  31. flow: torch.nn.Module,
  32. hift: torch.nn.Module):
  33. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  34. self.llm = llm
  35. self.flow = flow
  36. self.hift = hift
  37. self.token_min_hop_len = 100
  38. self.token_max_hop_len = 200
  39. self.token_overlap_len = 20
  40. # mel fade in out
  41. self.mel_overlap_len = 34
  42. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  43. # hift cache
  44. self.mel_cache_len = 20
  45. self.source_cache_len = int(self.mel_cache_len * 256)
  46. # rtf and decoding related
  47. self.stream_scale_factor = 1
  48. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  49. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  50. self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  51. self.lock = threading.Lock()
  52. # dict used to store session related variable
  53. self.tts_speech_token_dict = {}
  54. self.llm_end_dict = {}
  55. self.mel_overlap_dict = {}
  56. self.hift_cache_dict = {}
  57. def load(self, llm_model, flow_model, hift_model):
  58. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  59. self.llm.to(self.device).eval()
  60. self.llm.half()
  61. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  62. self.flow.to(self.device).eval()
  63. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  64. self.hift.to(self.device).eval()
  65. def load_jit(self, llm_text_encoder_model, llm_llm_model):
  66. llm_text_encoder = torch.jit.load(llm_text_encoder_model)
  67. self.llm.text_encoder = llm_text_encoder
  68. llm_llm = torch.jit.load(llm_llm_model)
  69. self.llm.llm = llm_llm
  70. # def load_trt(self, model_dir, use_fp16):
  71. # trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
  72. # trt_file_path = os.path.join(model_dir, trt_file_name)
  73. # if not os.path.isfile(trt_file_path):
  74. # raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
  75. # trt.init_libnvinfer_plugins(None, "")
  76. # logger = trt.Logger(trt.Logger.WARNING)
  77. # runtime = trt.Runtime(logger)
  78. # with open(trt_file_path, 'rb') as f:
  79. # serialized_engine = f.read()
  80. # engine = runtime.deserialize_cuda_engine(serialized_engine)
  81. # self.flow.decoder.estimator_context = engine.create_execution_context()
  82. # self.flow.decoder.estimator = None
  83. def load_onnx(self, model_dir, use_fp16):
  84. onnx_file_name = 'estimator_fp16.onnx' if use_fp16 else 'estimator_fp32.onnx'
  85. onnx_file_path = os.path.join(model_dir, onnx_file_name)
  86. if not os.path.isfile(onnx_file_path):
  87. raise f"{onnx_file_path} does not exist. Please use bin/export_trt.py to generate .onnx file"
  88. providers = ['CUDAExecutionProvider']
  89. sess_options = ort.SessionOptions()
  90. # Add TensorRT Execution Provider
  91. providers = [
  92. 'CUDAExecutionProvider'
  93. ]
  94. # Load the ONNX model
  95. self.flow.decoder.session = ort.InferenceSession(onnx_file_path, sess_options=sess_options, providers=providers)
  96. # self.flow.decoder.estimator_context = None
  97. self.flow.decoder.estimator = None
  98. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  99. with self.llm_context:
  100. for i in self.llm.inference(text=text.to(self.device),
  101. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  102. prompt_text=prompt_text.to(self.device),
  103. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  104. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  105. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  106. embedding=llm_embedding.to(self.device).half(),
  107. sampling=25,
  108. max_token_text_ratio=30,
  109. min_token_text_ratio=3):
  110. self.tts_speech_token_dict[uuid].append(i)
  111. self.llm_end_dict[uuid] = True
  112. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
  113. with self.flow_hift_context:
  114. tts_mel = self.flow.inference(token=token.to(self.device),
  115. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  116. prompt_token=prompt_token.to(self.device),
  117. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  118. prompt_feat=prompt_feat.to(self.device),
  119. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  120. embedding=embedding.to(self.device))
  121. # mel overlap fade in out
  122. if self.mel_overlap_dict[uuid] is not None:
  123. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  124. # append hift cache
  125. if self.hift_cache_dict[uuid] is not None:
  126. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  127. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  128. else:
  129. hift_cache_source = torch.zeros(1, 1, 0)
  130. # keep overlap mel and hift cache
  131. if finalize is False:
  132. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  133. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  134. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  135. self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
  136. tts_speech = tts_speech[:, :-self.source_cache_len]
  137. else:
  138. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  139. return tts_speech
  140. def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  141. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  142. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  143. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  144. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
  145. # this_uuid is used to track variables related to this inference thread
  146. this_uuid = str(uuid.uuid1())
  147. with self.lock:
  148. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
  149. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  150. p.start()
  151. p.join()
  152. if stream is True:
  153. token_hop_len = self.token_min_hop_len
  154. while True:
  155. time.sleep(0.1)
  156. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  157. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
  158. with self.flow_hift_context:
  159. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  160. prompt_token=flow_prompt_speech_token,
  161. prompt_feat=prompt_speech_feat,
  162. embedding=flow_embedding,
  163. uuid=this_uuid,
  164. finalize=False)
  165. yield {'tts_speech': this_tts_speech.cpu()}
  166. with self.lock:
  167. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  168. # increase token_hop_len for better speech quality
  169. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  170. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  171. break
  172. # p.join()
  173. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  174. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  175. with self.flow_hift_context:
  176. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  177. prompt_token=flow_prompt_speech_token,
  178. prompt_feat=prompt_speech_feat,
  179. embedding=flow_embedding,
  180. uuid=this_uuid,
  181. finalize=True)
  182. yield {'tts_speech': this_tts_speech.cpu()}
  183. else:
  184. # deal with all tokens
  185. # p.join()
  186. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  187. with self.flow_hift_context:
  188. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  189. prompt_token=flow_prompt_speech_token,
  190. prompt_feat=prompt_speech_feat,
  191. embedding=flow_embedding,
  192. uuid=this_uuid,
  193. finalize=True)
  194. yield {'tts_speech': this_tts_speech.cpu()}
  195. with self.lock:
  196. self.tts_speech_token_dict.pop(this_uuid)
  197. self.llm_end_dict.pop(this_uuid)
  198. self.mel_overlap_dict.pop(this_uuid)
  199. self.hift_cache_dict.pop(this_uuid)
  200. torch.cuda.synchronize()