|
|
@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
|
|
|
print(f"Unexpected token: {token_str}")
|
|
|
return speech_ids
|
|
|
|
|
|
+
|
|
|
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
|
|
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
|
|
speech_id_str = ""
|
|
|
@@ -182,14 +183,13 @@ def get_args():
|
|
|
return args
|
|
|
|
|
|
|
|
|
-
|
|
|
def data_collator(batch, tokenizer, s3_tokenizer):
|
|
|
"""Simplified data collator for batch_size=1 processing"""
|
|
|
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
|
|
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
|
|
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
|
|
mels, prompt_audio_cosy2tokens_list = [], []
|
|
|
- for i, item in enumerate(batch):
|
|
|
+ for item in batch:
|
|
|
prompt_text, target_text = (
|
|
|
item["prompt_text"],
|
|
|
item["target_text"],
|
|
|
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|
|
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
|
|
for i in range(len(codes)):
|
|
|
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
|
|
- for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
|
|
|
+ for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
|
|
|
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
|
|
# Create chat template for LLM generation
|
|
|
chat = [
|
|
|
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|
|
)
|
|
|
input_ids_list.append(input_ids.squeeze(0))
|
|
|
|
|
|
-
|
|
|
# For batch_size=1, no need to pad
|
|
|
if len(input_ids_list) == 1:
|
|
|
input_ids = input_ids_list[0].unsqueeze(0)
|
|
|
@@ -256,7 +255,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
|
|
|
for input_ids in input_ids_list
|
|
|
]
|
|
|
input_ids = torch.stack(input_ids_list)
|
|
|
-
|
|
|
+
|
|
|
ids = [item["id"] for item in batch]
|
|
|
|
|
|
return {
|
|
|
@@ -287,7 +286,7 @@ def main():
|
|
|
assert torch.cuda.is_available()
|
|
|
world_size, local_rank, rank = init_distributed()
|
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
|
-
|
|
|
+
|
|
|
# Load LLM model and tokenizer directly
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
|
|
@@ -329,7 +328,7 @@ def main():
|
|
|
for batch in dataloader:
|
|
|
with torch.no_grad():
|
|
|
input_ids = batch["input_ids"].to(device)
|
|
|
-
|
|
|
+
|
|
|
# Generate speech tokens using LLM
|
|
|
outputs = model.generate(
|
|
|
input_ids,
|
|
|
@@ -339,31 +338,31 @@ def main():
|
|
|
temperature=args.temperature,
|
|
|
top_k=args.top_k,
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
# Process each sample in the batch
|
|
|
for i in range(len(batch["ids"])):
|
|
|
# Extract generated tokens (excluding input)
|
|
|
input_length = input_ids[i].shape[0]
|
|
|
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
|
|
|
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
|
-
|
|
|
+
|
|
|
# Extract speech IDs from token strings like <|s_23456|>
|
|
|
speech_ids = extract_speech_ids(speech_tokens_str)
|
|
|
-
|
|
|
+
|
|
|
if len(speech_ids) == 0:
|
|
|
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
|
|
continue
|
|
|
-
|
|
|
+
|
|
|
# Convert to tensor for CosyVoice2
|
|
|
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
|
|
|
-
|
|
|
+
|
|
|
if args.prompt_text is not None:
|
|
|
current_prompt_text = args.prompt_text
|
|
|
current_prompt_audio = prompt_speech_16k
|
|
|
else:
|
|
|
current_prompt_text = batch["prompt_text"][i]
|
|
|
current_prompt_audio = batch["prompt_audio_list"][i]
|
|
|
-
|
|
|
+
|
|
|
if current_prompt_audio is not None:
|
|
|
# Generate audio using CosyVoice2
|
|
|
audio_hat = audio_decode_cosyvoice2(
|
|
|
@@ -372,18 +371,17 @@ def main():
|
|
|
current_prompt_audio,
|
|
|
cosyvoice_codec,
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
# Convert to numpy and save
|
|
|
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
|
|
target_sample_rate = 24000
|
|
|
-
|
|
|
+
|
|
|
utt = batch["ids"][i]
|
|
|
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
|
|
|
|
|
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
|
|
|
else:
|
|
|
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
|
|
-
|
|
|
|
|
|
if rank == 0:
|
|
|
progress_bar.update(world_size * len(batch["ids"]))
|