yuekaiz 1 mês atrás
pai
commit
a224be6117

+ 1 - 1
examples/grpo/cosyvoice2/infer_dataset.py

@@ -53,7 +53,7 @@ except RuntimeError:
     pass
 
 
-TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
+TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}"  # noqa: E501
 
 
 def audio_decode_cosyvoice2(

+ 2 - 2
runtime/triton_trtllm/model_repo/token2wav_dit/1/token2wav_dit.py

@@ -464,7 +464,7 @@ class CosyVoice2_Token2Wav(torch.nn.Module):
 
 def collate_fn(batch):
     ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
-    for i, item in enumerate(batch):
+    for item in batch:
         generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
         audio = torch.from_numpy(item['prompt_audio']['array']).float()
         prompt_audios_list.append(audio)
@@ -496,7 +496,7 @@ if __name__ == "__main__":
 
     data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
 
-    for epoch in range(args.warmup):
+    for _ in range(args.warmup):
         start_time = time.time()
         for batch in data_loader:
             ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = batch

+ 1 - 1
runtime/triton_trtllm/offline_inference.py

@@ -512,7 +512,7 @@ def main(args):
                         ))
                     else:
                         outputs = []
-                        for i, chat in enumerate(batch["chat_list"]):
+                        for chat in batch["chat_list"]:
                             payload = {
                                 "model": args.openai_model_name,
                                 "messages": chat,

+ 1 - 1
runtime/triton_trtllm/streaming_inference.py

@@ -13,7 +13,7 @@ import soundfile as sf
 def collate_fn(batch):
     ids, generated_speech_tokens_list, prompt_audios_list, prompt_audios_sample_rate = [], [], [], []
     prompt_speech_tokens_list, prompt_text_list = [], []
-    for i, item in enumerate(batch):
+    for item in batch:
         generated_speech_tokens_list.append(item['target_audio_cosy2_tokens'])
         audio = torch.from_numpy(item['prompt_audio']['array']).float()
         prompt_audios_list.append(audio)