llm.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import queue
  15. import random
  16. import time
  17. import threading
  18. from typing import Dict, Optional, Callable, List, Generator
  19. import torch
  20. from torch import nn
  21. import torch.nn.functional as F
  22. from transformers import Qwen2ForCausalLM
  23. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  24. from cosyvoice.utils.common import IGNORE_ID
  25. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  26. from cosyvoice.utils.common import th_accuracy
  27. from cosyvoice.utils.file_utils import logging
  28. from cosyvoice.utils.mask import make_pad_mask
  29. class TransformerLM(torch.nn.Module):
  30. def __init__(
  31. self,
  32. text_encoder_input_size: int,
  33. llm_input_size: int,
  34. llm_output_size: int,
  35. text_token_size: int,
  36. speech_token_size: int,
  37. text_encoder: torch.nn.Module,
  38. llm: torch.nn.Module,
  39. sampling: Callable,
  40. length_normalized_loss: bool = True,
  41. lsm_weight: float = 0.0,
  42. spk_embed_dim: int = 192,
  43. ):
  44. super().__init__()
  45. self.llm_input_size = llm_input_size
  46. self.speech_token_size = speech_token_size
  47. # 1. build text token inputs related modules
  48. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  49. self.text_encoder = text_encoder
  50. self.text_encoder_affine_layer = nn.Linear(
  51. self.text_encoder.output_size(),
  52. llm_input_size
  53. )
  54. # 2. build speech token language model related modules
  55. self.sos_eos = 0
  56. self.task_id = 1
  57. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  58. self.llm = llm
  59. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  60. self.criterion_ce = LabelSmoothingLoss(
  61. size=speech_token_size + 1,
  62. padding_idx=IGNORE_ID,
  63. smoothing=lsm_weight,
  64. normalize_length=length_normalized_loss,
  65. )
  66. # 3. [Optional] build speech token related modules
  67. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  68. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  69. # 4. sampling method
  70. self.sampling = sampling
  71. def encode(
  72. self,
  73. text: torch.Tensor,
  74. text_lengths: torch.Tensor,
  75. ):
  76. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  77. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  78. encoder_out = self.text_encoder_affine_layer(encoder_out)
  79. return encoder_out, encoder_out_lens
  80. def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  81. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  82. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  83. lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  84. for i in range(len(text_token))]
  85. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  86. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  87. return lm_input, lm_input_len
  88. def forward(
  89. self,
  90. batch: dict,
  91. device: torch.device,
  92. ) -> Dict[str, Optional[torch.Tensor]]:
  93. """
  94. Args:
  95. text: (B, L, D)
  96. text_lengths: (B,)
  97. audio: (B, T, N) or (B, T)
  98. audio_lengths: (B,)
  99. """
  100. text_token = batch['text_token'].to(device)
  101. text_token_len = batch['text_token_len'].to(device)
  102. speech_token = batch['speech_token'].to(device)
  103. speech_token_len = batch['speech_token_len'].to(device)
  104. embedding = batch['embedding'].to(device)
  105. # 1. prepare llm_target
  106. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  107. [self.speech_token_size]) for i in range(text_token.size(0))]
  108. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  109. # 1. encode text_token
  110. text_token = self.text_embedding(text_token)
  111. text_token, text_token_len = self.encode(text_token, text_token_len)
  112. # 2. embedding projection
  113. embedding = F.normalize(embedding, dim=1)
  114. embedding = self.spk_embed_affine_layer(embedding)
  115. embedding = embedding.unsqueeze(1)
  116. # 3. eos and task_id
  117. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  118. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  119. # 4. encode speech_token
  120. speech_token = self.speech_embedding(speech_token)
  121. # 5. unpad and pad
  122. lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
  123. task_id_emb, speech_token, speech_token_len)
  124. # 6. run lm forward
  125. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  126. logits = self.llm_decoder(lm_output)
  127. loss = self.criterion_ce(logits, lm_target)
  128. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  129. return {'loss': loss, 'acc': acc}
  130. def sampling_ids(
  131. self,
  132. weighted_scores: torch.Tensor,
  133. decoded_tokens: List,
  134. sampling: int,
  135. ignore_eos: bool = True,
  136. ):
  137. num_trials, max_trials = 0, 100
  138. while True:
  139. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  140. if (not ignore_eos) or (self.speech_token_size not in top_ids):
  141. break
  142. num_trials += 1
  143. if num_trials > max_trials:
  144. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  145. return top_ids
  146. @torch.inference_mode()
  147. def inference(
  148. self,
  149. text: torch.Tensor,
  150. text_len: torch.Tensor,
  151. prompt_text: torch.Tensor,
  152. prompt_text_len: torch.Tensor,
  153. prompt_speech_token: torch.Tensor,
  154. prompt_speech_token_len: torch.Tensor,
  155. embedding: torch.Tensor,
  156. sampling: int = 25,
  157. max_token_text_ratio: float = 20,
  158. min_token_text_ratio: float = 2,
  159. uuid: str = '',
  160. ) -> Generator[torch.Tensor, None, None]:
  161. device = text.device
  162. text = torch.concat([prompt_text, text], dim=1)
  163. text_len += prompt_text_len
  164. text = self.text_embedding(text)
  165. # 1. encode text
  166. text, text_len = self.encode(text, text_len)
  167. # 2. encode embedding
  168. if embedding.shape[0] != 0:
  169. embedding = F.normalize(embedding, dim=1)
  170. embedding = self.spk_embed_affine_layer(embedding)
  171. embedding = embedding.unsqueeze(dim=1)
  172. else:
  173. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  174. # 3. concat llm_input
  175. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  176. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  177. if prompt_speech_token_len != 0:
  178. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  179. else:
  180. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  181. lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  182. # 4. cal min/max_length
  183. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  184. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  185. # 5. step by step decode
  186. out_tokens = []
  187. offset = 0
  188. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  189. for i in range(max_len):
  190. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  191. att_cache=att_cache, cnn_cache=cnn_cache,
  192. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  193. device=lm_input.device)).to(torch.bool))
  194. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  195. # force continue decode first token
  196. if i == 0:
  197. logp[:, self.speech_token_size] = -float('inf')
  198. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  199. if top_ids == self.speech_token_size:
  200. break
  201. # in stream mode, yield token one by one
  202. yield top_ids
  203. out_tokens.append(top_ids)
  204. offset += lm_input.size(1)
  205. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  206. class Qwen2Encoder(torch.nn.Module):
  207. def __init__(self, pretrain_path):
  208. super().__init__()
  209. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  210. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  211. T = xs.size(1)
  212. masks = ~make_pad_mask(xs_lens, T)
  213. outs = self.model(
  214. inputs_embeds=xs,
  215. attention_mask=masks,
  216. output_hidden_states=True,
  217. return_dict=True,
  218. )
  219. return outs.hidden_states[-1], masks.unsqueeze(1)
  220. def forward_one_step(self, xs, masks, cache=None):
  221. input_masks = masks[:, -1, :]
  222. outs = self.model(
  223. inputs_embeds=xs,
  224. attention_mask=input_masks,
  225. output_hidden_states=True,
  226. return_dict=True,
  227. use_cache=True,
  228. past_key_values=cache,
  229. )
  230. xs = outs.hidden_states[-1]
  231. new_cache = outs.past_key_values
  232. return xs, new_cache
  233. class Qwen2LM(TransformerLM):
  234. def __init__(
  235. self,
  236. llm_input_size: int,
  237. llm_output_size: int,
  238. speech_token_size: int,
  239. llm: torch.nn.Module,
  240. sampling: Callable,
  241. length_normalized_loss: bool = True,
  242. lsm_weight: float = 0.0,
  243. mix_ratio: List[int] = [5, 15],
  244. ):
  245. torch.nn.Module.__init__(self)
  246. self.llm_input_size = llm_input_size
  247. self.llm_output_size = llm_output_size
  248. self.speech_token_size = speech_token_size
  249. # 2. build speech token language model related modules
  250. self.sos_eos = 0
  251. self.task_id = 1
  252. self.fill_token = 2
  253. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  254. self.llm = llm
  255. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  256. self.criterion_ce = LabelSmoothingLoss(
  257. size=speech_token_size + 3,
  258. padding_idx=IGNORE_ID,
  259. smoothing=lsm_weight,
  260. normalize_length=length_normalized_loss,
  261. )
  262. # 3. [Optional] build speech token related modules
  263. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  264. # 4. sampling method
  265. self.sampling = sampling
  266. self.mix_ratio = mix_ratio
  267. # 5. vllm related
  268. self.stop_token_ids = [speech_token_size + i for i in range(3)]
  269. self.vllm_output_queue = {}
  270. self.lock = threading.Lock()
  271. def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
  272. lm_target, lm_input = [], []
  273. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  274. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  275. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  276. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  277. for i in range(len(text_token)):
  278. # bistream sequence
  279. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  280. this_lm_target, this_lm_input = [], []
  281. this_lm_target.append(IGNORE_ID)
  282. this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
  283. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  284. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  285. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  286. if len(this_text_token) == self.mix_ratio[0]:
  287. assert len(this_speech_token) == self.mix_ratio[1]
  288. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  289. this_lm_target += this_speech_token
  290. this_lm_target.append(self.speech_token_size + 2)
  291. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  292. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  293. else:
  294. this_lm_target += [-1] * len(this_text_token)
  295. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  296. this_lm_target.append(self.speech_token_size)
  297. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  298. this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
  299. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  300. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  301. # unistream sequence
  302. else:
  303. this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
  304. this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
  305. self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
  306. lm_target.append(this_lm_target)
  307. lm_input.append(this_lm_input)
  308. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  309. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  310. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  311. return lm_target, lm_input, lm_input_len
  312. def forward(
  313. self,
  314. batch: dict,
  315. device: torch.device,
  316. ) -> Dict[str, Optional[torch.Tensor]]:
  317. """
  318. Args:
  319. text: (B, L, D)
  320. text_lengths: (B,)
  321. audio: (B, T, N) or (B, T)
  322. audio_lengths: (B,)
  323. """
  324. text_token = batch['text_token'].to(device)
  325. text_token_len = batch['text_token_len'].to(device)
  326. speech_token = batch['speech_token'].to(device)
  327. speech_token_len = batch['speech_token_len'].to(device)
  328. # 1. encode text_token
  329. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  330. # 2. encode speech_token
  331. speech_token_emb = self.speech_embedding(speech_token)
  332. # 3. prepare llm_input/target
  333. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
  334. lm_target = lm_target.to(device)
  335. # 4. run lm forward
  336. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  337. logits = self.llm_decoder(lm_output)
  338. loss = self.criterion_ce(logits, lm_target.to(device))
  339. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  340. return {'loss': loss, 'acc': acc}
  341. @torch.inference_mode()
  342. def inference(
  343. self,
  344. text: torch.Tensor,
  345. text_len: torch.Tensor,
  346. prompt_text: torch.Tensor,
  347. prompt_text_len: torch.Tensor,
  348. prompt_speech_token: torch.Tensor,
  349. prompt_speech_token_len: torch.Tensor,
  350. embedding: torch.Tensor,
  351. sampling: int = 25,
  352. max_token_text_ratio: float = 20,
  353. min_token_text_ratio: float = 2,
  354. uuid: str = '',
  355. ) -> Generator[torch.Tensor, None, None]:
  356. device = text.device
  357. text = torch.concat([prompt_text, text], dim=1)
  358. text_len += prompt_text_len
  359. text = self.llm.model.model.embed_tokens(text)
  360. # 3. concat llm_input
  361. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  362. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  363. if prompt_speech_token_len != 0:
  364. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  365. else:
  366. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  367. lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  368. # 4. cal min/max_length
  369. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  370. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  371. # 5. step by step decode
  372. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  373. yield token
  374. @torch.inference_mode()
  375. def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
  376. if hasattr(self, 'vllm'):
  377. from vllm import SamplingParams, RequestOutput
  378. sampling_params = SamplingParams(top_k=sampling,
  379. stop_token_ids=self.stop_token_ids,
  380. min_tokens=min_len,
  381. max_tokens=max_len)
  382. with self.lock:
  383. self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
  384. self.vllm_output_queue[uuid] = queue.Queue()
  385. out_tokens = []
  386. while True:
  387. with self.lock:
  388. if self.vllm_output_queue[uuid].empty() is True:
  389. request_outputs: List[RequestOutput] = self.vllm.step()
  390. for request_output in request_outputs:
  391. top_ids = list(request_output.outputs[0].token_ids)[-1]
  392. self.vllm_output_queue[request_output.request_id].put(top_ids)
  393. if self.vllm_output_queue[uuid].empty() is False:
  394. top_ids = self.vllm_output_queue[uuid].get()
  395. if top_ids in self.stop_token_ids:
  396. break
  397. # in stream mode, yield token one by one
  398. yield top_ids
  399. out_tokens.append(top_ids)
  400. time.sleep(0.001)
  401. with self.lock:
  402. self.vllm_output_queue.pop(uuid)
  403. else:
  404. out_tokens = []
  405. cache = None
  406. for i in range(max_len):
  407. y_pred, cache = self.llm.forward_one_step(lm_input,
  408. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  409. cache=cache)
  410. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  411. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  412. if top_ids == self.speech_token_size:
  413. break
  414. if top_ids > self.speech_token_size:
  415. continue
  416. # in stream mode, yield token one by one
  417. yield top_ids
  418. out_tokens.append(top_ids)
  419. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  420. @torch.inference_mode()
  421. def inference_bistream(
  422. self,
  423. text: Generator,
  424. prompt_text: torch.Tensor,
  425. prompt_text_len: torch.Tensor,
  426. prompt_speech_token: torch.Tensor,
  427. prompt_speech_token_len: torch.Tensor,
  428. embedding: torch.Tensor,
  429. sampling: int = 25,
  430. max_token_text_ratio: float = 20,
  431. min_token_text_ratio: float = 2,
  432. ) -> Generator[torch.Tensor, None, None]:
  433. device = prompt_text.device
  434. # 1. prepare input
  435. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  436. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  437. if prompt_speech_token_len != 0:
  438. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  439. else:
  440. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  441. lm_input = torch.concat([sos_eos_emb], dim=1)
  442. # 2. iterate text
  443. out_tokens = []
  444. cache = None
  445. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  446. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  447. next_fill_index = -1
  448. for this_text in text:
  449. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  450. # prompt_speech_token_emb not empty, try append to lm_input
  451. while prompt_speech_token_emb.size(1) != 0:
  452. if text_cache.size(1) >= self.mix_ratio[0]:
  453. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  454. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  455. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  456. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  457. else:
  458. logging.info('not enough text token to decode, wait for more')
  459. break
  460. # no prompt_speech_token_emb remain, can decode some speech token
  461. if prompt_speech_token_emb.size(1) == 0:
  462. if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  463. logging.info('get fill token, need to append more text token')
  464. if text_cache.size(1) >= self.mix_ratio[0]:
  465. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  466. logging.info('append {} text token'.format(lm_input_text.size(1)))
  467. if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
  468. lm_input = lm_input_text
  469. else:
  470. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  471. text_cache = text_cache[:, self.mix_ratio[0]:]
  472. else:
  473. logging.info('not enough text token to decode, wait for more')
  474. continue
  475. while True:
  476. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  477. y_pred, cache = self.llm.forward_one_step(lm_input,
  478. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  479. cache=cache)
  480. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  481. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  482. top_ids = self.speech_token_size + 2
  483. next_fill_index += (self.mix_ratio[1] + 1)
  484. else:
  485. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
  486. if top_ids == self.speech_token_size + 2:
  487. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  488. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  489. out_tokens.append(top_ids)
  490. if top_ids >= self.speech_token_size:
  491. if top_ids == self.speech_token_size + 2:
  492. break
  493. else:
  494. raise ValueError('should not get token {}'.format(top_ids))
  495. yield top_ids
  496. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  497. # 3. final decode
  498. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  499. logging.info('no more text token, decode until met eos')
  500. while True:
  501. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  502. y_pred, cache = self.llm.forward_one_step(lm_input,
  503. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  504. cache=cache)
  505. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  506. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
  507. out_tokens.append(top_ids)
  508. if top_ids >= self.speech_token_size:
  509. if top_ids == self.speech_token_size:
  510. break
  511. else:
  512. raise ValueError('should not get token {}'.format(top_ids))
  513. # in stream mode, yield token one by one
  514. yield top_ids
  515. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)