model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import torch
  16. import numpy as np
  17. import threading
  18. import time
  19. from contextlib import nullcontext
  20. import uuid
  21. from cosyvoice.utils.common import fade_in_out
  22. class CosyVoiceModel:
  23. def __init__(self,
  24. llm: torch.nn.Module,
  25. flow: torch.nn.Module,
  26. hift: torch.nn.Module):
  27. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  28. self.llm = llm
  29. self.flow = flow
  30. self.hift = hift
  31. self.token_min_hop_len = 100
  32. self.token_max_hop_len = 200
  33. self.token_overlap_len = 20
  34. # mel fade in out
  35. self.mel_overlap_len = 34
  36. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  37. # hift cache
  38. self.mel_cache_len = 20
  39. self.source_cache_len = int(self.mel_cache_len * 256)
  40. # rtf and decoding related
  41. self.stream_scale_factor = 1
  42. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  43. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  44. self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  45. self.lock = threading.Lock()
  46. # dict used to store session related variable
  47. self.tts_speech_token_dict = {}
  48. self.llm_end_dict = {}
  49. self.mel_overlap_dict = {}
  50. self.hift_cache_dict = {}
  51. def load(self, llm_model, flow_model, hift_model):
  52. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  53. self.llm.to(self.device).eval()
  54. self.llm.half()
  55. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  56. self.flow.to(self.device).eval()
  57. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  58. self.hift.to(self.device).eval()
  59. def load_jit(self, llm_text_encoder_model, llm_llm_model):
  60. llm_text_encoder = torch.jit.load(llm_text_encoder_model)
  61. self.llm.text_encoder = llm_text_encoder
  62. llm_llm = torch.jit.load(llm_llm_model)
  63. self.llm.llm = llm_llm
  64. def load_trt(self, model_dir, use_fp16):
  65. import tensorrt as trt
  66. trt_file_name = 'estimator_fp16.plan' if use_fp16 else 'estimator_fp32.plan'
  67. trt_file_path = os.path.join(model_dir, trt_file_name)
  68. if not os.path.isfile(trt_file_path):
  69. raise f"{trt_file_path} does not exist. Please use bin/export_trt.py to generate .plan file"
  70. trt.init_libnvinfer_plugins(None, "")
  71. logger = trt.Logger(trt.Logger.WARNING)
  72. runtime = trt.Runtime(logger)
  73. with open(trt_file_path, 'rb') as f:
  74. serialized_engine = f.read()
  75. engine = runtime.deserialize_cuda_engine(serialized_engine)
  76. self.flow.decoder.estimator_context = engine.create_execution_context()
  77. self.flow.decoder.estimator = None
  78. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  79. with self.llm_context:
  80. for i in self.llm.inference(text=text.to(self.device),
  81. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  82. prompt_text=prompt_text.to(self.device),
  83. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  84. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  85. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  86. embedding=llm_embedding.to(self.device).half(),
  87. sampling=25,
  88. max_token_text_ratio=30,
  89. min_token_text_ratio=3):
  90. self.tts_speech_token_dict[uuid].append(i)
  91. self.llm_end_dict[uuid] = True
  92. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
  93. with self.flow_hift_context:
  94. tts_mel = self.flow.inference(token=token.to(self.device),
  95. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  96. prompt_token=prompt_token.to(self.device),
  97. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  98. prompt_feat=prompt_feat.to(self.device),
  99. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  100. embedding=embedding.to(self.device))
  101. # mel overlap fade in out
  102. if self.mel_overlap_dict[uuid] is not None:
  103. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  104. # append hift cache
  105. if self.hift_cache_dict[uuid] is not None:
  106. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  107. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  108. else:
  109. hift_cache_source = torch.zeros(1, 1, 0)
  110. # keep overlap mel and hift cache
  111. if finalize is False:
  112. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  113. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  114. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  115. self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
  116. tts_speech = tts_speech[:, :-self.source_cache_len]
  117. else:
  118. tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
  119. return tts_speech
  120. def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
  121. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  122. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  123. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  124. prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
  125. # this_uuid is used to track variables related to this inference thread
  126. this_uuid = str(uuid.uuid1())
  127. with self.lock:
  128. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
  129. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  130. p.start()
  131. p.join()
  132. if stream is True:
  133. token_hop_len = self.token_min_hop_len
  134. while True:
  135. time.sleep(0.1)
  136. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  137. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
  138. with self.flow_hift_context:
  139. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  140. prompt_token=flow_prompt_speech_token,
  141. prompt_feat=prompt_speech_feat,
  142. embedding=flow_embedding,
  143. uuid=this_uuid,
  144. finalize=False)
  145. yield {'tts_speech': this_tts_speech.cpu()}
  146. with self.lock:
  147. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  148. # increase token_hop_len for better speech quality
  149. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  150. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  151. break
  152. # p.join()
  153. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  154. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  155. with self.flow_hift_context:
  156. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  157. prompt_token=flow_prompt_speech_token,
  158. prompt_feat=prompt_speech_feat,
  159. embedding=flow_embedding,
  160. uuid=this_uuid,
  161. finalize=True)
  162. yield {'tts_speech': this_tts_speech.cpu()}
  163. else:
  164. # deal with all tokens
  165. # p.join()
  166. this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
  167. with self.flow_hift_context:
  168. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  169. prompt_token=flow_prompt_speech_token,
  170. prompt_feat=prompt_speech_feat,
  171. embedding=flow_embedding,
  172. uuid=this_uuid,
  173. finalize=True)
  174. yield {'tts_speech': this_tts_speech.cpu()}
  175. with self.lock:
  176. self.tts_speech_token_dict.pop(this_uuid)
  177. self.llm_end_dict.pop(this_uuid)
  178. self.mel_overlap_dict.pop(this_uuid)
  179. self.hift_cache_dict.pop(this_uuid)
  180. torch.cuda.synchronize()