lyuxiang.lx 2 달 전
부모
커밋
3047591fad
3개의 변경된 파일9개의 추가작업 그리고 5개의 파일을 삭제
  1. 2 1
      cosyvoice/llm/llm.py
  2. 3 2
      examples/libritts/cosyvoice/local/prepare_reject_sample.py
  3. 4 2
      vllm_example.py

+ 2 - 1
cosyvoice/llm/llm.py

@@ -401,7 +401,8 @@ class Qwen2LM(TransformerLM):
         speech_token_combined_emb = self.speech_embedding(speech_token_combined)
 
         # 3. prepare llm_input/target
-        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
+        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), \
+                                                                         speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
         lm_target = lm_target.to(device)
 
         # 4. run lm forward

+ 3 - 2
examples/libritts/cosyvoice/local/prepare_reject_sample.py

@@ -2,7 +2,8 @@ import argparse
 import logging
 import os
 from tqdm import tqdm
-import torch, torchaudio
+import torch
+import torchaudio
 from cosyvoice.cli.cosyvoice import CosyVoice2
 from cosyvoice.utils.file_utils import load_wav
 
@@ -30,7 +31,7 @@ def main():
             if prompt_speech_16k.shape[1] >= 30 * 16000:
                 continue
             speech_list = []
-            for i, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
+            for _, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
                 speech_list.append(j['tts_speech'])
             negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
             torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')

+ 4 - 2
vllm_example.py

@@ -9,13 +9,15 @@ from cosyvoice.utils.file_utils import load_wav
 from cosyvoice.utils.common import set_all_random_seed
 from tqdm import tqdm
 
+
 def main():
     cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
     prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
     for i in tqdm(range(100)):
         set_all_random_seed(i)
-        for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
+        for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
             continue
 
-if __name__=='__main__':
+
+if __name__ == '__main__':
     main()