|
|
@@ -22,7 +22,7 @@ import random
|
|
|
import librosa
|
|
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
|
-from cosyvoice.cli.cosyvoice import CosyVoice
|
|
|
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
|
|
from cosyvoice.utils.file_utils import load_wav, logging
|
|
|
from cosyvoice.utils.common import set_all_random_seed
|
|
|
|
|
|
@@ -51,7 +51,7 @@ def postprocess(speech, top_db=60, hop_length=220, win_length=440):
|
|
|
)
|
|
|
if speech.abs().max() > max_val:
|
|
|
speech = speech / speech.abs().max() * max_val
|
|
|
- speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
|
|
|
+ speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
|
|
|
return speech
|
|
|
|
|
|
|
|
|
@@ -71,31 +71,31 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
|
|
|
if mode_checkbox_group in ['自然语言控制']:
|
|
|
if cosyvoice.frontend.instruct is False:
|
|
|
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
|
|
|
- yield (target_sr, default_data)
|
|
|
+ yield (cosyvoice.sample_rate, default_data)
|
|
|
if instruct_text == '':
|
|
|
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
|
|
|
- yield (target_sr, default_data)
|
|
|
+ yield (cosyvoice.sample_rate, default_data)
|
|
|
if prompt_wav is not None or prompt_text != '':
|
|
|
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
|
|
|
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
|
|
|
if mode_checkbox_group in ['跨语种复刻']:
|
|
|
if cosyvoice.frontend.instruct is True:
|
|
|
gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
|
|
|
- yield (target_sr, default_data)
|
|
|
+ yield (cosyvoice.sample_rate, default_data)
|
|
|
if instruct_text != '':
|
|
|
gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
|
|
|
if prompt_wav is None:
|
|
|
gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
|
|
|
- yield (target_sr, default_data)
|
|
|
+ yield (cosyvoice.sample_rate, default_data)
|
|
|
gr.Info('您正在使用跨语种复刻模式, 请确保合成文本和prompt文本为不同语言')
|
|
|
# if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
|
|
|
if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
|
|
|
if prompt_wav is None:
|
|
|
gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
|
|
|
- yield (target_sr, default_data)
|
|
|
+ yield (cosyvoice.sample_rate, default_data)
|
|
|
if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
|
|
|
gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
|
|
|
- yield (target_sr, default_data)
|
|
|
+ yield (cosyvoice.sample_rate, default_data)
|
|
|
# sft mode only use sft_dropdown
|
|
|
if mode_checkbox_group in ['预训练音色']:
|
|
|
if instruct_text != '' or prompt_wav is not None or prompt_text != '':
|
|
|
@@ -104,7 +104,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
|
|
|
if mode_checkbox_group in ['3s极速复刻']:
|
|
|
if prompt_text == '':
|
|
|
gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
|
|
|
- yield (target_sr, default_data)
|
|
|
+ yield (cosyvoice.sample_rate, default_data)
|
|
|
if instruct_text != '':
|
|
|
gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
|
|
|
|
|
|
@@ -112,24 +112,24 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
|
|
|
logging.info('get sft inference request')
|
|
|
set_all_random_seed(seed)
|
|
|
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
|
|
|
- yield (target_sr, i['tts_speech'].numpy().flatten())
|
|
|
+ yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
|
|
|
elif mode_checkbox_group == '3s极速复刻':
|
|
|
logging.info('get zero_shot inference request')
|
|
|
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
|
|
set_all_random_seed(seed)
|
|
|
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
|
|
|
- yield (target_sr, i['tts_speech'].numpy().flatten())
|
|
|
+ yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
|
|
|
elif mode_checkbox_group == '跨语种复刻':
|
|
|
logging.info('get cross_lingual inference request')
|
|
|
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
|
|
set_all_random_seed(seed)
|
|
|
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
|
|
|
- yield (target_sr, i['tts_speech'].numpy().flatten())
|
|
|
+ yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
|
|
|
else:
|
|
|
logging.info('get instruct inference request')
|
|
|
set_all_random_seed(seed)
|
|
|
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
|
|
|
- yield (target_sr, i['tts_speech'].numpy().flatten())
|
|
|
+ yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
|
|
|
|
|
|
|
|
|
def main():
|
|
|
@@ -178,11 +178,11 @@ if __name__ == '__main__':
|
|
|
default=8000)
|
|
|
parser.add_argument('--model_dir',
|
|
|
type=str,
|
|
|
- default='pretrained_models/CosyVoice-300M',
|
|
|
+ default='pretrained_models/CosyVoice2-0.5B',
|
|
|
help='local path or modelscope repo id')
|
|
|
args = parser.parse_args()
|
|
|
- cosyvoice = CosyVoice(args.model_dir)
|
|
|
+ cosyvoice = CosyVoice2(args.model_dir) if 'CosyVoice2' in args.model_dir else CosyVoice(args.model_dir)
|
|
|
sft_spk = cosyvoice.list_avaliable_spks()
|
|
|
- prompt_sr, target_sr = 16000, 22050
|
|
|
- default_data = np.zeros(target_sr)
|
|
|
+ prompt_sr = 16000
|
|
|
+ default_data = np.zeros(cosyvoice.sample_rate)
|
|
|
main()
|