|
@@ -20,6 +20,7 @@ from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
|
|
from cosyvoice.utils.common import IGNORE_ID
|
|
|
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
|
|
from cosyvoice.utils.common import th_accuracy
|
|
|
+from cosyvoice.utils.file_utils import logging
|
|
|
|
|
|
|
|
|
class TransformerLM(torch.nn.Module):
|
|
@@ -144,10 +145,14 @@ class TransformerLM(torch.nn.Module):
|
|
|
sampling: int,
|
|
|
ignore_eos: bool = True,
|
|
|
):
|
|
|
+ num_trials, max_trials = 0, 100
|
|
|
while True:
|
|
|
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
|
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
|
break
|
|
|
+ num_trials += 1
|
|
|
+ if num_trials > max_trials:
|
|
|
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
|
|
return top_ids
|
|
|
|
|
|
@torch.inference_mode()
|
|
@@ -239,7 +244,7 @@ class Qwen2Encoder(torch.nn.Module):
|
|
|
return xs, new_cache
|
|
|
|
|
|
|
|
|
-class Qwen2LM(torch.nn.Module):
|
|
|
+class Qwen2LM(TransformerLM):
|
|
|
def __init__(
|
|
|
self,
|
|
|
llm_input_size: int,
|
|
@@ -249,8 +254,9 @@ class Qwen2LM(torch.nn.Module):
|
|
|
sampling: Callable,
|
|
|
length_normalized_loss: bool = True,
|
|
|
lsm_weight: float = 0.0,
|
|
|
+ mix_ratio: List[int] = [5, 15],
|
|
|
):
|
|
|
- super().__init__()
|
|
|
+ torch.nn.Module.__init__(self)
|
|
|
self.llm_input_size = llm_input_size
|
|
|
self.llm_output_size = llm_output_size
|
|
|
self.speech_token_size = speech_token_size
|
|
@@ -275,23 +281,7 @@ class Qwen2LM(torch.nn.Module):
|
|
|
|
|
|
# 4. sampling method
|
|
|
self.sampling = sampling
|
|
|
-
|
|
|
- def sampling_ids(
|
|
|
- self,
|
|
|
- weighted_scores: torch.Tensor,
|
|
|
- decoded_tokens: List,
|
|
|
- sampling: int,
|
|
|
- ignore_eos: bool = True,
|
|
|
- ):
|
|
|
- num_trials, max_trials = 0, 100
|
|
|
- while True:
|
|
|
- top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
|
- if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
|
- break
|
|
|
- num_trials += 1
|
|
|
- if num_trials > max_trials:
|
|
|
- raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
|
|
- return top_ids
|
|
|
+ self.mix_ratio = mix_ratio
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def inference(
|
|
@@ -312,9 +302,6 @@ class Qwen2LM(torch.nn.Module):
|
|
|
text_len += prompt_text_len
|
|
|
text = self.llm.model.model.embed_tokens(text)
|
|
|
|
|
|
- # 2. encode embedding
|
|
|
- embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
|
|
-
|
|
|
# 3. concat llm_input
|
|
|
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
|
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
@@ -322,7 +309,7 @@ class Qwen2LM(torch.nn.Module):
|
|
|
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
|
else:
|
|
|
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
|
|
- lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
|
+ lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
|
|
|
|
|
# 4. cal min/max_length
|
|
|
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
|
@@ -345,3 +332,100 @@ class Qwen2LM(torch.nn.Module):
|
|
|
yield top_ids
|
|
|
out_tokens.append(top_ids)
|
|
|
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
|
+
|
|
|
+ @torch.inference_mode()
|
|
|
+ def inference_bistream(
|
|
|
+ self,
|
|
|
+ text: Generator,
|
|
|
+ prompt_text: torch.Tensor,
|
|
|
+ prompt_text_len: torch.Tensor,
|
|
|
+ prompt_speech_token: torch.Tensor,
|
|
|
+ prompt_speech_token_len: torch.Tensor,
|
|
|
+ embedding: torch.Tensor,
|
|
|
+ sampling: int = 25,
|
|
|
+ max_token_text_ratio: float = 20,
|
|
|
+ min_token_text_ratio: float = 2,
|
|
|
+ ) -> Generator[torch.Tensor, None, None]:
|
|
|
+
|
|
|
+ device = prompt_text.device
|
|
|
+ # 1. prepare input
|
|
|
+ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
|
+ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+ if prompt_speech_token_len != 0:
|
|
|
+ prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
|
|
+ else:
|
|
|
+ prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
|
|
+ lm_input = torch.concat([sos_eos_emb], dim=1)
|
|
|
+
|
|
|
+ # 2. iterate text
|
|
|
+ out_tokens = []
|
|
|
+ cache = None
|
|
|
+ # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
|
|
+ text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
|
|
+ next_fill_index = -1
|
|
|
+ for this_text in text:
|
|
|
+ text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
|
|
+ # prompt_speech_token_emb not empty, try append to lm_input
|
|
|
+ while prompt_speech_token_emb.size(1) != 0:
|
|
|
+ if text_cache.size(1) >= self.mix_ratio[0]:
|
|
|
+ lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
|
|
+ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
|
|
+ lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
|
|
+ text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
|
|
+ else:
|
|
|
+ logging.info('not enough text token to decode, wait for more')
|
|
|
+ break
|
|
|
+ # no prompt_speech_token_emb remain, can decode some speech token
|
|
|
+ if prompt_speech_token_emb.size(1) == 0:
|
|
|
+ if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
|
|
+ logging.info('get fill token, need to append more text token')
|
|
|
+ if text_cache.size(1) >= self.mix_ratio[0]:
|
|
|
+ lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
|
|
+ logging.info('append {} text token'.format(lm_input_text.size(1)))
|
|
|
+ lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
|
|
+ text_cache = text_cache[:, self.mix_ratio[0]:]
|
|
|
+ else:
|
|
|
+ logging.info('not enough text token to decode, wait for more')
|
|
|
+ continue
|
|
|
+ while True:
|
|
|
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
|
|
+ y_pred, cache = self.llm.forward_one_step(lm_input,
|
|
|
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
|
|
+ cache=cache)
|
|
|
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
|
+ if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
|
|
+ top_ids = self.speech_token_size + 2
|
|
|
+ next_fill_index += (self.mix_ratio[1] + 1)
|
|
|
+ else:
|
|
|
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
|
|
|
+ if top_ids == self.speech_token_size + 2:
|
|
|
+ next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
|
|
+ logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
|
|
+ out_tokens.append(top_ids)
|
|
|
+ if top_ids >= self.speech_token_size:
|
|
|
+ if top_ids == self.speech_token_size + 2:
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ raise ValueError('should not get token {}'.format(top_ids))
|
|
|
+ yield top_ids
|
|
|
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
|
|
+
|
|
|
+ # 3. final decode
|
|
|
+ lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
|
|
+ logging.info('no more text token, decode until met eos')
|
|
|
+ while True:
|
|
|
+ seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
|
|
+ y_pred, cache = self.llm.forward_one_step(lm_input,
|
|
|
+ masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
|
|
+ cache=cache)
|
|
|
+ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
|
|
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
|
|
|
+ out_tokens.append(top_ids)
|
|
|
+ if top_ids >= self.speech_token_size:
|
|
|
+ if top_ids == self.speech_token_size:
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ raise ValueError('should not get token {}'.format(top_ids))
|
|
|
+ # in stream mode, yield token one by one
|
|
|
+ yield top_ids
|
|
|
+ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|