|
@@ -17,6 +17,7 @@ import random
|
|
|
import time
|
|
import time
|
|
|
import threading
|
|
import threading
|
|
|
from typing import Dict, Optional, Callable, List, Generator
|
|
from typing import Dict, Optional, Callable, List, Generator
|
|
|
|
|
+import numpy as np
|
|
|
import torch
|
|
import torch
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
@@ -216,7 +217,7 @@ class TransformerLM(torch.nn.Module):
|
|
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
|
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
|
|
device=lm_input.device)).to(torch.bool))
|
|
device=lm_input.device)).to(torch.bool))
|
|
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
|
- top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
|
|
|
|
|
|
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
|
|
|
if top_ids == self.eos_token:
|
|
if top_ids == self.eos_token:
|
|
|
break
|
|
break
|
|
|
# in stream mode, yield token one by one
|
|
# in stream mode, yield token one by one
|
|
@@ -544,7 +545,7 @@ class Qwen2LM(TransformerLM):
|
|
|
cache = None
|
|
cache = None
|
|
|
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
|
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
|
|
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
|
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
|
|
- next_fill_index = -1
|
|
|
|
|
|
|
+ next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
|
|
|
for this_text in text:
|
|
for this_text in text:
|
|
|
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
|
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
|
|
# prompt_speech_token_emb not empty, try append to lm_input
|
|
# prompt_speech_token_emb not empty, try append to lm_input
|
|
@@ -582,7 +583,7 @@ class Qwen2LM(TransformerLM):
|
|
|
top_ids = self.fill_token
|
|
top_ids = self.fill_token
|
|
|
next_fill_index += (self.mix_ratio[1] + 1)
|
|
next_fill_index += (self.mix_ratio[1] + 1)
|
|
|
else:
|
|
else:
|
|
|
- top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
|
|
|
|
|
|
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
|
|
|
if top_ids == self.fill_token:
|
|
if top_ids == self.fill_token:
|
|
|
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
|
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
|
|
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
|
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|