1
0

model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import numpy as np
  16. import threading
  17. import time
  18. from contextlib import nullcontext
  19. class CosyVoiceModel:
  20. def __init__(self,
  21. llm: torch.nn.Module,
  22. flow: torch.nn.Module,
  23. hift: torch.nn.Module):
  24. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  25. self.llm = llm
  26. self.flow = flow
  27. self.hift = hift
  28. self.stream_win_len = 60 * 4
  29. self.stream_hop_len = 50 * 4
  30. self.overlap = 4395 * 4 # 10 token equals 4395 sample point
  31. self.window = np.hamming(2 * self.overlap)
  32. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  33. self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  34. self.lock = threading.Lock()
  35. def load(self, llm_model, flow_model, hift_model):
  36. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  37. self.llm.to(self.device).eval()
  38. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  39. self.flow.to(self.device).eval()
  40. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  41. self.hift.to(self.device).eval()
  42. def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding):
  43. with self.llm_context:
  44. for i in self.llm.inference(text=text.to(self.device),
  45. text_len=text_len.to(self.device),
  46. prompt_text=prompt_text.to(self.device),
  47. prompt_text_len=prompt_text_len.to(self.device),
  48. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  49. prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
  50. embedding=llm_embedding.to(self.device),
  51. beam_size=1,
  52. sampling=25,
  53. max_token_text_ratio=30,
  54. min_token_text_ratio=3,
  55. stream=True):
  56. self.tts_speech_token.append(i)
  57. self.llm_end = True
  58. def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
  59. with self.flow_hift_context:
  60. tts_mel = self.flow.inference(token=token.to(self.device),
  61. token_len=torch.tensor([token.size(1)], dtype=torch.int32).to(self.device),
  62. prompt_token=prompt_token.to(self.device),
  63. prompt_token_len=prompt_token_len.to(self.device),
  64. prompt_feat=prompt_feat.to(self.device),
  65. prompt_feat_len=prompt_feat_len.to(self.device),
  66. embedding=embedding.to(self.device))
  67. tts_speech = self.hift.inference(mel=tts_mel).cpu()
  68. return tts_speech
  69. def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
  70. prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
  71. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
  72. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
  73. prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
  74. if stream is True:
  75. self.tts_speech_token, self.llm_end, cache_speech = [], False, None
  76. p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
  77. llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device)))
  78. p.start()
  79. while True:
  80. time.sleep(0.1)
  81. if len(self.tts_speech_token) >= self.stream_win_len:
  82. this_tts_speech_token = torch.concat(self.tts_speech_token[:self.stream_win_len], dim=1)
  83. with self.flow_hift_context:
  84. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  85. prompt_token=flow_prompt_speech_token.to(self.device),
  86. prompt_token_len=flow_prompt_speech_token_len.to(self.device),
  87. prompt_feat=prompt_speech_feat.to(self.device),
  88. prompt_feat_len=prompt_speech_feat_len.to(self.device),
  89. embedding=flow_embedding.to(self.device))
  90. # fade in/out if necessary
  91. if cache_speech is not None:
  92. this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
  93. yield {'tts_speech': this_tts_speech[:, :-self.overlap]}
  94. cache_speech = this_tts_speech[:, -self.overlap:]
  95. with self.lock:
  96. self.tts_speech_token = self.tts_speech_token[self.stream_hop_len:]
  97. if self.llm_end is True:
  98. break
  99. # deal with remain tokens
  100. if cache_speech is None or len(self.tts_speech_token) > self.stream_win_len - self.stream_hop_len:
  101. this_tts_speech_token = torch.concat(self.tts_speech_token, dim=1)
  102. with self.flow_hift_context:
  103. this_tts_mel = self.flow.inference(token=this_tts_speech_token,
  104. token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
  105. prompt_token=flow_prompt_speech_token.to(self.device),
  106. prompt_token_len=flow_prompt_speech_token_len.to(self.device),
  107. prompt_feat=prompt_speech_feat.to(self.device),
  108. prompt_feat_len=prompt_speech_feat_len.to(self.device),
  109. embedding=flow_embedding.to(self.device))
  110. this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
  111. if cache_speech is not None:
  112. this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
  113. yield {'tts_speech': this_tts_speech}
  114. else:
  115. assert len(self.tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len)
  116. yield {'tts_speech': cache_speech}
  117. p.join()
  118. torch.cuda.synchronize()
  119. else:
  120. tts_speech_token = []
  121. for i in self.llm.inference(text=text.to(self.device),
  122. text_len=text_len.to(self.device),
  123. prompt_text=prompt_text.to(self.device),
  124. prompt_text_len=prompt_text_len.to(self.device),
  125. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  126. prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
  127. embedding=llm_embedding.to(self.device),
  128. beam_size=1,
  129. sampling=25,
  130. max_token_text_ratio=30,
  131. min_token_text_ratio=3,
  132. stream=stream):
  133. tts_speech_token.append(i)
  134. assert len(tts_speech_token) == 1, 'tts_speech_token len should be 1 when stream is {}'.format(stream)
  135. tts_speech_token = torch.concat(tts_speech_token, dim=1)
  136. tts_mel = self.flow.inference(token=tts_speech_token,
  137. token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
  138. prompt_token=flow_prompt_speech_token.to(self.device),
  139. prompt_token_len=flow_prompt_speech_token_len.to(self.device),
  140. prompt_feat=prompt_speech_feat.to(self.device),
  141. prompt_feat_len=prompt_speech_feat_len.to(self.device),
  142. embedding=flow_embedding.to(self.device))
  143. tts_speech = self.hift.inference(mel=tts_mel).cpu()
  144. torch.cuda.empty_cache()
  145. yield {'tts_speech': tts_speech}