|
|
@@ -381,7 +381,7 @@ class Qwen2LM(TransformerLM):
|
|
|
self,
|
|
|
batch: dict,
|
|
|
device: torch.device,
|
|
|
- ) -> Dict[str, Optional[torch.Tensor]]:
|
|
|
+ ) -> Dict[str, Optional[torch.Tensor]]:
|
|
|
text_token = batch['text_token'].to(device)
|
|
|
text_token_len = batch['text_token_len'].to(device)
|
|
|
speech_token = batch['speech_token'].to(device)
|