frontend.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from functools import partial
  15. import json
  16. import onnxruntime
  17. import torch
  18. import numpy as np
  19. import whisper
  20. from typing import Callable
  21. import torchaudio.compliance.kaldi as kaldi
  22. import torchaudio
  23. import os
  24. import re
  25. import inflect
  26. try:
  27. import ttsfrd
  28. use_ttsfrd = True
  29. except ImportError:
  30. print("failed to import ttsfrd, use WeTextProcessing instead")
  31. from tn.chinese.normalizer import Normalizer as ZhNormalizer
  32. from tn.english.normalizer import Normalizer as EnNormalizer
  33. use_ttsfrd = False
  34. from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
  35. class CosyVoiceFrontEnd:
  36. def __init__(self,
  37. get_tokenizer: Callable,
  38. feat_extractor: Callable,
  39. campplus_model: str,
  40. speech_tokenizer_model: str,
  41. spk2info: str = '',
  42. instruct: bool = False,
  43. allowed_special: str = 'all'):
  44. self.tokenizer = get_tokenizer()
  45. self.feat_extractor = feat_extractor
  46. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  47. option = onnxruntime.SessionOptions()
  48. option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
  49. option.intra_op_num_threads = 1
  50. self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
  51. self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
  52. providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
  53. "CPUExecutionProvider"])
  54. if os.path.exists(spk2info):
  55. self.spk2info = torch.load(spk2info, map_location=self.device)
  56. else:
  57. self.spk2info = {}
  58. self.instruct = instruct
  59. self.allowed_special = allowed_special
  60. self.inflect_parser = inflect.engine()
  61. self.use_ttsfrd = use_ttsfrd
  62. if self.use_ttsfrd:
  63. self.frd = ttsfrd.TtsFrontendEngine()
  64. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  65. assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
  66. 'failed to initialize ttsfrd resource'
  67. self.frd.set_lang_type('pinyinvg')
  68. else:
  69. self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False,overwrite_cache=True)
  70. self.en_tn_model = EnNormalizer()
  71. def _extract_text_token(self, text):
  72. text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
  73. text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
  74. text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
  75. return text_token, text_token_len
  76. def _extract_speech_token(self, speech):
  77. assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
  78. feat = whisper.log_mel_spectrogram(speech, n_mels=128)
  79. speech_token = self.speech_tokenizer_session.run(None,
  80. {self.speech_tokenizer_session.get_inputs()[0].name:
  81. feat.detach().cpu().numpy(),
  82. self.speech_tokenizer_session.get_inputs()[1].name:
  83. np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
  84. speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
  85. speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
  86. return speech_token, speech_token_len
  87. def _extract_spk_embedding(self, speech):
  88. feat = kaldi.fbank(speech,
  89. num_mel_bins=80,
  90. dither=0,
  91. sample_frequency=16000)
  92. feat = feat - feat.mean(dim=0, keepdim=True)
  93. embedding = self.campplus_session.run(None,
  94. {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
  95. embedding = torch.tensor([embedding]).to(self.device)
  96. return embedding
  97. def _extract_speech_feat(self, speech):
  98. speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
  99. speech_feat = speech_feat.unsqueeze(dim=0)
  100. speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
  101. return speech_feat, speech_feat_len
  102. def text_normalize(self, text, split=True, text_frontend=True):
  103. if text_frontend is False:
  104. return [text] if split is True else text
  105. text = text.strip()
  106. if contains_chinese(text):
  107. if self.use_ttsfrd:
  108. texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
  109. text = ''.join(texts)
  110. else:
  111. text = self.zh_tn_model.normalize(text)
  112. text = text.replace("\n", "")
  113. text = replace_blank(text)
  114. text = replace_corner_mark(text)
  115. text = text.replace(".", "。")
  116. text = text.replace(" - ", ",")
  117. text = remove_bracket(text)
  118. text = re.sub(r'[,,、]+$', '。', text)
  119. texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
  120. token_min_n=60, merge_len=20, comma_split=False))
  121. else:
  122. if self.use_ttsfrd:
  123. texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
  124. text = ''.join(texts)
  125. else:
  126. text = self.en_tn_model.normalize(text)
  127. text = spell_out_number(text, self.inflect_parser)
  128. texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
  129. token_min_n=60, merge_len=20, comma_split=False))
  130. if split is False:
  131. return text
  132. return texts
  133. def frontend_sft(self, tts_text, spk_id):
  134. tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
  135. embedding = self.spk2info[spk_id]['embedding']
  136. model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
  137. return model_input
  138. def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
  139. tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
  140. prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
  141. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
  142. speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
  143. speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
  144. if resample_rate == 24000:
  145. # cosyvoice2, force speech_feat % speech_token = 2
  146. token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
  147. speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
  148. speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
  149. embedding = self._extract_spk_embedding(prompt_speech_16k)
  150. model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
  151. 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
  152. 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
  153. 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
  154. 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
  155. 'llm_embedding': embedding, 'flow_embedding': embedding}
  156. return model_input
  157. def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
  158. model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
  159. # in cross lingual mode, we remove prompt in llm
  160. del model_input['prompt_text']
  161. del model_input['prompt_text_len']
  162. del model_input['llm_prompt_speech_token']
  163. del model_input['llm_prompt_speech_token_len']
  164. return model_input
  165. def frontend_instruct(self, tts_text, spk_id, instruct_text):
  166. model_input = self.frontend_sft(tts_text, spk_id)
  167. # in instruct mode, we remove spk_embedding in llm due to information leakage
  168. del model_input['llm_embedding']
  169. instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
  170. model_input['prompt_text'] = instruct_text_token
  171. model_input['prompt_text_len'] = instruct_text_token_len
  172. return model_input
  173. def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
  174. tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
  175. prompt_text_token, prompt_text_token_len = self._extract_text_token(instruct_text + '<|endofprompt|>')
  176. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
  177. speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
  178. speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
  179. if resample_rate == 24000:
  180. # cosyvoice2, force speech_feat % speech_token = 2
  181. token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
  182. speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
  183. speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
  184. embedding = self._extract_spk_embedding(prompt_speech_16k)
  185. model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
  186. 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
  187. 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
  188. 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
  189. 'llm_embedding': embedding, 'flow_embedding': embedding}
  190. return model_input
  191. def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
  192. prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
  193. prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
  194. prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
  195. embedding = self._extract_spk_embedding(prompt_speech_16k)
  196. source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
  197. model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
  198. 'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
  199. 'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
  200. 'flow_embedding': embedding}
  201. return model_input