llm.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import queue
  16. import random
  17. import time
  18. import threading
  19. from typing import Dict, Optional, Callable, List, Generator
  20. import torch
  21. from torch import nn
  22. import torch.nn.functional as F
  23. from transformers import Qwen2ForCausalLM
  24. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  25. from cosyvoice.utils.common import IGNORE_ID
  26. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  27. from cosyvoice.utils.common import th_accuracy
  28. from cosyvoice.utils.file_utils import logging
  29. from cosyvoice.utils.mask import make_pad_mask
  30. class TransformerLM(torch.nn.Module):
  31. def __init__(
  32. self,
  33. text_encoder_input_size: int,
  34. llm_input_size: int,
  35. llm_output_size: int,
  36. text_token_size: int,
  37. speech_token_size: int,
  38. text_encoder: torch.nn.Module,
  39. llm: torch.nn.Module,
  40. sampling: Callable,
  41. length_normalized_loss: bool = True,
  42. lsm_weight: float = 0.0,
  43. spk_embed_dim: int = 192,
  44. ):
  45. super().__init__()
  46. self.llm_input_size = llm_input_size
  47. self.speech_token_size = speech_token_size
  48. # 1. build text token inputs related modules
  49. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  50. self.text_encoder = text_encoder
  51. self.text_encoder_affine_layer = nn.Linear(
  52. self.text_encoder.output_size(),
  53. llm_input_size
  54. )
  55. # 2. build speech token language model related modules
  56. self.sos_eos = 0
  57. self.task_id = 1
  58. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  59. self.llm = llm
  60. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  61. self.criterion_ce = LabelSmoothingLoss(
  62. size=speech_token_size + 1,
  63. padding_idx=IGNORE_ID,
  64. smoothing=lsm_weight,
  65. normalize_length=length_normalized_loss,
  66. )
  67. # 3. [Optional] build speech token related modules
  68. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  69. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  70. # 4. sampling method
  71. self.sampling = sampling
  72. def encode(
  73. self,
  74. text: torch.Tensor,
  75. text_lengths: torch.Tensor,
  76. ):
  77. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  78. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  79. encoder_out = self.text_encoder_affine_layer(encoder_out)
  80. return encoder_out, encoder_out_lens
  81. def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  82. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  83. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  84. lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  85. for i in range(len(text_token))]
  86. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  87. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  88. return lm_input, lm_input_len
  89. def forward(
  90. self,
  91. batch: dict,
  92. device: torch.device,
  93. ) -> Dict[str, Optional[torch.Tensor]]:
  94. """
  95. Args:
  96. text: (B, L, D)
  97. text_lengths: (B,)
  98. audio: (B, T, N) or (B, T)
  99. audio_lengths: (B,)
  100. """
  101. text_token = batch['text_token'].to(device)
  102. text_token_len = batch['text_token_len'].to(device)
  103. speech_token = batch['speech_token'].to(device)
  104. speech_token_len = batch['speech_token_len'].to(device)
  105. embedding = batch['embedding'].to(device)
  106. # 1. prepare llm_target
  107. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  108. [self.speech_token_size]) for i in range(text_token.size(0))]
  109. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  110. # 1. encode text_token
  111. text_token = self.text_embedding(text_token)
  112. text_token, text_token_len = self.encode(text_token, text_token_len)
  113. # 2. embedding projection
  114. embedding = F.normalize(embedding, dim=1)
  115. embedding = self.spk_embed_affine_layer(embedding)
  116. embedding = embedding.unsqueeze(1)
  117. # 3. eos and task_id
  118. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  119. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  120. # 4. encode speech_token
  121. speech_token = self.speech_embedding(speech_token)
  122. # 5. unpad and pad
  123. lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
  124. task_id_emb, speech_token, speech_token_len)
  125. # 6. run lm forward
  126. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  127. logits = self.llm_decoder(lm_output)
  128. loss = self.criterion_ce(logits, lm_target)
  129. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  130. return {'loss': loss, 'acc': acc}
  131. def sampling_ids(
  132. self,
  133. weighted_scores: torch.Tensor,
  134. decoded_tokens: List,
  135. sampling: int,
  136. ignore_eos: bool = True,
  137. ):
  138. num_trials, max_trials = 0, 100
  139. while True:
  140. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  141. if (not ignore_eos) or (self.speech_token_size not in top_ids):
  142. break
  143. num_trials += 1
  144. if num_trials > max_trials:
  145. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  146. return top_ids
  147. @torch.inference_mode()
  148. def inference(
  149. self,
  150. text: torch.Tensor,
  151. text_len: torch.Tensor,
  152. prompt_text: torch.Tensor,
  153. prompt_text_len: torch.Tensor,
  154. prompt_speech_token: torch.Tensor,
  155. prompt_speech_token_len: torch.Tensor,
  156. embedding: torch.Tensor,
  157. sampling: int = 25,
  158. max_token_text_ratio: float = 20,
  159. min_token_text_ratio: float = 2,
  160. uuid: str = '',
  161. ) -> Generator[torch.Tensor, None, None]:
  162. device = text.device
  163. text = torch.concat([prompt_text, text], dim=1)
  164. text_len += prompt_text_len
  165. text = self.text_embedding(text)
  166. # 1. encode text
  167. text, text_len = self.encode(text, text_len)
  168. # 2. encode embedding
  169. if embedding.shape[0] != 0:
  170. embedding = F.normalize(embedding, dim=1)
  171. embedding = self.spk_embed_affine_layer(embedding)
  172. embedding = embedding.unsqueeze(dim=1)
  173. else:
  174. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  175. # 3. concat llm_input
  176. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  177. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  178. if prompt_speech_token_len != 0:
  179. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  180. else:
  181. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  182. lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  183. # 4. cal min/max_length
  184. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  185. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  186. # 5. step by step decode
  187. out_tokens = []
  188. offset = 0
  189. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  190. for i in range(max_len):
  191. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  192. att_cache=att_cache, cnn_cache=cnn_cache,
  193. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  194. device=lm_input.device)).to(torch.bool))
  195. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  196. # force continue decode first token
  197. if i == 0:
  198. logp[:, self.speech_token_size] = -float('inf')
  199. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  200. if top_ids == self.speech_token_size:
  201. break
  202. # in stream mode, yield token one by one
  203. yield top_ids
  204. out_tokens.append(top_ids)
  205. offset += lm_input.size(1)
  206. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  207. class Qwen2Encoder(torch.nn.Module):
  208. def __init__(self, pretrain_path):
  209. super().__init__()
  210. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  211. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  212. T = xs.size(1)
  213. masks = ~make_pad_mask(xs_lens, T)
  214. outs = self.model(
  215. inputs_embeds=xs,
  216. attention_mask=masks,
  217. output_hidden_states=True,
  218. return_dict=True,
  219. )
  220. return outs.hidden_states[-1], masks.unsqueeze(1)
  221. def forward_one_step(self, xs, masks, cache=None):
  222. input_masks = masks[:, -1, :]
  223. outs = self.model(
  224. inputs_embeds=xs,
  225. attention_mask=input_masks,
  226. output_hidden_states=True,
  227. return_dict=True,
  228. use_cache=True,
  229. past_key_values=cache,
  230. )
  231. xs = outs.hidden_states[-1]
  232. new_cache = outs.past_key_values
  233. return xs, new_cache
  234. class Qwen2LM(TransformerLM):
  235. def __init__(
  236. self,
  237. llm_input_size: int,
  238. llm_output_size: int,
  239. speech_token_size: int,
  240. llm: torch.nn.Module,
  241. sampling: Callable,
  242. length_normalized_loss: bool = True,
  243. lsm_weight: float = 0.0,
  244. mix_ratio: List[int] = [5, 15],
  245. ):
  246. torch.nn.Module.__init__(self)
  247. self.llm_input_size = llm_input_size
  248. self.llm_output_size = llm_output_size
  249. self.speech_token_size = speech_token_size
  250. # 2. build speech token language model related modules
  251. self.sos_eos = 0
  252. self.task_id = 1
  253. self.fill_token = 2
  254. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  255. self.llm = llm
  256. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  257. self.criterion_ce = LabelSmoothingLoss(
  258. size=speech_token_size + 3,
  259. padding_idx=IGNORE_ID,
  260. smoothing=lsm_weight,
  261. normalize_length=length_normalized_loss,
  262. )
  263. # 3. [Optional] build speech token related modules
  264. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  265. # 4. sampling method
  266. self.sampling = sampling
  267. self.mix_ratio = mix_ratio
  268. # 5. vllm related
  269. self.stop_token_ids = [speech_token_size + i for i in range(3)]
  270. self.vllm_output_queue = {}
  271. def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
  272. lm_target, lm_input = [], []
  273. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  274. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  275. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  276. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  277. for i in range(len(text_token)):
  278. # bistream sequence
  279. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  280. this_lm_target, this_lm_input = [], []
  281. this_lm_target.append(IGNORE_ID)
  282. this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
  283. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  284. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  285. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  286. if len(this_text_token) == self.mix_ratio[0]:
  287. assert len(this_speech_token) == self.mix_ratio[1]
  288. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  289. this_lm_target += this_speech_token
  290. this_lm_target.append(self.speech_token_size + 2)
  291. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  292. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  293. else:
  294. this_lm_target += [-1] * len(this_text_token)
  295. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  296. this_lm_target.append(self.speech_token_size)
  297. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  298. this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
  299. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  300. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  301. # unistream sequence
  302. else:
  303. this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
  304. this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
  305. self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
  306. lm_target.append(this_lm_target)
  307. lm_input.append(this_lm_input)
  308. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  309. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  310. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  311. return lm_target, lm_input, lm_input_len
  312. def forward(
  313. self,
  314. batch: dict,
  315. device: torch.device,
  316. ) -> Dict[str, Optional[torch.Tensor]]:
  317. """
  318. Args:
  319. text: (B, L, D)
  320. text_lengths: (B,)
  321. audio: (B, T, N) or (B, T)
  322. audio_lengths: (B,)
  323. """
  324. text_token = batch['text_token'].to(device)
  325. text_token_len = batch['text_token_len'].to(device)
  326. speech_token = batch['speech_token'].to(device)
  327. speech_token_len = batch['speech_token_len'].to(device)
  328. # 1. encode text_token
  329. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  330. # 2. encode speech_token
  331. speech_token_emb = self.speech_embedding(speech_token)
  332. # 3. prepare llm_input/target
  333. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
  334. lm_target = lm_target.to(device)
  335. # 4. run lm forward
  336. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  337. logits = self.llm_decoder(lm_output)
  338. loss = self.criterion_ce(logits, lm_target.to(device))
  339. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  340. return {'loss': loss, 'acc': acc}
  341. def forward_dpo(
  342. self,
  343. batch: dict,
  344. device: torch.device,
  345. ) -> Dict[str, Optional[torch.Tensor]]:
  346. text_token = batch['text_token'].to(device)
  347. text_token_len = batch['text_token_len'].to(device)
  348. speech_token = batch['speech_token'].to(device)
  349. speech_token_len = batch['speech_token_len'].to(device)
  350. reject_speech_token = batch['reject_speech_token'].to(device)
  351. reject_speech_token_len = batch['reject_speech_token_len'].to(device)
  352. # 1. encode text_token
  353. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  354. # 2. encode speech_token
  355. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  356. reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
  357. speech_token_combined = speech_token + reject_speech_token
  358. speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
  359. speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
  360. speech_token_combined_emb = self.speech_embedding(speech_token_combined)
  361. # 3. prepare llm_input/target
  362. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
  363. lm_target = lm_target.to(device)
  364. # 4. run lm forward
  365. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  366. logits = self.llm_decoder(lm_output)
  367. chosen_logits = logits[:text_token.shape[0]]
  368. rejected_logits = logits[text_token.shape[0]:]
  369. chosen_lm_target = lm_target[:text_token.shape[0]]
  370. rejected_lm_target = lm_target[text_token.shape[0]:]
  371. loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
  372. acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
  373. # 5. calculate dpo logits
  374. chosen_lm_mask = chosen_lm_target == IGNORE_ID
  375. rejected_lm_mask = rejected_lm_target == IGNORE_ID
  376. chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  377. rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  378. chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
  379. rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
  380. return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
  381. @torch.inference_mode()
  382. def inference(
  383. self,
  384. text: torch.Tensor,
  385. text_len: torch.Tensor,
  386. prompt_text: torch.Tensor,
  387. prompt_text_len: torch.Tensor,
  388. prompt_speech_token: torch.Tensor,
  389. prompt_speech_token_len: torch.Tensor,
  390. embedding: torch.Tensor,
  391. sampling: int = 25,
  392. max_token_text_ratio: float = 20,
  393. min_token_text_ratio: float = 2,
  394. uuid: str = '',
  395. ) -> Generator[torch.Tensor, None, None]:
  396. device = text.device
  397. text = torch.concat([prompt_text, text], dim=1)
  398. text_len += prompt_text_len
  399. text = self.llm.model.model.embed_tokens(text)
  400. # 3. concat llm_input
  401. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  402. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  403. if prompt_speech_token_len != 0:
  404. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  405. else:
  406. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  407. lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  408. # 4. cal min/max_length
  409. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  410. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  411. # 5. step by step decode
  412. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  413. yield token
  414. @torch.inference_mode()
  415. def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
  416. if hasattr(self, 'vllm'):
  417. from vllm import SamplingParams, RequestOutput
  418. sampling_params = SamplingParams(top_k=sampling,
  419. stop_token_ids=self.stop_token_ids,
  420. min_tokens=min_len,
  421. max_tokens=max_len)
  422. with self.lock:
  423. self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
  424. self.vllm_output_queue[uuid] = queue.Queue()
  425. out_tokens = []
  426. while True:
  427. with self.lock:
  428. if self.vllm_output_queue[uuid].empty() is True:
  429. request_outputs: List[RequestOutput] = self.vllm.step()
  430. for request_output in request_outputs:
  431. top_ids = list(request_output.outputs[0].token_ids)[-1]
  432. self.vllm_output_queue[request_output.request_id].put(top_ids)
  433. if self.vllm_output_queue[uuid].empty() is False:
  434. top_ids = self.vllm_output_queue[uuid].get()
  435. if top_ids in self.stop_token_ids:
  436. break
  437. # in stream mode, yield token one by one
  438. yield top_ids
  439. out_tokens.append(top_ids)
  440. if len(out_tokens) == max_len:
  441. break
  442. time.sleep(0.001)
  443. with self.lock:
  444. self.vllm_output_queue.pop(uuid)
  445. else:
  446. out_tokens = []
  447. cache = None
  448. for i in range(max_len):
  449. y_pred, cache = self.llm.forward_one_step(lm_input,
  450. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  451. cache=cache)
  452. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  453. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  454. if top_ids == self.speech_token_size:
  455. break
  456. if top_ids > self.speech_token_size:
  457. continue
  458. # in stream mode, yield token one by one
  459. yield top_ids
  460. out_tokens.append(top_ids)
  461. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  462. @torch.inference_mode()
  463. def inference_bistream(
  464. self,
  465. text: Generator,
  466. prompt_text: torch.Tensor,
  467. prompt_text_len: torch.Tensor,
  468. prompt_speech_token: torch.Tensor,
  469. prompt_speech_token_len: torch.Tensor,
  470. embedding: torch.Tensor,
  471. sampling: int = 25,
  472. max_token_text_ratio: float = 20,
  473. min_token_text_ratio: float = 2,
  474. ) -> Generator[torch.Tensor, None, None]:
  475. device = prompt_text.device
  476. # 1. prepare input
  477. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  478. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  479. if prompt_speech_token_len != 0:
  480. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  481. else:
  482. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  483. lm_input = torch.concat([sos_eos_emb], dim=1)
  484. # 2. iterate text
  485. out_tokens = []
  486. cache = None
  487. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  488. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  489. next_fill_index = -1
  490. for this_text in text:
  491. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  492. # prompt_speech_token_emb not empty, try append to lm_input
  493. while prompt_speech_token_emb.size(1) != 0:
  494. if text_cache.size(1) >= self.mix_ratio[0]:
  495. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  496. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  497. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  498. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  499. else:
  500. logging.info('not enough text token to decode, wait for more')
  501. break
  502. # no prompt_speech_token_emb remain, can decode some speech token
  503. if prompt_speech_token_emb.size(1) == 0:
  504. if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  505. logging.info('get fill token, need to append more text token')
  506. if text_cache.size(1) >= self.mix_ratio[0]:
  507. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  508. logging.info('append {} text token'.format(lm_input_text.size(1)))
  509. if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
  510. lm_input = lm_input_text
  511. else:
  512. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  513. text_cache = text_cache[:, self.mix_ratio[0]:]
  514. else:
  515. logging.info('not enough text token to decode, wait for more')
  516. continue
  517. while True:
  518. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  519. y_pred, cache = self.llm.forward_one_step(lm_input,
  520. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  521. cache=cache)
  522. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  523. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  524. top_ids = self.speech_token_size + 2
  525. next_fill_index += (self.mix_ratio[1] + 1)
  526. else:
  527. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
  528. if top_ids == self.speech_token_size + 2:
  529. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  530. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  531. out_tokens.append(top_ids)
  532. if top_ids >= self.speech_token_size:
  533. if top_ids == self.speech_token_size + 2:
  534. break
  535. else:
  536. raise ValueError('should not get token {}'.format(top_ids))
  537. yield top_ids
  538. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  539. # 3. final decode
  540. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  541. logging.info('no more text token, decode until met eos')
  542. while True:
  543. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  544. y_pred, cache = self.llm.forward_one_step(lm_input,
  545. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  546. cache=cache)
  547. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  548. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
  549. out_tokens.append(top_ids)
  550. if top_ids >= self.speech_token_size:
  551. if top_ids == self.speech_token_size:
  552. break
  553. else:
  554. raise ValueError('should not get token {}'.format(top_ids))
  555. # in stream mode, yield token one by one
  556. yield top_ids
  557. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)