|
|
@@ -367,8 +367,11 @@ class Qwen2LM(TransformerLM):
|
|
|
"""
|
|
|
text_token = batch['text_token'].to(device)
|
|
|
text_token_len = batch['text_token_len'].to(device)
|
|
|
- speech_token = batch['speech_token'].to(device)
|
|
|
- speech_token_len = batch['speech_token_len'].to(device)
|
|
|
+ if 'speech_token' not in batch:
|
|
|
+ speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
|
|
+ else:
|
|
|
+ speech_token = batch['speech_token'].to(device)
|
|
|
+ speech_token_len = batch['speech_token_len'].to(device)
|
|
|
|
|
|
# 1. encode text_token
|
|
|
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
|
|
@@ -686,8 +689,12 @@ class CosyVoice3LM(Qwen2LM):
|
|
|
"""
|
|
|
text_token = batch['text_token'].to(device)
|
|
|
text_token_len = batch['text_token_len'].to(device)
|
|
|
- speech_token = batch['speech_token'].to(device)
|
|
|
- speech_token_len = batch['speech_token_len'].to(device)
|
|
|
+ if 'speech_token' not in batch:
|
|
|
+ speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
|
|
+ else:
|
|
|
+ speech_token = batch['speech_token'].to(device)
|
|
|
+ speech_token_len = batch['speech_token_len'].to(device)
|
|
|
+
|
|
|
# NOTE should append instruct_token to sequence, not implemented yet
|
|
|
instruct_token = batch['instruct_token'].to(device)
|
|
|
instruct_token_len = batch['instruct_token_len'].to(device)
|