inference_deprecated.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import argparse
  16. import logging
  17. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  18. import os
  19. import torch
  20. from torch.utils.data import DataLoader
  21. import torchaudio
  22. from hyperpyyaml import load_hyperpyyaml
  23. from tqdm import tqdm
  24. from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
  25. from cosyvoice.dataset.dataset import Dataset
  26. def get_args():
  27. parser = argparse.ArgumentParser(description='inference with your model')
  28. parser.add_argument('--config', required=True, help='config file')
  29. parser.add_argument('--prompt_data', required=True, help='prompt data file')
  30. parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
  31. parser.add_argument('--tts_text', required=True, help='tts input file')
  32. parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
  33. parser.add_argument('--llm_model', required=True, help='llm model file')
  34. parser.add_argument('--flow_model', required=True, help='flow model file')
  35. parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
  36. parser.add_argument('--gpu',
  37. type=int,
  38. default=-1,
  39. help='gpu id for this rank, -1 for cpu')
  40. parser.add_argument('--mode',
  41. default='sft',
  42. choices=['sft', 'zero_shot'],
  43. help='inference mode')
  44. parser.add_argument('--result_dir', required=True, help='asr result file')
  45. args = parser.parse_args()
  46. print(args)
  47. return args
  48. def main():
  49. args = get_args()
  50. logging.basicConfig(level=logging.DEBUG,
  51. format='%(asctime)s %(levelname)s %(message)s')
  52. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
  53. # Init cosyvoice models from configs
  54. use_cuda = args.gpu >= 0 and torch.cuda.is_available()
  55. device = torch.device('cuda' if use_cuda else 'cpu')
  56. try:
  57. with open(args.config, 'r') as f:
  58. configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
  59. model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
  60. except Exception:
  61. try:
  62. with open(args.config, 'r') as f:
  63. configs = load_hyperpyyaml(f)
  64. model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
  65. except Exception:
  66. raise TypeError('no valid model_type!')
  67. model.load(args.llm_model, args.flow_model, args.hifigan_model)
  68. test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
  69. tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
  70. test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
  71. sample_rate = configs['sample_rate']
  72. del configs
  73. os.makedirs(args.result_dir, exist_ok=True)
  74. fn = os.path.join(args.result_dir, 'wav.scp')
  75. f = open(fn, 'w')
  76. with torch.no_grad():
  77. for _, batch in tqdm(enumerate(test_data_loader)):
  78. utts = batch["utts"]
  79. assert len(utts) == 1, "inference mode only support batchsize 1"
  80. text_token = batch["text_token"].to(device)
  81. text_token_len = batch["text_token_len"].to(device)
  82. tts_index = batch["tts_index"]
  83. tts_text_token = batch["tts_text_token"].to(device)
  84. tts_text_token_len = batch["tts_text_token_len"].to(device)
  85. speech_token = batch["speech_token"].to(device)
  86. speech_token_len = batch["speech_token_len"].to(device)
  87. speech_feat = batch["speech_feat"].to(device)
  88. speech_feat_len = batch["speech_feat_len"].to(device)
  89. utt_embedding = batch["utt_embedding"].to(device)
  90. spk_embedding = batch["spk_embedding"].to(device)
  91. if args.mode == 'sft':
  92. model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
  93. 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
  94. else:
  95. model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
  96. 'prompt_text': text_token, 'prompt_text_len': text_token_len,
  97. 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
  98. 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
  99. 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
  100. 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
  101. tts_speeches = []
  102. for model_output in model.tts(**model_input):
  103. tts_speeches.append(model_output['tts_speech'])
  104. tts_speeches = torch.concat(tts_speeches, dim=1)
  105. tts_key = '{}_{}'.format(utts[0], tts_index[0])
  106. tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
  107. torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
  108. f.write('{} {}\n'.format(tts_key, tts_fn))
  109. f.flush()
  110. f.close()
  111. logging.info('Result wav.scp saved in {}'.format(fn))
  112. if __name__ == '__main__':
  113. logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
  114. main()