llm_dpo.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, Optional, Callable, List, Generator
  15. import torch
  16. from torch import nn
  17. import torch.nn.functional as F
  18. from transformers import Qwen2ForCausalLM
  19. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  20. from cosyvoice.utils.common import IGNORE_ID
  21. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  22. from cosyvoice.utils.common import th_accuracy
  23. from cosyvoice.utils.file_utils import logging
  24. from cosyvoice.utils.mask import make_pad_mask
  25. class TransformerLM(torch.nn.Module):
  26. def __init__(
  27. self,
  28. text_encoder_input_size: int,
  29. llm_input_size: int,
  30. llm_output_size: int,
  31. text_token_size: int,
  32. speech_token_size: int,
  33. text_encoder: torch.nn.Module,
  34. llm: torch.nn.Module,
  35. sampling: Callable,
  36. length_normalized_loss: bool = True,
  37. lsm_weight: float = 0.0,
  38. spk_embed_dim: int = 192,
  39. ):
  40. super().__init__()
  41. self.llm_input_size = llm_input_size
  42. self.speech_token_size = speech_token_size
  43. # 1. build text token inputs related modules
  44. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  45. self.text_encoder = text_encoder
  46. self.text_encoder_affine_layer = nn.Linear(
  47. self.text_encoder.output_size(),
  48. llm_input_size
  49. )
  50. # 2. build speech token language model related modules
  51. self.sos_eos = 0
  52. self.task_id = 1
  53. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  54. self.llm = llm
  55. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  56. self.criterion_ce = LabelSmoothingLoss(
  57. size=speech_token_size + 1,
  58. padding_idx=IGNORE_ID,
  59. smoothing=lsm_weight,
  60. normalize_length=length_normalized_loss,
  61. )
  62. # 3. [Optional] build speech token related modules
  63. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  64. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  65. # 4. sampling method
  66. self.sampling = sampling
  67. def encode(
  68. self,
  69. text: torch.Tensor,
  70. text_lengths: torch.Tensor,
  71. ):
  72. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  73. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  74. encoder_out = self.text_encoder_affine_layer(encoder_out)
  75. return encoder_out, encoder_out_lens
  76. def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  77. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  78. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  79. lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  80. for i in range(len(text_token))]
  81. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  82. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  83. return lm_input, lm_input_len
  84. def forward(
  85. self,
  86. batch: dict,
  87. device: torch.device,
  88. ) -> Dict[str, Optional[torch.Tensor]]:
  89. """
  90. Args:
  91. text: (B, L, D)
  92. text_lengths: (B,)
  93. audio: (B, T, N) or (B, T)
  94. audio_lengths: (B,)
  95. """
  96. text_token = batch['text_token'].to(device)
  97. text_token_len = batch['text_token_len'].to(device)
  98. speech_token = batch['speech_token'].to(device)
  99. speech_token_len = batch['speech_token_len'].to(device)
  100. embedding = batch['embedding'].to(device)
  101. # 1. prepare llm_target
  102. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  103. [self.speech_token_size]) for i in range(text_token.size(0))]
  104. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  105. # 1. encode text_token
  106. text_token = self.text_embedding(text_token)
  107. text_token, text_token_len = self.encode(text_token, text_token_len)
  108. # 2. embedding projection
  109. embedding = F.normalize(embedding, dim=1)
  110. embedding = self.spk_embed_affine_layer(embedding)
  111. embedding = embedding.unsqueeze(1)
  112. # 3. eos and task_id
  113. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  114. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  115. # 4. encode speech_token
  116. speech_token = self.speech_embedding(speech_token)
  117. # 5. unpad and pad
  118. lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
  119. task_id_emb, speech_token, speech_token_len)
  120. # 6. run lm forward
  121. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  122. logits = self.llm_decoder(lm_output)
  123. loss = self.criterion_ce(logits, lm_target)
  124. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  125. return {'loss': loss, 'acc': acc}
  126. def sampling_ids(
  127. self,
  128. weighted_scores: torch.Tensor,
  129. decoded_tokens: List,
  130. sampling: int,
  131. ignore_eos: bool = True,
  132. ):
  133. num_trials, max_trials = 0, 100
  134. while True:
  135. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  136. if (not ignore_eos) or (self.speech_token_size not in top_ids):
  137. break
  138. num_trials += 1
  139. if num_trials > max_trials:
  140. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  141. return top_ids
  142. @torch.inference_mode()
  143. def inference(
  144. self,
  145. text: torch.Tensor,
  146. text_len: torch.Tensor,
  147. prompt_text: torch.Tensor,
  148. prompt_text_len: torch.Tensor,
  149. prompt_speech_token: torch.Tensor,
  150. prompt_speech_token_len: torch.Tensor,
  151. embedding: torch.Tensor,
  152. sampling: int = 25,
  153. max_token_text_ratio: float = 20,
  154. min_token_text_ratio: float = 2,
  155. ) -> Generator[torch.Tensor, None, None]:
  156. if self.fp16 is True:
  157. embedding = embedding.half()
  158. device = text.device
  159. text = torch.concat([prompt_text, text], dim=1)
  160. text_len += prompt_text_len
  161. text = self.text_embedding(text)
  162. # 1. encode text
  163. text, text_len = self.encode(text, text_len)
  164. # 2. encode embedding
  165. if embedding.shape[0] != 0:
  166. embedding = F.normalize(embedding, dim=1)
  167. embedding = self.spk_embed_affine_layer(embedding)
  168. embedding = embedding.unsqueeze(dim=1)
  169. else:
  170. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  171. # 3. concat llm_input
  172. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  173. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  174. if prompt_speech_token_len != 0:
  175. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  176. else:
  177. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  178. lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  179. # 4. cal min/max_length
  180. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  181. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  182. # 5. step by step decode
  183. out_tokens = []
  184. offset = 0
  185. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  186. for i in range(max_len):
  187. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  188. att_cache=att_cache, cnn_cache=cnn_cache,
  189. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  190. device=lm_input.device)).to(torch.bool))
  191. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  192. # force continue decode first token
  193. if i == 0:
  194. logp[:, self.speech_token_size] = -float('inf')
  195. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  196. if top_ids == self.speech_token_size:
  197. break
  198. # in stream mode, yield token one by one
  199. yield top_ids
  200. out_tokens.append(top_ids)
  201. offset += lm_input.size(1)
  202. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  203. class Qwen2Encoder(torch.nn.Module):
  204. def __init__(self, pretrain_path):
  205. super().__init__()
  206. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  207. def forward_one_step(self, xs, masks, cache=None):
  208. input_masks = masks[:, -1, :]
  209. outs = self.model(
  210. inputs_embeds=xs,
  211. attention_mask=input_masks,
  212. output_hidden_states=True,
  213. return_dict=True,
  214. use_cache=True,
  215. past_key_values=cache,
  216. )
  217. xs = outs.hidden_states[-1]
  218. new_cache = outs.past_key_values
  219. return xs, new_cache
  220. class Qwen2LM(TransformerLM):
  221. def __init__(
  222. self,
  223. llm_input_size: int,
  224. llm_output_size: int,
  225. speech_token_size: int,
  226. llm: torch.nn.Module,
  227. sampling: Callable,
  228. length_normalized_loss: bool = True,
  229. lsm_weight: float = 0.0,
  230. mix_ratio: List[int] = [5, 15],
  231. dpo: bool = False,
  232. ):
  233. torch.nn.Module.__init__(self)
  234. self.llm_input_size = llm_input_size
  235. self.llm_output_size = llm_output_size
  236. self.speech_token_size = speech_token_size
  237. # 2. build speech token language model related modules
  238. self.sos_eos = 0
  239. self.task_id = 1
  240. self.fill_token = 2
  241. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  242. self.llm = llm
  243. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  244. self.criterion_ce = LabelSmoothingLoss(
  245. size=speech_token_size + 3,
  246. padding_idx=IGNORE_ID,
  247. smoothing=lsm_weight,
  248. normalize_length=length_normalized_loss,
  249. )
  250. # 3. [Optional] build speech token related modules
  251. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  252. # 4. sampling method
  253. self.sampling = sampling
  254. self.mix_ratio = mix_ratio
  255. # 5. [Optional] set dpo
  256. self.dpo = dpo
  257. def forward(
  258. self,
  259. batch: dict,
  260. device: torch.device,
  261. ) -> Dict[str, Optional[torch.Tensor]]:
  262. text_token = batch['text_token'].to(device)
  263. text_token_len = batch['text_token_len'].to(device)
  264. speech_token = batch['speech_token'].to(device)
  265. speech_token_len = batch['speech_token_len'].to(device)
  266. if self.dpo:
  267. reject_speech_token = batch['reject_speech_token'].to(device)
  268. reject_speech_token_len = batch['reject_speech_token_len'].to(device)
  269. # 1. prepare llm_target
  270. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  271. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  272. target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  273. [self.speech_token_size]) for i in range(text_token.size(0))]
  274. if self.dpo:
  275. reject_target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + reject_speech_token[i, :reject_speech_token_len[i]].tolist() +
  276. [self.speech_token_size]) for i in range(text_token.size(0))]
  277. target_ids.extend(reject_target_ids)
  278. target_ids = pad_sequence(target_ids, batch_first=True, padding_value=IGNORE_ID).to(device)
  279. # 2. speech token projection
  280. speech_emb = self.speech_embedding(speech_token)
  281. if self.dpo:
  282. reject_speech_emb = self.speech_embedding(reject_speech_token)
  283. # 3. text token projection
  284. text_token_lst = unpad_sequence(text_token, text_token_len, batch_first=True)
  285. text_emb = [self.llm.model.model.embed_tokens(y) for y in text_token_lst]
  286. # 4. prepare llm_input
  287. speech_emb = unpad_sequence(speech_emb, speech_token_len.cpu(), batch_first=True)
  288. input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), speech_emb[i]], dim=0)
  289. for i in range(len(text_emb))]
  290. if self.dpo:
  291. reject_speech_emb = unpad_sequence(reject_speech_emb, reject_speech_token_len.cpu(), batch_first=True)
  292. reject_input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), reject_speech_emb[i]], dim=0)
  293. for i in range(len(text_emb))]
  294. input_emb.extend(reject_input_emb)
  295. input_emb_lengths = torch.tensor([i.size(0) for i in input_emb], dtype=torch.int32).to(device)
  296. input_emb = pad_sequence(input_emb, batch_first=True, padding_value=IGNORE_ID).to(device)
  297. attention_mask = ~make_pad_mask(input_emb_lengths)
  298. result = self.llm.model(
  299. inputs_embeds=input_emb,
  300. attention_mask=attention_mask,
  301. return_dict=True
  302. )
  303. hidden_states = result.hidden_states
  304. logits = self.llm_decoder(hidden_states[-1])
  305. loss = self.criterion_ce(logits[: speech_token.shape[0]], target_ids[: speech_token.shape[0]])
  306. acc = th_accuracy(
  307. logits[: speech_token.shape[0]].view(-1, self.speech_token_size + 3),
  308. target_ids[: speech_token.shape[0]],
  309. ignore_label=IGNORE_ID,
  310. )
  311. if not self.dpo:
  312. return {
  313. "loss": loss,
  314. "acc": acc,
  315. }
  316. else:
  317. all_logps_sum, all_logps_mean = self.get_batch_logps(
  318. logits, target_ids, attention_mask, text_token_len, average_log_prob=False, ignore_id=IGNORE_ID
  319. )
  320. chosen_logps = all_logps_sum[: speech_token.shape[0]]
  321. rejected_logps = all_logps_sum[speech_token.shape[0]:]
  322. return {
  323. "loss": loss,
  324. "acc": acc,
  325. "chosen_logps": chosen_logps,
  326. "rejected_logps": rejected_logps
  327. }
  328. def get_batch_logps(
  329. self,
  330. logits: torch.FloatTensor,
  331. labels: torch.LongTensor,
  332. attention_mask,
  333. prompt_token_lens,
  334. average_log_prob: bool = False,
  335. ignore_id: int = -1,
  336. ) -> torch.FloatTensor:
  337. """Compute the log probabilities of the given labels under the given logits.
  338. Args:
  339. logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
  340. labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
  341. average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
  342. Returns:
  343. A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
  344. """
  345. assert average_log_prob == False
  346. assert logits.shape[:-1] == labels.shape
  347. labels = labels[:, 1:].clone()
  348. logits = logits[:, :-1, :]
  349. loss_masks = attention_mask.clone().bool()
  350. # mask prompts
  351. for mask, text_token_len in zip(loss_masks, prompt_token_lens):
  352. mask[:text_token_len + 1] = False
  353. loss_masks = loss_masks[:, 1:]
  354. labels[loss_masks == False] = 0
  355. # dummy token; we'll ignore the losses on these tokens later
  356. ignore = labels == ignore_id
  357. labels = labels.masked_fill(ignore, 0) # avoid -1 index
  358. per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) # (bs, time,)
  359. logprobs_sums = (per_token_logps * loss_masks).sum(-1)
  360. logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1)
  361. return logprobs_sums, logprobs_means
  362. @torch.inference_mode()
  363. def inference(
  364. self,
  365. text: torch.Tensor,
  366. text_len: torch.Tensor,
  367. prompt_text: torch.Tensor,
  368. prompt_text_len: torch.Tensor,
  369. prompt_speech_token: torch.Tensor,
  370. prompt_speech_token_len: torch.Tensor,
  371. embedding: torch.Tensor,
  372. sampling: int = 25,
  373. max_token_text_ratio: float = 20,
  374. min_token_text_ratio: float = 2,
  375. ) -> Generator[torch.Tensor, None, None]:
  376. device = text.device
  377. text = torch.concat([prompt_text, text], dim=1)
  378. text_len += prompt_text_len
  379. text = self.llm.model.model.embed_tokens(text)
  380. # 3. concat llm_input
  381. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  382. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  383. if prompt_speech_token_len != 0:
  384. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  385. else:
  386. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  387. lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  388. # 4. cal min/max_length
  389. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  390. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  391. # 5. step by step decode
  392. out_tokens = []
  393. cache = None
  394. for i in range(max_len):
  395. y_pred, cache = self.llm.forward_one_step(lm_input,
  396. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  397. cache=cache)
  398. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  399. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  400. if top_ids == self.speech_token_size:
  401. break
  402. if top_ids > self.speech_token_size:
  403. continue
  404. # in stream mode, yield token one by one
  405. yield top_ids
  406. out_tokens.append(top_ids)
  407. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  408. @torch.inference_mode()
  409. def inference_bistream(
  410. self,
  411. text: Generator,
  412. prompt_text: torch.Tensor,
  413. prompt_text_len: torch.Tensor,
  414. prompt_speech_token: torch.Tensor,
  415. prompt_speech_token_len: torch.Tensor,
  416. embedding: torch.Tensor,
  417. sampling: int = 25,
  418. max_token_text_ratio: float = 20,
  419. min_token_text_ratio: float = 2,
  420. ) -> Generator[torch.Tensor, None, None]:
  421. device = prompt_text.device
  422. # 1. prepare input
  423. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  424. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  425. if prompt_speech_token_len != 0:
  426. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  427. else:
  428. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  429. lm_input = torch.concat([sos_eos_emb], dim=1)
  430. # 2. iterate text
  431. out_tokens = []
  432. cache = None
  433. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  434. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  435. next_fill_index = -1
  436. for this_text in text:
  437. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  438. # prompt_speech_token_emb not empty, try append to lm_input
  439. while prompt_speech_token_emb.size(1) != 0:
  440. if text_cache.size(1) >= self.mix_ratio[0]:
  441. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  442. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  443. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  444. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  445. else:
  446. logging.info('not enough text token to decode, wait for more')
  447. break
  448. # no prompt_speech_token_emb remain, can decode some speech token
  449. if prompt_speech_token_emb.size(1) == 0:
  450. if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  451. logging.info('get fill token, need to append more text token')
  452. if text_cache.size(1) >= self.mix_ratio[0]:
  453. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  454. logging.info('append {} text token'.format(lm_input_text.size(1)))
  455. if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
  456. lm_input = lm_input_text
  457. else:
  458. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  459. text_cache = text_cache[:, self.mix_ratio[0]:]
  460. else:
  461. logging.info('not enough text token to decode, wait for more')
  462. continue
  463. while True:
  464. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  465. y_pred, cache = self.llm.forward_one_step(lm_input,
  466. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  467. cache=cache)
  468. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  469. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  470. top_ids = self.speech_token_size + 2
  471. next_fill_index += (self.mix_ratio[1] + 1)
  472. else:
  473. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
  474. if top_ids == self.speech_token_size + 2:
  475. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  476. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  477. out_tokens.append(top_ids)
  478. if top_ids >= self.speech_token_size:
  479. if top_ids == self.speech_token_size + 2:
  480. break
  481. else:
  482. raise ValueError('should not get token {}'.format(top_ids))
  483. yield top_ids
  484. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  485. # 3. final decode
  486. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  487. logging.info('no more text token, decode until met eos')
  488. while True:
  489. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  490. y_pred, cache = self.llm.forward_one_step(lm_input,
  491. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  492. cache=cache)
  493. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  494. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
  495. out_tokens.append(top_ids)
  496. if top_ids >= self.speech_token_size:
  497. if top_ids == self.speech_token_size:
  498. break
  499. else:
  500. raise ValueError('should not get token {}'.format(top_ids))
  501. # in stream mode, yield token one by one
  502. yield top_ids
  503. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)