model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import torch
  16. import numpy as np
  17. import threading
  18. import time
  19. from contextlib import nullcontext
  20. import uuid
  21. from cosyvoice.utils.common import fade_in_out
  22. try:
  23. import tensorrt as trt
  24. except ImportError:
  25. ...
  26. class CosyVoiceModel:
  27. def __init__(self,
  28. llm: torch.nn.Module,
  29. flow: torch.nn.Module,
  30. hift: torch.nn.Module):
  31. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  32. self.llm = llm
  33. self.flow = flow
  34. self.hift = hift
  35. self.token_min_hop_len = 100
  36. self.token_max_hop_len = 200
  37. self.token_overlap_len = 20
  38. # mel fade in out
  39. self.mel_overlap_len = 34
  40. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  41. # hift cache
  42. self.mel_cache_len = 20
  43. self.source_cache_len = int(self.mel_cache_len * 256)
  44. # rtf and decoding related
  45. self.stream_scale_factor = 1
  46. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  47. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  48. self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  49. self.lock = threading.Lock()
  50. # dict used to store session related variable
  51. self.tts_speech_token_dict = {}
  52. self.llm_end_dict = {}
  53. self.mel_overlap_dict = {}
  54. self.hift_cache_dict = {}
  55. def load(self, llm_model, flow_model, hift_model):
  56. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  57. self.llm.to(self.device).eval()
  58. self.llm.half()
  59. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  60. self.flow.to(self.device).eval()
  61. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  62. self.hift.to(self.device).eval()
  63. def load_jit(self, llm_text_encoder_model, llm_llm_model):
  64. llm_text_encoder = torch.jit.load(llm_text_encoder_model)
  65. self.llm.text_encoder = llm_text_encoder
  66. llm_llm = torch.jit.load(llm_llm_model)
  67. self.llm.llm = llm_llm
  68. def load_trt(self, model_dir, use_fp16):
  69. trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
  70. trt_file_path = os.path.join(model_dir, trt_file_name)
  71. if not os.path.isfile(trt_file_path):
  72. raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
  73. trt.init_libnvinfer_plugins(None, "")
  74. logger = trt.Logger(trt.Logger.WARNING)
  75. runtime = trt.Runtime(logger)
  76. with open(trt_file_path, 'rb') as f:
  77. serialized_engine = f.read()
  78. engine = runtime.deserialize_cuda_engine(serialized_engine)
  79. self.flow.decoder.estimator = engine.create_execution_context()
  80. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  81. with self.llm_context:
  82. for i in self.llm.inference(text=text.to(self.device),
  83. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  84. prompt_text=prompt_text.to(self.device),
  85. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  86. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  87. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  88. embedding=llm_embedding.to(self.device).half(),
  89. sampling=25,
  90. max_token_text_ratio=30,
  91. min_token_text_ratio=3):
  92. self.tts_speech_token_dict[uuid].append(i)
  93. self.llm_end_dict[uuid] = True
  94. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
  95. with self.flow_hift_context:
  96. tts_mel = self.flow.inference(token=token.to(self.device),
  97. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  98. prompt_token=prompt_token.to(self.device),
  99. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  100. prompt_feat=prompt_feat.to(self.device),
  101. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  102. embedding=embedding.to(self.device))
  103. # mel overlap fade in out
  104. if self.mel_overlap_dict[uuid] is not None:
  105. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  106. # append hift cache
  107. if self.hift_cache_dict[uuid] is not None:
  108. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  109. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  110. else:
  111. hift_cache_source = torch.zeros(1, 1, 0)
  112. # keep overlap mel and hift cache
  113. if finalize is False:
  114. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  115. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  116. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  117. self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
  118. tts_speech = tts_speech[:, :-self.source_cache_len]
  119. else:
  120. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  121. return tts_speech
  122. def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  123. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  124. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  125. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  126. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
  127. # this_uuid is used to track variables related to this inference thread
  128. this_uuid = str(uuid.uuid1())
  129. with self.lock:
  130. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
  131. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  132. p.start()
  133. p.join()
  134. if stream is True:
  135. token_hop_len = self.token_min_hop_len
  136. while True:
  137. time.sleep(0.1)
  138. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  139. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
  140. with self.flow_hift_context:
  141. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  142. prompt_token=flow_prompt_speech_token,
  143. prompt_feat=prompt_speech_feat,
  144. embedding=flow_embedding,
  145. uuid=this_uuid,
  146. finalize=False)
  147. yield {'tts_speech': this_tts_speech.cpu()}
  148. with self.lock:
  149. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  150. # increase token_hop_len for better speech quality
  151. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  152. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  153. break
  154. # p.join()
  155. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  156. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  157. with self.flow_hift_context:
  158. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  159. prompt_token=flow_prompt_speech_token,
  160. prompt_feat=prompt_speech_feat,
  161. embedding=flow_embedding,
  162. uuid=this_uuid,
  163. finalize=True)
  164. yield {'tts_speech': this_tts_speech.cpu()}
  165. else:
  166. # deal with all tokens
  167. # p.join()
  168. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  169. with self.flow_hift_context:
  170. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  171. prompt_token=flow_prompt_speech_token,
  172. prompt_feat=prompt_speech_feat,
  173. embedding=flow_embedding,
  174. uuid=this_uuid,
  175. finalize=True)
  176. yield {'tts_speech': this_tts_speech.cpu()}
  177. with self.lock:
  178. self.tts_speech_token_dict.pop(this_uuid)
  179. self.llm_end_dict.pop(this_uuid)
  180. self.mel_overlap_dict.pop(this_uuid)
  181. self.hift_cache_dict.pop(this_uuid)
  182. torch.cuda.synchronize()