|
@@ -180,7 +180,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|
|
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
|
|
prompt_text_after_apply_template_list = []
|
|
|
mels, prompt_audio_cosy2tokens_list, full_text_list = [], [], []
|
|
|
- for i, item in enumerate(batch):
|
|
|
+ for _, item in enumerate(batch):
|
|
|
audio_processing_start_time = time.time()
|
|
|
prompt_text, target_text = (
|
|
|
item["prompt_text"],
|
|
@@ -402,7 +402,7 @@ def main(args):
|
|
|
)
|
|
|
torch.cuda.synchronize()
|
|
|
elif args.backend == "trtllm":
|
|
|
- batch_input_ids = [ids for ids in batch["input_ids"]]
|
|
|
+ batch_input_ids = list(batch["input_ids"])
|
|
|
input_lengths = [x.size(0) for x in batch_input_ids]
|
|
|
|
|
|
end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") if "<|eos1|>" in tokenizer.get_vocab() else tokenizer.eos_token_id
|