llm.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import queue
  16. import random
  17. import time
  18. import threading
  19. from typing import Dict, Optional, Callable, List, Generator
  20. import torch
  21. from torch import nn
  22. import torch.nn.functional as F
  23. from transformers import Qwen2ForCausalLM
  24. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  25. from cosyvoice.utils.common import IGNORE_ID
  26. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  27. from cosyvoice.utils.common import th_accuracy
  28. from cosyvoice.utils.file_utils import logging
  29. from cosyvoice.utils.mask import make_pad_mask
  30. class TransformerLM(torch.nn.Module):
  31. def __init__(
  32. self,
  33. text_encoder_input_size: int,
  34. llm_input_size: int,
  35. llm_output_size: int,
  36. text_token_size: int,
  37. speech_token_size: int,
  38. text_encoder: torch.nn.Module,
  39. llm: torch.nn.Module,
  40. sampling: Callable,
  41. length_normalized_loss: bool = True,
  42. lsm_weight: float = 0.0,
  43. spk_embed_dim: int = 192,
  44. ):
  45. super().__init__()
  46. self.llm_input_size = llm_input_size
  47. self.speech_token_size = speech_token_size
  48. # 1. build text token inputs related modules
  49. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  50. self.text_encoder = text_encoder
  51. self.text_encoder_affine_layer = nn.Linear(
  52. self.text_encoder.output_size(),
  53. llm_input_size
  54. )
  55. # 2. build speech token language model related modules
  56. self.sos = 0
  57. self.task_id = 1
  58. self.eos_token = self.speech_token_size
  59. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  60. self.llm = llm
  61. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  62. self.criterion_ce = LabelSmoothingLoss(
  63. size=speech_token_size + 1,
  64. padding_idx=IGNORE_ID,
  65. smoothing=lsm_weight,
  66. normalize_length=length_normalized_loss,
  67. )
  68. # 3. [Optional] build speech token related modules
  69. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  70. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  71. # 4. sampling method
  72. self.sampling = sampling
  73. def encode(
  74. self,
  75. text: torch.Tensor,
  76. text_lengths: torch.Tensor,
  77. ):
  78. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  79. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  80. encoder_out = self.text_encoder_affine_layer(encoder_out)
  81. return encoder_out, encoder_out_lens
  82. def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  83. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  84. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  85. lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  86. for i in range(len(text_token))]
  87. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  88. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  89. return lm_input, lm_input_len
  90. def forward(
  91. self,
  92. batch: dict,
  93. device: torch.device,
  94. ) -> Dict[str, Optional[torch.Tensor]]:
  95. """
  96. Args:
  97. text: (B, L, D)
  98. text_lengths: (B,)
  99. audio: (B, T, N) or (B, T)
  100. audio_lengths: (B,)
  101. """
  102. text_token = batch['text_token'].to(device)
  103. text_token_len = batch['text_token_len'].to(device)
  104. speech_token = batch['speech_token'].to(device)
  105. speech_token_len = batch['speech_token_len'].to(device)
  106. embedding = batch['embedding'].to(device)
  107. # 1. prepare llm_target
  108. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  109. [self.speech_token_size]) for i in range(text_token.size(0))]
  110. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  111. # 1. encode text_token
  112. text_token = self.text_embedding(text_token)
  113. text_token, text_token_len = self.encode(text_token, text_token_len)
  114. # 2. embedding projection
  115. embedding = F.normalize(embedding, dim=1)
  116. embedding = self.spk_embed_affine_layer(embedding)
  117. embedding = embedding.unsqueeze(1)
  118. # 3. eos and task_id
  119. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  120. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  121. # 4. encode speech_token
  122. speech_token = self.speech_embedding(speech_token)
  123. # 5. unpad and pad
  124. lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
  125. task_id_emb, speech_token, speech_token_len)
  126. # 6. run lm forward
  127. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  128. logits = self.llm_decoder(lm_output)
  129. loss = self.criterion_ce(logits, lm_target)
  130. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  131. return {'loss': loss, 'acc': acc}
  132. def sampling_ids(
  133. self,
  134. weighted_scores: torch.Tensor,
  135. decoded_tokens: List,
  136. sampling: int,
  137. ignore_eos: bool = True,
  138. ):
  139. num_trials, max_trials = 0, 100
  140. while True:
  141. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  142. if (not ignore_eos) or (self.speech_token_size not in top_ids):
  143. break
  144. num_trials += 1
  145. if num_trials > max_trials:
  146. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  147. return top_ids
  148. @torch.inference_mode()
  149. def inference(
  150. self,
  151. text: torch.Tensor,
  152. text_len: torch.Tensor,
  153. prompt_text: torch.Tensor,
  154. prompt_text_len: torch.Tensor,
  155. prompt_speech_token: torch.Tensor,
  156. prompt_speech_token_len: torch.Tensor,
  157. embedding: torch.Tensor,
  158. sampling: int = 25,
  159. max_token_text_ratio: float = 20,
  160. min_token_text_ratio: float = 2,
  161. uuid: str = '',
  162. ) -> Generator[torch.Tensor, None, None]:
  163. device = text.device
  164. text = torch.concat([prompt_text, text], dim=1)
  165. text_len += prompt_text_len
  166. text = self.text_embedding(text)
  167. # 1. encode text
  168. text, text_len = self.encode(text, text_len)
  169. # 2. encode embedding
  170. if embedding.shape[0] != 0:
  171. embedding = F.normalize(embedding, dim=1)
  172. embedding = self.spk_embed_affine_layer(embedding)
  173. embedding = embedding.unsqueeze(dim=1)
  174. else:
  175. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  176. # 3. concat llm_input
  177. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  178. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  179. if prompt_speech_token_len != 0:
  180. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  181. else:
  182. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  183. lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  184. # 4. cal min/max_length
  185. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  186. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  187. # 5. step by step decode
  188. out_tokens = []
  189. offset = 0
  190. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  191. for i in range(max_len):
  192. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  193. att_cache=att_cache, cnn_cache=cnn_cache,
  194. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  195. device=lm_input.device)).to(torch.bool))
  196. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  197. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  198. if top_ids == self.eos_token:
  199. break
  200. # in stream mode, yield token one by one
  201. yield top_ids
  202. out_tokens.append(top_ids)
  203. offset += lm_input.size(1)
  204. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  205. class Qwen2Encoder(torch.nn.Module):
  206. def __init__(self, pretrain_path):
  207. super().__init__()
  208. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  209. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  210. T = xs.size(1)
  211. masks = ~make_pad_mask(xs_lens, T)
  212. outs = self.model(
  213. inputs_embeds=xs,
  214. attention_mask=masks,
  215. output_hidden_states=True,
  216. return_dict=True,
  217. )
  218. return outs.hidden_states[-1], masks.unsqueeze(1)
  219. def forward_one_step(self, xs, masks, cache=None):
  220. input_masks = masks[:, -1, :]
  221. outs = self.model(
  222. inputs_embeds=xs,
  223. attention_mask=input_masks,
  224. output_hidden_states=True,
  225. return_dict=True,
  226. use_cache=True,
  227. past_key_values=cache,
  228. )
  229. xs = outs.hidden_states[-1]
  230. new_cache = outs.past_key_values
  231. return xs, new_cache
  232. class Qwen2LM(TransformerLM):
  233. def __init__(
  234. self,
  235. llm_input_size: int,
  236. llm_output_size: int,
  237. speech_token_size: int,
  238. llm: torch.nn.Module,
  239. sampling: Callable,
  240. length_normalized_loss: bool = True,
  241. lsm_weight: float = 0.0,
  242. mix_ratio: List[int] = [5, 15],
  243. ):
  244. torch.nn.Module.__init__(self)
  245. self.llm_input_size = llm_input_size
  246. self.llm_output_size = llm_output_size
  247. self.speech_token_size = speech_token_size
  248. # 2. build speech token language model related modules
  249. self.sos = 0
  250. self.task_id = 1
  251. self.eos_token = speech_token_size
  252. self.fill_token = speech_token_size + 2
  253. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  254. self.llm = llm
  255. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  256. self.criterion_ce = LabelSmoothingLoss(
  257. size=speech_token_size + 3,
  258. padding_idx=IGNORE_ID,
  259. smoothing=lsm_weight,
  260. normalize_length=length_normalized_loss,
  261. )
  262. # 3. [Optional] build speech token related modules
  263. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  264. # 4. sampling method
  265. self.sampling = sampling
  266. self.mix_ratio = mix_ratio
  267. # 5. vllm related
  268. self.stop_token_ids = [speech_token_size + i for i in range(3)]
  269. self.vllm_output_queue = {}
  270. def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
  271. lm_target, lm_input = [], []
  272. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  273. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  274. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  275. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  276. for i in range(len(text_token)):
  277. # bistream sequence
  278. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  279. this_lm_target, this_lm_input = [], []
  280. this_lm_target.append(IGNORE_ID)
  281. this_lm_input.append(self.llm_embedding.weight[self.sos].reshape(1, -1))
  282. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  283. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  284. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  285. if len(this_text_token) == self.mix_ratio[0]:
  286. assert len(this_speech_token) == self.mix_ratio[1]
  287. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  288. this_lm_target += this_speech_token
  289. this_lm_target.append(self.fill_token)
  290. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  291. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  292. else:
  293. this_lm_target += [-1] * len(this_text_token)
  294. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  295. this_lm_target.append(self.eos_token)
  296. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  297. this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
  298. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  299. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  300. # unistream sequence
  301. else:
  302. this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
  303. this_lm_input = torch.concat([self.llm_embedding.weight[self.sos].reshape(1, -1), text_token_emb[i],
  304. self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
  305. lm_target.append(this_lm_target)
  306. lm_input.append(this_lm_input)
  307. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  308. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  309. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  310. return lm_target, lm_input, lm_input_len
  311. def forward(
  312. self,
  313. batch: dict,
  314. device: torch.device,
  315. ) -> Dict[str, Optional[torch.Tensor]]:
  316. """
  317. Args:
  318. text: (B, L, D)
  319. text_lengths: (B,)
  320. audio: (B, T, N) or (B, T)
  321. audio_lengths: (B,)
  322. """
  323. text_token = batch['text_token'].to(device)
  324. text_token_len = batch['text_token_len'].to(device)
  325. speech_token = batch['speech_token'].to(device)
  326. speech_token_len = batch['speech_token_len'].to(device)
  327. # 1. encode text_token
  328. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  329. # 2. encode speech_token
  330. speech_token_emb = self.speech_embedding(speech_token)
  331. # 3. prepare llm_input/target
  332. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
  333. lm_target = lm_target.to(device)
  334. # 4. run lm forward
  335. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  336. logits = self.llm_decoder(lm_output)
  337. loss = self.criterion_ce(logits, lm_target.to(device))
  338. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  339. return {'loss': loss, 'acc': acc}
  340. def forward_dpo(
  341. self,
  342. batch: dict,
  343. device: torch.device,
  344. ) -> Dict[str, Optional[torch.Tensor]]:
  345. text_token = batch['text_token'].to(device)
  346. text_token_len = batch['text_token_len'].to(device)
  347. speech_token = batch['speech_token'].to(device)
  348. speech_token_len = batch['speech_token_len'].to(device)
  349. reject_speech_token = batch['reject_speech_token'].to(device)
  350. reject_speech_token_len = batch['reject_speech_token_len'].to(device)
  351. # 1. encode text_token
  352. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  353. # 2. encode speech_token
  354. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  355. reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
  356. speech_token_combined = speech_token + reject_speech_token
  357. speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
  358. speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
  359. speech_token_combined_emb = self.speech_embedding(speech_token_combined)
  360. # 3. prepare llm_input/target
  361. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
  362. speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
  363. lm_target = lm_target.to(device)
  364. # 4. run lm forward
  365. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  366. logits = self.llm_decoder(lm_output)
  367. chosen_logits = logits[:text_token.shape[0]]
  368. rejected_logits = logits[text_token.shape[0]:]
  369. chosen_lm_target = lm_target[:text_token.shape[0]]
  370. rejected_lm_target = lm_target[text_token.shape[0]:]
  371. loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
  372. acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
  373. # 5. calculate dpo logits
  374. chosen_lm_mask = chosen_lm_target == IGNORE_ID
  375. rejected_lm_mask = rejected_lm_target == IGNORE_ID
  376. chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  377. rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  378. chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
  379. rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
  380. return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
  381. @torch.inference_mode()
  382. def inference(
  383. self,
  384. text: torch.Tensor,
  385. text_len: torch.Tensor,
  386. prompt_text: torch.Tensor,
  387. prompt_text_len: torch.Tensor,
  388. prompt_speech_token: torch.Tensor,
  389. prompt_speech_token_len: torch.Tensor,
  390. embedding: torch.Tensor,
  391. sampling: int = 25,
  392. max_token_text_ratio: float = 20,
  393. min_token_text_ratio: float = 2,
  394. uuid: str = '',
  395. ) -> Generator[torch.Tensor, None, None]:
  396. device = text.device
  397. text = torch.concat([prompt_text, text], dim=1)
  398. text_len += prompt_text_len
  399. text = self.llm.model.model.embed_tokens(text)
  400. # 3. concat llm_input
  401. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  402. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  403. if prompt_speech_token_len != 0:
  404. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  405. else:
  406. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  407. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  408. # 4. cal min/max_length
  409. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  410. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  411. # 5. step by step decode
  412. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  413. yield token
  414. @torch.inference_mode()
  415. def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
  416. if hasattr(self, 'vllm'):
  417. from vllm import SamplingParams, RequestOutput
  418. sampling_params = SamplingParams(top_k=sampling,
  419. stop_token_ids=self.stop_token_ids,
  420. min_tokens=min_len,
  421. max_tokens=max_len)
  422. with self.lock:
  423. self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
  424. self.vllm_output_queue[uuid] = queue.Queue()
  425. out_tokens = []
  426. while True:
  427. with self.lock:
  428. if self.vllm_output_queue[uuid].empty() is True:
  429. request_outputs: List[RequestOutput] = self.vllm.step()
  430. for request_output in request_outputs:
  431. top_ids = list(request_output.outputs[0].token_ids)[-1]
  432. self.vllm_output_queue[request_output.request_id].put(top_ids)
  433. if self.vllm_output_queue[uuid].empty() is False:
  434. top_ids = self.vllm_output_queue[uuid].get()
  435. if top_ids in self.stop_token_ids:
  436. break
  437. # in stream mode, yield token one by one
  438. yield top_ids
  439. out_tokens.append(top_ids)
  440. if len(out_tokens) == max_len:
  441. break
  442. time.sleep(0.001)
  443. with self.lock:
  444. self.vllm_output_queue.pop(uuid)
  445. else:
  446. out_tokens = []
  447. cache = None
  448. for i in range(max_len):
  449. y_pred, cache = self.llm.forward_one_step(lm_input,
  450. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  451. cache=cache)
  452. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  453. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  454. if top_ids in self.stop_token_ids:
  455. break
  456. # in stream mode, yield token one by one
  457. yield top_ids
  458. out_tokens.append(top_ids)
  459. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  460. @torch.inference_mode()
  461. def inference_bistream(
  462. self,
  463. text: Generator,
  464. prompt_text: torch.Tensor,
  465. prompt_text_len: torch.Tensor,
  466. prompt_speech_token: torch.Tensor,
  467. prompt_speech_token_len: torch.Tensor,
  468. embedding: torch.Tensor,
  469. sampling: int = 25,
  470. max_token_text_ratio: float = 20,
  471. min_token_text_ratio: float = 2,
  472. ) -> Generator[torch.Tensor, None, None]:
  473. device = prompt_text.device
  474. # 1. prepare input
  475. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  476. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  477. if prompt_speech_token_len != 0:
  478. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  479. else:
  480. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  481. lm_input = torch.concat([sos_emb], dim=1)
  482. # 2. iterate text
  483. out_tokens = []
  484. cache = None
  485. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  486. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  487. next_fill_index = -1
  488. for this_text in text:
  489. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  490. # prompt_speech_token_emb not empty, try append to lm_input
  491. while prompt_speech_token_emb.size(1) != 0:
  492. if text_cache.size(1) >= self.mix_ratio[0]:
  493. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  494. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  495. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  496. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  497. else:
  498. logging.info('not enough text token to decode, wait for more')
  499. break
  500. # no prompt_speech_token_emb remain, can decode some speech token
  501. if prompt_speech_token_emb.size(1) == 0:
  502. if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  503. logging.info('get fill token, need to append more text token')
  504. if text_cache.size(1) >= self.mix_ratio[0]:
  505. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  506. logging.info('append {} text token'.format(lm_input_text.size(1)))
  507. if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
  508. lm_input = lm_input_text
  509. else:
  510. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  511. text_cache = text_cache[:, self.mix_ratio[0]:]
  512. else:
  513. logging.info('not enough text token to decode, wait for more')
  514. continue
  515. while True:
  516. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  517. y_pred, cache = self.llm.forward_one_step(lm_input,
  518. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  519. cache=cache)
  520. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  521. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  522. top_ids = self.fill_token
  523. next_fill_index += (self.mix_ratio[1] + 1)
  524. else:
  525. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
  526. if top_ids == self.fill_token:
  527. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  528. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  529. out_tokens.append(top_ids)
  530. if top_ids >= self.speech_token_size:
  531. if top_ids == self.fill_token:
  532. break
  533. else:
  534. raise ValueError('should not get token {}'.format(top_ids))
  535. yield top_ids
  536. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  537. # 3. final decode
  538. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  539. logging.info('no more text token, decode until met eos')
  540. while True:
  541. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  542. y_pred, cache = self.llm.forward_one_step(lm_input,
  543. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  544. cache=cache)
  545. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  546. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
  547. out_tokens.append(top_ids)
  548. if top_ids >= self.speech_token_size:
  549. if top_ids == self.eos_token:
  550. break
  551. else:
  552. raise ValueError('should not get token {}'.format(top_ids))
  553. # in stream mode, yield token one by one
  554. yield top_ids
  555. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  556. class CosyVoice3LM(Qwen2LM):
  557. def __init__(
  558. self,
  559. llm_input_size: int,
  560. llm_output_size: int,
  561. speech_token_size: int,
  562. llm: torch.nn.Module,
  563. sampling: Callable,
  564. length_normalized_loss: bool = True,
  565. lsm_weight: float = 0.0,
  566. mix_ratio: List[int] = [5, 15],
  567. ):
  568. torch.nn.Module.__init__(self)
  569. self.llm_input_size = llm_input_size
  570. self.llm_output_size = llm_output_size
  571. self.speech_token_size = speech_token_size
  572. # 2. build speech token language model related modules
  573. self.sos = speech_token_size + 0
  574. self.eos_token = speech_token_size + 1
  575. self.task_id = speech_token_size + 2
  576. self.fill_token = speech_token_size + 3
  577. self.llm = llm
  578. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
  579. self.criterion_ce = LabelSmoothingLoss(
  580. size=speech_token_size + 200,
  581. padding_idx=IGNORE_ID,
  582. smoothing=lsm_weight,
  583. normalize_length=length_normalized_loss,
  584. )
  585. # 3. [Optional] build speech token related modules
  586. self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
  587. # 4. sampling method
  588. self.sampling = sampling
  589. self.mix_ratio = mix_ratio
  590. # 5. vllm related
  591. self.stop_token_ids = [speech_token_size + i for i in range(4)]
  592. self.vllm_output_queue = {}
  593. @torch.inference_mode()
  594. def inference(
  595. self,
  596. text: torch.Tensor,
  597. text_len: torch.Tensor,
  598. prompt_text: torch.Tensor,
  599. prompt_text_len: torch.Tensor,
  600. prompt_speech_token: torch.Tensor,
  601. prompt_speech_token_len: torch.Tensor,
  602. embedding: torch.Tensor,
  603. sampling: int = 25,
  604. max_token_text_ratio: float = 20,
  605. min_token_text_ratio: float = 2,
  606. uuid: str = '',
  607. ) -> Generator[torch.Tensor, None, None]:
  608. device = text.device
  609. text = torch.concat([prompt_text, text], dim=1)
  610. text_len += prompt_text_len
  611. text = self.llm.model.model.embed_tokens(text)
  612. # 3. concat llm_input
  613. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  614. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  615. if prompt_speech_token_len != 0:
  616. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  617. else:
  618. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  619. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  620. # 4. cal min/max_length
  621. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  622. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  623. # 5. step by step decode
  624. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  625. yield token