model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import numpy as np
  16. import threading
  17. import time
  18. from contextlib import nullcontext
  19. import uuid
  20. from cosyvoice.utils.common import fade_in_out
  21. class CosyVoiceModel:
  22. def __init__(self,
  23. llm: torch.nn.Module,
  24. flow: torch.nn.Module,
  25. hift: torch.nn.Module):
  26. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  27. self.llm = llm
  28. self.flow = flow
  29. self.hift = hift
  30. self.token_min_hop_len = 100
  31. self.token_max_hop_len = 400
  32. self.token_overlap_len = 20
  33. self.speech_overlap_len = 34 * 256
  34. self.window = np.hamming(2 * self.speech_overlap_len)
  35. self.stream_scale_factor = 1
  36. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  37. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  38. self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  39. self.lock = threading.Lock()
  40. # dict used to store session related variable
  41. self.tts_speech_token = {}
  42. self.llm_end = {}
  43. def load(self, llm_model, flow_model, hift_model):
  44. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  45. self.llm.to(self.device).eval()
  46. self.llm.half()
  47. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  48. self.flow.to(self.device).eval()
  49. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  50. self.hift.to(self.device).eval()
  51. def load_jit(self, llm_text_encoder_model, llm_llm_model):
  52. llm_text_encoder = torch.jit.load(llm_text_encoder_model)
  53. self.llm.text_encoder = llm_text_encoder
  54. llm_llm = torch.jit.load(llm_llm_model)
  55. self.llm.llm = llm_llm
  56. def load_trt(self):
  57. # TODO 你需要的TRT推理的准备
  58. self.flow.decoder.estimator = xxx
  59. self.flow.decoder.session = xxx
  60. def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
  61. with self.llm_context:
  62. for i in self.llm.inference(text=text.to(self.device),
  63. text_len=text_len.to(self.device),
  64. prompt_text=prompt_text.to(self.device),
  65. prompt_text_len=prompt_text_len.to(self.device),
  66. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  67. prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
  68. embedding=llm_embedding.to(self.device).half(),
  69. sampling=25,
  70. max_token_text_ratio=30,
  71. min_token_text_ratio=3):
  72. self.tts_speech_token[this_uuid].append(i)
  73. self.llm_end[this_uuid] = True
  74. def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
  75. with self.flow_hift_context:
  76. tts_mel = self.flow.inference(token=token.to(self.device),
  77. token_len=torch.tensor([token.size(1)], dtype=torch.int32).to(self.device),
  78. prompt_token=prompt_token.to(self.device),
  79. prompt_token_len=prompt_token_len.to(self.device),
  80. prompt_feat=prompt_feat.to(self.device),
  81. prompt_feat_len=prompt_feat_len.to(self.device),
  82. embedding=embedding.to(self.device))
  83. tts_speech = self.hift.inference(mel=tts_mel).cpu()
  84. return tts_speech
  85. def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
  86. prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
  87. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
  88. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
  89. prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
  90. # this_uuid is used to track variables related to this inference thread
  91. this_uuid = str(uuid.uuid1())
  92. with self.lock:
  93. self.tts_speech_token[this_uuid], self.llm_end[this_uuid] = [], False
  94. p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
  95. llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device), this_uuid))
  96. p.start()
  97. if stream is True:
  98. cache_speech, cache_token, token_hop_len = None, None, self.token_min_hop_len
  99. while True:
  100. time.sleep(0.1)
  101. if len(self.tts_speech_token[this_uuid]) >= token_hop_len + self.token_overlap_len:
  102. this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
  103. with self.flow_hift_context:
  104. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  105. prompt_token=flow_prompt_speech_token.to(self.device),
  106. prompt_token_len=flow_prompt_speech_token_len.to(self.device),
  107. prompt_feat=prompt_speech_feat.to(self.device),
  108. prompt_feat_len=prompt_speech_feat_len.to(self.device),
  109. embedding=flow_embedding.to(self.device))
  110. # fade in/out if necessary
  111. if cache_speech is not None:
  112. this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
  113. yield {'tts_speech': this_tts_speech[:, :-self.speech_overlap_len]}
  114. cache_speech = this_tts_speech[:, -self.speech_overlap_len:]
  115. cache_token = self.tts_speech_token[this_uuid][:token_hop_len]
  116. with self.lock:
  117. self.tts_speech_token[this_uuid] = self.tts_speech_token[this_uuid][token_hop_len:]
  118. # increase token_hop_len for better speech quality
  119. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  120. if self.llm_end[this_uuid] is True and len(self.tts_speech_token[this_uuid]) < token_hop_len + self.token_overlap_len:
  121. break
  122. p.join()
  123. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  124. this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
  125. if this_tts_speech_token.shape[1] < self.token_min_hop_len + self.token_overlap_len and cache_token is not None:
  126. cache_token_len = self.token_min_hop_len + self.token_overlap_len - this_tts_speech_token.shape[1]
  127. this_tts_speech_token = torch.concat([torch.concat(cache_token[-cache_token_len:], dim=1), this_tts_speech_token], dim=1)
  128. else:
  129. cache_token_len = 0
  130. with self.flow_hift_context:
  131. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  132. prompt_token=flow_prompt_speech_token.to(self.device),
  133. prompt_token_len=flow_prompt_speech_token_len.to(self.device),
  134. prompt_feat=prompt_speech_feat.to(self.device),
  135. prompt_feat_len=prompt_speech_feat_len.to(self.device),
  136. embedding=flow_embedding.to(self.device))
  137. this_tts_speech = this_tts_speech[:, int(cache_token_len / this_tts_speech_token.shape[1] * this_tts_speech.shape[1]):]
  138. if cache_speech is not None:
  139. this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
  140. yield {'tts_speech': this_tts_speech}
  141. else:
  142. # deal with all tokens
  143. p.join()
  144. this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
  145. with self.flow_hift_context:
  146. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  147. prompt_token=flow_prompt_speech_token.to(self.device),
  148. prompt_token_len=flow_prompt_speech_token_len.to(self.device),
  149. prompt_feat=prompt_speech_feat.to(self.device),
  150. prompt_feat_len=prompt_speech_feat_len.to(self.device),
  151. embedding=flow_embedding.to(self.device))
  152. yield {'tts_speech': this_tts_speech}
  153. with self.lock:
  154. self.tts_speech_token.pop(this_uuid)
  155. self.llm_end.pop(this_uuid)
  156. torch.cuda.synchronize()