llm.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. from typing import Dict, Optional, Callable, List, Generator
  16. import torch
  17. from torch import nn
  18. import torch.nn.functional as F
  19. from transformers import Qwen2ForCausalLM
  20. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  21. from cosyvoice.utils.common import IGNORE_ID
  22. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  23. from cosyvoice.utils.common import th_accuracy
  24. from cosyvoice.utils.file_utils import logging
  25. from cosyvoice.utils.mask import make_pad_mask
  26. class TransformerLM(torch.nn.Module):
  27. def __init__(
  28. self,
  29. text_encoder_input_size: int,
  30. llm_input_size: int,
  31. llm_output_size: int,
  32. text_token_size: int,
  33. speech_token_size: int,
  34. text_encoder: torch.nn.Module,
  35. llm: torch.nn.Module,
  36. sampling: Callable,
  37. length_normalized_loss: bool = True,
  38. lsm_weight: float = 0.0,
  39. spk_embed_dim: int = 192,
  40. ):
  41. super().__init__()
  42. self.llm_input_size = llm_input_size
  43. self.speech_token_size = speech_token_size
  44. # 1. build text token inputs related modules
  45. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  46. self.text_encoder = text_encoder
  47. self.text_encoder_affine_layer = nn.Linear(
  48. self.text_encoder.output_size(),
  49. llm_input_size
  50. )
  51. # 2. build speech token language model related modules
  52. self.sos_eos = 0
  53. self.task_id = 1
  54. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  55. self.llm = llm
  56. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  57. self.criterion_ce = LabelSmoothingLoss(
  58. size=speech_token_size + 1,
  59. padding_idx=IGNORE_ID,
  60. smoothing=lsm_weight,
  61. normalize_length=length_normalized_loss,
  62. )
  63. # 3. [Optional] build speech token related modules
  64. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  65. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  66. # 4. sampling method
  67. self.sampling = sampling
  68. def encode(
  69. self,
  70. text: torch.Tensor,
  71. text_lengths: torch.Tensor,
  72. ):
  73. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  74. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  75. encoder_out = self.text_encoder_affine_layer(encoder_out)
  76. return encoder_out, encoder_out_lens
  77. def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  78. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  79. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  80. lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  81. for i in range(len(text_token))]
  82. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  83. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  84. return lm_input, lm_input_len
  85. def forward(
  86. self,
  87. batch: dict,
  88. device: torch.device,
  89. ) -> Dict[str, Optional[torch.Tensor]]:
  90. """
  91. Args:
  92. text: (B, L, D)
  93. text_lengths: (B,)
  94. audio: (B, T, N) or (B, T)
  95. audio_lengths: (B,)
  96. """
  97. text_token = batch['text_token'].to(device)
  98. text_token_len = batch['text_token_len'].to(device)
  99. speech_token = batch['speech_token'].to(device)
  100. speech_token_len = batch['speech_token_len'].to(device)
  101. embedding = batch['embedding'].to(device)
  102. # 1. prepare llm_target
  103. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  104. [self.speech_token_size]) for i in range(text_token.size(0))]
  105. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  106. # 1. encode text_token
  107. text_token = self.text_embedding(text_token)
  108. text_token, text_token_len = self.encode(text_token, text_token_len)
  109. # 2. embedding projection
  110. embedding = F.normalize(embedding, dim=1)
  111. embedding = self.spk_embed_affine_layer(embedding)
  112. embedding = embedding.unsqueeze(1)
  113. # 3. eos and task_id
  114. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  115. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  116. # 4. encode speech_token
  117. speech_token = self.speech_embedding(speech_token)
  118. # 5. unpad and pad
  119. lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
  120. task_id_emb, speech_token, speech_token_len)
  121. # 6. run lm forward
  122. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  123. logits = self.llm_decoder(lm_output)
  124. loss = self.criterion_ce(logits, lm_target)
  125. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  126. return {'loss': loss, 'acc': acc}
  127. def sampling_ids(
  128. self,
  129. weighted_scores: torch.Tensor,
  130. decoded_tokens: List,
  131. sampling: int,
  132. ignore_eos: bool = True,
  133. ):
  134. num_trials, max_trials = 0, 100
  135. while True:
  136. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  137. if (not ignore_eos) or (self.speech_token_size not in top_ids):
  138. break
  139. num_trials += 1
  140. if num_trials > max_trials:
  141. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  142. return top_ids
  143. @torch.inference_mode()
  144. def inference(
  145. self,
  146. text: torch.Tensor,
  147. text_len: torch.Tensor,
  148. prompt_text: torch.Tensor,
  149. prompt_text_len: torch.Tensor,
  150. prompt_speech_token: torch.Tensor,
  151. prompt_speech_token_len: torch.Tensor,
  152. embedding: torch.Tensor,
  153. sampling: int = 25,
  154. max_token_text_ratio: float = 20,
  155. min_token_text_ratio: float = 2,
  156. ) -> Generator[torch.Tensor, None, None]:
  157. device = text.device
  158. text = torch.concat([prompt_text, text], dim=1)
  159. text_len += prompt_text_len
  160. text = self.text_embedding(text)
  161. # 1. encode text
  162. text, text_len = self.encode(text, text_len)
  163. # 2. encode embedding
  164. if embedding.shape[0] != 0:
  165. embedding = F.normalize(embedding, dim=1)
  166. embedding = self.spk_embed_affine_layer(embedding)
  167. embedding = embedding.unsqueeze(dim=1)
  168. else:
  169. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  170. # 3. concat llm_input
  171. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  172. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  173. if prompt_speech_token_len != 0:
  174. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  175. else:
  176. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  177. lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  178. # 4. cal min/max_length
  179. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  180. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  181. # 5. step by step decode
  182. out_tokens = []
  183. offset = 0
  184. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  185. for i in range(max_len):
  186. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  187. att_cache=att_cache, cnn_cache=cnn_cache,
  188. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  189. device=lm_input.device)).to(torch.bool))
  190. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  191. # force continue decode first token
  192. if i == 0:
  193. logp[:, self.speech_token_size] = -float('inf')
  194. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  195. if top_ids == self.speech_token_size:
  196. break
  197. # in stream mode, yield token one by one
  198. yield top_ids
  199. out_tokens.append(top_ids)
  200. offset += lm_input.size(1)
  201. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  202. class Qwen2Encoder(torch.nn.Module):
  203. def __init__(self, pretrain_path):
  204. super().__init__()
  205. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  206. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  207. T = xs.size(1)
  208. masks = ~make_pad_mask(xs_lens, T)
  209. outs = self.model(
  210. inputs_embeds=xs,
  211. attention_mask=masks,
  212. output_hidden_states=True,
  213. return_dict=True,
  214. )
  215. return outs.hidden_states[-1], masks.unsqueeze(1)
  216. def forward_one_step(self, xs, masks, cache=None):
  217. input_masks = masks[:, -1, :]
  218. outs = self.model(
  219. inputs_embeds=xs,
  220. attention_mask=input_masks,
  221. output_hidden_states=True,
  222. return_dict=True,
  223. use_cache=True,
  224. past_key_values=cache,
  225. )
  226. xs = outs.hidden_states[-1]
  227. new_cache = outs.past_key_values
  228. return xs, new_cache
  229. class Qwen2LM(TransformerLM):
  230. def __init__(
  231. self,
  232. llm_input_size: int,
  233. llm_output_size: int,
  234. speech_token_size: int,
  235. llm: torch.nn.Module,
  236. sampling: Callable,
  237. length_normalized_loss: bool = True,
  238. lsm_weight: float = 0.0,
  239. mix_ratio: List[int] = [5, 15],
  240. ):
  241. torch.nn.Module.__init__(self)
  242. self.llm_input_size = llm_input_size
  243. self.llm_output_size = llm_output_size
  244. self.speech_token_size = speech_token_size
  245. # 2. build speech token language model related modules
  246. self.sos_eos = 0
  247. self.task_id = 1
  248. self.fill_token = 2
  249. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  250. self.llm = llm
  251. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  252. self.criterion_ce = LabelSmoothingLoss(
  253. size=speech_token_size + 3,
  254. padding_idx=IGNORE_ID,
  255. smoothing=lsm_weight,
  256. normalize_length=length_normalized_loss,
  257. )
  258. # 3. [Optional] build speech token related modules
  259. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  260. # 4. sampling method
  261. self.sampling = sampling
  262. self.mix_ratio = mix_ratio
  263. def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
  264. lm_target, lm_input = [], []
  265. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  266. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  267. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  268. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  269. for i in range(len(text_token)):
  270. # bistream sequence
  271. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  272. this_lm_target, this_lm_input = [], []
  273. this_lm_target.append(IGNORE_ID)
  274. this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
  275. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  276. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  277. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  278. if len(this_text_token) == self.mix_ratio[0]:
  279. assert len(this_speech_token) == self.mix_ratio[1]
  280. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  281. this_lm_target += this_speech_token
  282. this_lm_target.append(self.speech_token_size + 2)
  283. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  284. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  285. else:
  286. this_lm_target += [-1] * len(this_text_token)
  287. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  288. this_lm_target.append(self.speech_token_size)
  289. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  290. this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
  291. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  292. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  293. # unistream sequence
  294. else:
  295. this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
  296. this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
  297. self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
  298. lm_target.append(this_lm_target)
  299. lm_input.append(this_lm_input)
  300. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  301. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  302. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  303. return lm_target, lm_input, lm_input_len
  304. def forward(
  305. self,
  306. batch: dict,
  307. device: torch.device,
  308. ) -> Dict[str, Optional[torch.Tensor]]:
  309. """
  310. Args:
  311. text: (B, L, D)
  312. text_lengths: (B,)
  313. audio: (B, T, N) or (B, T)
  314. audio_lengths: (B,)
  315. """
  316. text_token = batch['text_token'].to(device)
  317. text_token_len = batch['text_token_len'].to(device)
  318. speech_token = batch['speech_token'].to(device)
  319. speech_token_len = batch['speech_token_len'].to(device)
  320. # 1. encode text_token
  321. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  322. # 2. encode speech_token
  323. speech_token_emb = self.speech_embedding(speech_token)
  324. # 3. prepare llm_input/target
  325. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
  326. lm_target = lm_target.to(device)
  327. # 4. run lm forward
  328. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  329. logits = self.llm_decoder(lm_output)
  330. loss = self.criterion_ce(logits, lm_target.to(device))
  331. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  332. return {'loss': loss, 'acc': acc}
  333. @torch.inference_mode()
  334. def inference(
  335. self,
  336. text: torch.Tensor,
  337. text_len: torch.Tensor,
  338. prompt_text: torch.Tensor,
  339. prompt_text_len: torch.Tensor,
  340. prompt_speech_token: torch.Tensor,
  341. prompt_speech_token_len: torch.Tensor,
  342. embedding: torch.Tensor,
  343. sampling: int = 25,
  344. max_token_text_ratio: float = 20,
  345. min_token_text_ratio: float = 2,
  346. ) -> Generator[torch.Tensor, None, None]:
  347. device = text.device
  348. text = torch.concat([prompt_text, text], dim=1)
  349. text_len += prompt_text_len
  350. text = self.llm.model.model.embed_tokens(text)
  351. # 3. concat llm_input
  352. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  353. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  354. if prompt_speech_token_len != 0:
  355. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  356. else:
  357. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  358. lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  359. # 4. cal min/max_length
  360. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  361. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  362. # 5. step by step decode
  363. out_tokens = []
  364. cache = None
  365. for i in range(max_len):
  366. y_pred, cache = self.llm.forward_one_step(lm_input,
  367. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  368. cache=cache)
  369. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  370. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  371. if top_ids == self.speech_token_size:
  372. break
  373. if top_ids > self.speech_token_size:
  374. continue
  375. # in stream mode, yield token one by one
  376. yield top_ids
  377. out_tokens.append(top_ids)
  378. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  379. @torch.inference_mode()
  380. def inference_bistream(
  381. self,
  382. text: Generator,
  383. prompt_text: torch.Tensor,
  384. prompt_text_len: torch.Tensor,
  385. prompt_speech_token: torch.Tensor,
  386. prompt_speech_token_len: torch.Tensor,
  387. embedding: torch.Tensor,
  388. sampling: int = 25,
  389. max_token_text_ratio: float = 20,
  390. min_token_text_ratio: float = 2,
  391. ) -> Generator[torch.Tensor, None, None]:
  392. device = prompt_text.device
  393. # 1. prepare input
  394. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  395. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  396. if prompt_speech_token_len != 0:
  397. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  398. else:
  399. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  400. lm_input = torch.concat([sos_eos_emb], dim=1)
  401. # 2. iterate text
  402. out_tokens = []
  403. cache = None
  404. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  405. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  406. next_fill_index = -1
  407. for this_text in text:
  408. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  409. # prompt_speech_token_emb not empty, try append to lm_input
  410. while prompt_speech_token_emb.size(1) != 0:
  411. if text_cache.size(1) >= self.mix_ratio[0]:
  412. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  413. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  414. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  415. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  416. else:
  417. logging.info('not enough text token to decode, wait for more')
  418. break
  419. # no prompt_speech_token_emb remain, can decode some speech token
  420. if prompt_speech_token_emb.size(1) == 0:
  421. if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  422. logging.info('get fill token, need to append more text token')
  423. if text_cache.size(1) >= self.mix_ratio[0]:
  424. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  425. logging.info('append {} text token'.format(lm_input_text.size(1)))
  426. if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
  427. lm_input = lm_input_text
  428. else:
  429. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  430. text_cache = text_cache[:, self.mix_ratio[0]:]
  431. else:
  432. logging.info('not enough text token to decode, wait for more')
  433. continue
  434. while True:
  435. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  436. y_pred, cache = self.llm.forward_one_step(lm_input,
  437. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  438. cache=cache)
  439. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  440. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  441. top_ids = self.speech_token_size + 2
  442. next_fill_index += (self.mix_ratio[1] + 1)
  443. else:
  444. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
  445. if top_ids == self.speech_token_size + 2:
  446. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  447. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  448. out_tokens.append(top_ids)
  449. if top_ids >= self.speech_token_size:
  450. if top_ids == self.speech_token_size + 2:
  451. break
  452. else:
  453. raise ValueError('should not get token {}'.format(top_ids))
  454. yield top_ids
  455. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  456. # 3. final decode
  457. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  458. logging.info('no more text token, decode until met eos')
  459. while True:
  460. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  461. y_pred, cache = self.llm.forward_one_step(lm_input,
  462. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  463. cache=cache)
  464. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  465. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
  466. out_tokens.append(top_ids)
  467. if top_ids >= self.speech_token_size:
  468. if top_ids == self.speech_token_size:
  469. break
  470. else:
  471. raise ValueError('should not get token {}'.format(top_ids))
  472. # in stream mode, yield token one by one
  473. yield top_ids
  474. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)