model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import torch
  16. import numpy as np
  17. import threading
  18. import time
  19. from contextlib import nullcontext
  20. import uuid
  21. from cosyvoice.utils.common import fade_in_out
  22. try:
  23. import tensorrt as trt
  24. except ImportError:
  25. ...
  26. class CosyVoiceModel:
  27. def __init__(self,
  28. llm: torch.nn.Module,
  29. flow: torch.nn.Module,
  30. hift: torch.nn.Module):
  31. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  32. self.llm = llm
  33. self.flow = flow
  34. self.hift = hift
  35. self.token_min_hop_len = 100
  36. self.token_max_hop_len = 200
  37. self.token_overlap_len = 20
  38. # mel fade in out
  39. self.mel_overlap_len = 34
  40. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  41. # hift cache
  42. self.mel_cache_len = 20
  43. self.source_cache_len = int(self.mel_cache_len * 256)
  44. # rtf and decoding related
  45. self.stream_scale_factor = 1
  46. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  47. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  48. self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  49. self.lock = threading.Lock()
  50. # dict used to store session related variable
  51. self.tts_speech_token_dict = {}
  52. self.llm_end_dict = {}
  53. self.mel_overlap_dict = {}
  54. self.hift_cache_dict = {}
  55. def load(self, llm_model, flow_model, hift_model):
  56. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  57. self.llm.to(self.device).eval()
  58. self.llm.half()
  59. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  60. self.flow.to(self.device).eval()
  61. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  62. self.hift.to(self.device).eval()
  63. def load_jit(self, llm_text_encoder_model, llm_llm_model):
  64. llm_text_encoder = torch.jit.load(llm_text_encoder_model)
  65. self.llm.text_encoder = llm_text_encoder
  66. llm_llm = torch.jit.load(llm_llm_model)
  67. self.llm.llm = llm_llm
  68. def load_trt(self, model_dir, use_fp16):
  69. trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
  70. trt_file_path = os.path.join(model_dir, trt_file_name)
  71. if not os.path.isfile(trt_file_path):
  72. raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
  73. trt.init_libnvinfer_plugins(None, "")
  74. logger = trt.Logger(trt.Logger.WARNING)
  75. runtime = trt.Runtime(logger)
  76. with open(trt_file_path, 'rb') as f:
  77. serialized_engine = f.read()
  78. engine = runtime.deserialize_cuda_engine(serialized_engine)
  79. self.flow.decoder.estimator_context = engine.create_execution_context()
  80. self.flow.decoder.estimator_engine = engine
  81. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  82. with self.llm_context:
  83. for i in self.llm.inference(text=text.to(self.device),
  84. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  85. prompt_text=prompt_text.to(self.device),
  86. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  87. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  88. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  89. embedding=llm_embedding.to(self.device).half(),
  90. sampling=25,
  91. max_token_text_ratio=30,
  92. min_token_text_ratio=3):
  93. self.tts_speech_token_dict[uuid].append(i)
  94. self.llm_end_dict[uuid] = True
  95. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
  96. with self.flow_hift_context:
  97. tts_mel = self.flow.inference(token=token.to(self.device),
  98. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  99. prompt_token=prompt_token.to(self.device),
  100. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  101. prompt_feat=prompt_feat.to(self.device),
  102. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  103. embedding=embedding.to(self.device))
  104. # mel overlap fade in out
  105. if self.mel_overlap_dict[uuid] is not None:
  106. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  107. # append hift cache
  108. if self.hift_cache_dict[uuid] is not None:
  109. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  110. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  111. else:
  112. hift_cache_source = torch.zeros(1, 1, 0)
  113. # keep overlap mel and hift cache
  114. if finalize is False:
  115. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  116. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  117. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  118. self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
  119. tts_speech = tts_speech[:, :-self.source_cache_len]
  120. else:
  121. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  122. return tts_speech
  123. def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  124. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  125. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  126. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  127. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
  128. # this_uuid is used to track variables related to this inference thread
  129. this_uuid = str(uuid.uuid1())
  130. with self.lock:
  131. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
  132. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  133. p.start()
  134. p.join()
  135. if stream is True:
  136. token_hop_len = self.token_min_hop_len
  137. while True:
  138. time.sleep(0.1)
  139. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  140. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
  141. with self.flow_hift_context:
  142. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  143. prompt_token=flow_prompt_speech_token,
  144. prompt_feat=prompt_speech_feat,
  145. embedding=flow_embedding,
  146. uuid=this_uuid,
  147. finalize=False)
  148. yield {'tts_speech': this_tts_speech.cpu()}
  149. with self.lock:
  150. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  151. # increase token_hop_len for better speech quality
  152. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  153. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  154. break
  155. # p.join()
  156. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  157. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  158. with self.flow_hift_context:
  159. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  160. prompt_token=flow_prompt_speech_token,
  161. prompt_feat=prompt_speech_feat,
  162. embedding=flow_embedding,
  163. uuid=this_uuid,
  164. finalize=True)
  165. yield {'tts_speech': this_tts_speech.cpu()}
  166. else:
  167. # deal with all tokens
  168. # p.join()
  169. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  170. with self.flow_hift_context:
  171. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  172. prompt_token=flow_prompt_speech_token,
  173. prompt_feat=prompt_speech_feat,
  174. embedding=flow_embedding,
  175. uuid=this_uuid,
  176. finalize=True)
  177. yield {'tts_speech': this_tts_speech.cpu()}
  178. with self.lock:
  179. self.tts_speech_token_dict.pop(this_uuid)
  180. self.llm_end_dict.pop(this_uuid)
  181. self.mel_overlap_dict.pop(this_uuid)
  182. self.hift_cache_dict.pop(this_uuid)
  183. torch.cuda.synchronize()