llm.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import queue
  16. import random
  17. import time
  18. import threading
  19. from typing import Dict, Optional, Callable, List, Generator
  20. import torch
  21. from torch import nn
  22. import torch.nn.functional as F
  23. from transformers import Qwen2ForCausalLM
  24. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  25. from cosyvoice.utils.common import IGNORE_ID
  26. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  27. from cosyvoice.utils.common import th_accuracy
  28. from cosyvoice.utils.file_utils import logging
  29. from cosyvoice.utils.mask import make_pad_mask
  30. class TransformerLM(torch.nn.Module):
  31. def __init__(
  32. self,
  33. text_encoder_input_size: int,
  34. llm_input_size: int,
  35. llm_output_size: int,
  36. text_token_size: int,
  37. speech_token_size: int,
  38. text_encoder: torch.nn.Module,
  39. llm: torch.nn.Module,
  40. sampling: Callable,
  41. length_normalized_loss: bool = True,
  42. lsm_weight: float = 0.0,
  43. spk_embed_dim: int = 192,
  44. ):
  45. super().__init__()
  46. self.llm_input_size = llm_input_size
  47. self.speech_token_size = speech_token_size
  48. # 1. build text token inputs related modules
  49. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  50. self.text_encoder = text_encoder
  51. self.text_encoder_affine_layer = nn.Linear(
  52. self.text_encoder.output_size(),
  53. llm_input_size
  54. )
  55. # 2. build speech token language model related modules
  56. self.sos = 0
  57. self.task_id = 1
  58. self.eos_token = self.speech_token_size
  59. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  60. self.llm = llm
  61. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  62. self.criterion_ce = LabelSmoothingLoss(
  63. size=speech_token_size + 1,
  64. padding_idx=IGNORE_ID,
  65. smoothing=lsm_weight,
  66. normalize_length=length_normalized_loss,
  67. )
  68. # 3. [Optional] build speech token related modules
  69. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  70. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  71. # 4. sampling method
  72. self.sampling = sampling
  73. def encode(
  74. self,
  75. text: torch.Tensor,
  76. text_lengths: torch.Tensor,
  77. ):
  78. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  79. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  80. encoder_out = self.text_encoder_affine_layer(encoder_out)
  81. return encoder_out, encoder_out_lens
  82. def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  83. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  84. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  85. lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  86. for i in range(len(text_token))]
  87. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  88. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  89. return lm_input, lm_input_len
  90. def forward(
  91. self,
  92. batch: dict,
  93. device: torch.device,
  94. ) -> Dict[str, Optional[torch.Tensor]]:
  95. """
  96. Args:
  97. text: (B, L, D)
  98. text_lengths: (B,)
  99. audio: (B, T, N) or (B, T)
  100. audio_lengths: (B,)
  101. """
  102. text_token = batch['text_token'].to(device)
  103. text_token_len = batch['text_token_len'].to(device)
  104. speech_token = batch['speech_token'].to(device)
  105. speech_token_len = batch['speech_token_len'].to(device)
  106. embedding = batch['embedding'].to(device)
  107. # 1. prepare llm_target
  108. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  109. [self.speech_token_size]) for i in range(text_token.size(0))]
  110. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  111. # 1. encode text_token
  112. text_token = self.text_embedding(text_token)
  113. text_token, text_token_len = self.encode(text_token, text_token_len)
  114. # 2. embedding projection
  115. embedding = F.normalize(embedding, dim=1)
  116. embedding = self.spk_embed_affine_layer(embedding)
  117. embedding = embedding.unsqueeze(1)
  118. # 3. sos and task_id
  119. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  120. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  121. # 4. encode speech_token
  122. speech_token = self.speech_embedding(speech_token)
  123. # 5. unpad and pad
  124. lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
  125. task_id_emb, speech_token, speech_token_len)
  126. # 6. run lm forward
  127. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  128. logits = self.llm_decoder(lm_output)
  129. loss = self.criterion_ce(logits, lm_target)
  130. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  131. return {'loss': loss, 'acc': acc}
  132. def sampling_ids(
  133. self,
  134. weighted_scores: torch.Tensor,
  135. decoded_tokens: List,
  136. sampling: int,
  137. ignore_eos: bool = True,
  138. ):
  139. num_trials, max_trials = 0, 100
  140. while True:
  141. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  142. if (not ignore_eos) or (top_ids < self.speech_token_size):
  143. break
  144. num_trials += 1
  145. if num_trials > max_trials:
  146. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  147. return top_ids
  148. @torch.inference_mode()
  149. def inference(
  150. self,
  151. text: torch.Tensor,
  152. text_len: torch.Tensor,
  153. prompt_text: torch.Tensor,
  154. prompt_text_len: torch.Tensor,
  155. prompt_speech_token: torch.Tensor,
  156. prompt_speech_token_len: torch.Tensor,
  157. embedding: torch.Tensor,
  158. sampling: int = 25,
  159. max_token_text_ratio: float = 20,
  160. min_token_text_ratio: float = 2,
  161. uuid: str = '',
  162. ) -> Generator[torch.Tensor, None, None]:
  163. device = text.device
  164. text = torch.concat([prompt_text, text], dim=1)
  165. text_len += prompt_text_len
  166. text = self.text_embedding(text)
  167. # 1. encode text
  168. text, text_len = self.encode(text, text_len)
  169. # 2. encode embedding
  170. if embedding.shape[0] != 0:
  171. embedding = F.normalize(embedding, dim=1)
  172. embedding = self.spk_embed_affine_layer(embedding)
  173. embedding = embedding.unsqueeze(dim=1)
  174. else:
  175. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  176. # 3. concat llm_input
  177. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  178. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  179. if prompt_speech_token_len != 0:
  180. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  181. else:
  182. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  183. lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  184. # 4. cal min/max_length
  185. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  186. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  187. # 5. step by step decode
  188. out_tokens = []
  189. offset = 0
  190. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  191. for i in range(max_len):
  192. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  193. att_cache=att_cache, cnn_cache=cnn_cache,
  194. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  195. device=lm_input.device)).to(torch.bool))
  196. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  197. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  198. if top_ids == self.eos_token:
  199. break
  200. # in stream mode, yield token one by one
  201. yield top_ids
  202. out_tokens.append(top_ids)
  203. offset += lm_input.size(1)
  204. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  205. class Qwen2Encoder(torch.nn.Module):
  206. def __init__(self, pretrain_path):
  207. super().__init__()
  208. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  209. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  210. T = xs.size(1)
  211. masks = ~make_pad_mask(xs_lens, T)
  212. outs = self.model(
  213. inputs_embeds=xs,
  214. attention_mask=masks,
  215. output_hidden_states=True,
  216. return_dict=True,
  217. )
  218. return outs.hidden_states[-1], masks.unsqueeze(1)
  219. def forward_one_step(self, xs, masks, cache=None):
  220. input_masks = masks[:, -1, :]
  221. outs = self.model(
  222. inputs_embeds=xs,
  223. attention_mask=input_masks,
  224. output_hidden_states=True,
  225. return_dict=True,
  226. use_cache=True,
  227. past_key_values=cache,
  228. )
  229. xs = outs.hidden_states[-1]
  230. new_cache = outs.past_key_values
  231. return xs, new_cache
  232. class Qwen2LM(TransformerLM):
  233. def __init__(
  234. self,
  235. llm_input_size: int,
  236. llm_output_size: int,
  237. speech_token_size: int,
  238. llm: torch.nn.Module,
  239. sampling: Callable,
  240. length_normalized_loss: bool = True,
  241. lsm_weight: float = 0.0,
  242. mix_ratio: List[int] = [5, 15],
  243. ):
  244. torch.nn.Module.__init__(self)
  245. self.llm_input_size = llm_input_size
  246. self.llm_output_size = llm_output_size
  247. self.speech_token_size = speech_token_size
  248. # 2. build speech token language model related modules
  249. self.sos = 0
  250. self.task_id = 1
  251. self.eos_token = speech_token_size
  252. self.fill_token = speech_token_size + 2
  253. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  254. self.llm = llm
  255. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  256. self.criterion_ce = LabelSmoothingLoss(
  257. size=speech_token_size + 3,
  258. padding_idx=IGNORE_ID,
  259. smoothing=lsm_weight,
  260. normalize_length=length_normalized_loss,
  261. )
  262. # 3. [Optional] build speech token related modules
  263. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  264. # 4. sampling method
  265. self.sampling = sampling
  266. self.mix_ratio = mix_ratio
  267. # 5. vllm related
  268. self.stop_token_ids = [speech_token_size + i for i in range(3)]
  269. self.vllm_output_queue = {}
  270. def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
  271. lm_target, lm_input = [], []
  272. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  273. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  274. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  275. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  276. for i in range(len(text_token)):
  277. # bistream sequence
  278. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  279. this_lm_target, this_lm_input = [], []
  280. this_lm_target.append(IGNORE_ID)
  281. this_lm_input.append(sos_emb.squeeze(dim=0))
  282. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  283. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  284. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  285. if len(this_text_token) == self.mix_ratio[0]:
  286. assert len(this_speech_token) == self.mix_ratio[1]
  287. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  288. this_lm_target += this_speech_token
  289. this_lm_target.append(self.fill_token)
  290. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  291. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  292. else:
  293. this_lm_target += [-1] * len(this_text_token)
  294. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  295. this_lm_target.append(self.eos_token)
  296. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  297. this_lm_input.append(task_id_emb.squeeze(dim=0))
  298. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  299. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  300. # unistream sequence
  301. else:
  302. this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
  303. this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
  304. lm_target.append(this_lm_target)
  305. lm_input.append(this_lm_input)
  306. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  307. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  308. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  309. return lm_target, lm_input, lm_input_len
  310. def forward(
  311. self,
  312. batch: dict,
  313. device: torch.device,
  314. ) -> Dict[str, Optional[torch.Tensor]]:
  315. """
  316. Args:
  317. text: (B, L, D)
  318. text_lengths: (B,)
  319. audio: (B, T, N) or (B, T)
  320. audio_lengths: (B,)
  321. """
  322. text_token = batch['text_token'].to(device)
  323. text_token_len = batch['text_token_len'].to(device)
  324. speech_token = batch['speech_token'].to(device)
  325. speech_token_len = batch['speech_token_len'].to(device)
  326. # 1. encode text_token
  327. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  328. # 3. sos and task_id
  329. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  330. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  331. # 2. encode speech_token
  332. speech_token_emb = self.speech_embedding(speech_token)
  333. # 3. prepare llm_input/target
  334. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
  335. speech_token, speech_token_emb, speech_token_len)
  336. lm_target = lm_target.to(device)
  337. # 4. run lm forward
  338. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  339. logits = self.llm_decoder(lm_output)
  340. loss = self.criterion_ce(logits, lm_target.to(device))
  341. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  342. return {'loss': loss, 'acc': acc}
  343. def forward_dpo(
  344. self,
  345. batch: dict,
  346. device: torch.device,
  347. ) -> Dict[str, Optional[torch.Tensor]]:
  348. text_token = batch['text_token'].to(device)
  349. text_token_len = batch['text_token_len'].to(device)
  350. speech_token = batch['speech_token'].to(device)
  351. speech_token_len = batch['speech_token_len'].to(device)
  352. reject_speech_token = batch['reject_speech_token'].to(device)
  353. reject_speech_token_len = batch['reject_speech_token_len'].to(device)
  354. # 1. encode text_token
  355. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  356. # 3. sos and task_id
  357. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  358. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  359. # 2. encode speech_token
  360. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  361. reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
  362. speech_token_combined = speech_token + reject_speech_token
  363. speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
  364. speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
  365. speech_token_combined_emb = self.speech_embedding(speech_token_combined)
  366. # 3. prepare llm_input/target
  367. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
  368. task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
  369. lm_target = lm_target.to(device)
  370. # 4. run lm forward
  371. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  372. logits = self.llm_decoder(lm_output)
  373. chosen_logits = logits[:text_token.shape[0]]
  374. rejected_logits = logits[text_token.shape[0]:]
  375. chosen_lm_target = lm_target[:text_token.shape[0]]
  376. rejected_lm_target = lm_target[text_token.shape[0]:]
  377. loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
  378. acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
  379. # 5. calculate dpo logits
  380. chosen_lm_mask = chosen_lm_target == IGNORE_ID
  381. rejected_lm_mask = rejected_lm_target == IGNORE_ID
  382. chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  383. rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  384. chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
  385. rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
  386. return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
  387. @torch.inference_mode()
  388. def inference(
  389. self,
  390. text: torch.Tensor,
  391. text_len: torch.Tensor,
  392. prompt_text: torch.Tensor,
  393. prompt_text_len: torch.Tensor,
  394. prompt_speech_token: torch.Tensor,
  395. prompt_speech_token_len: torch.Tensor,
  396. embedding: torch.Tensor,
  397. sampling: int = 25,
  398. max_token_text_ratio: float = 20,
  399. min_token_text_ratio: float = 2,
  400. uuid: str = '',
  401. ) -> Generator[torch.Tensor, None, None]:
  402. device = text.device
  403. text = torch.concat([prompt_text, text], dim=1)
  404. text_len += prompt_text_len
  405. text = self.llm.model.model.embed_tokens(text)
  406. # 3. concat llm_input
  407. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  408. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  409. if prompt_speech_token_len != 0:
  410. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  411. else:
  412. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  413. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  414. # 4. cal min/max_length
  415. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  416. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  417. # 5. step by step decode
  418. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  419. yield token
  420. @torch.inference_mode()
  421. def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
  422. if hasattr(self, 'vllm'):
  423. from vllm import SamplingParams, RequestOutput
  424. sampling_params = SamplingParams(top_k=sampling,
  425. stop_token_ids=self.stop_token_ids,
  426. min_tokens=min_len,
  427. max_tokens=max_len)
  428. with self.lock:
  429. self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
  430. self.vllm_output_queue[uuid] = queue.Queue()
  431. out_tokens = []
  432. while True:
  433. with self.lock:
  434. if self.vllm_output_queue[uuid].empty() is True:
  435. request_outputs: List[RequestOutput] = self.vllm.step()
  436. for request_output in request_outputs:
  437. top_ids = list(request_output.outputs[0].token_ids)[-1]
  438. self.vllm_output_queue[request_output.request_id].put(top_ids)
  439. if self.vllm_output_queue[uuid].empty() is False:
  440. top_ids = self.vllm_output_queue[uuid].get()
  441. if top_ids in self.stop_token_ids:
  442. break
  443. # in stream mode, yield token one by one
  444. yield top_ids
  445. out_tokens.append(top_ids)
  446. if len(out_tokens) == max_len:
  447. break
  448. time.sleep(0.001)
  449. with self.lock:
  450. self.vllm_output_queue.pop(uuid)
  451. else:
  452. out_tokens = []
  453. cache = None
  454. for i in range(max_len):
  455. y_pred, cache = self.llm.forward_one_step(lm_input,
  456. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  457. cache=cache)
  458. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  459. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
  460. if top_ids in self.stop_token_ids:
  461. break
  462. # in stream mode, yield token one by one
  463. yield top_ids
  464. out_tokens.append(top_ids)
  465. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  466. @torch.inference_mode()
  467. def inference_bistream(
  468. self,
  469. text: Generator,
  470. prompt_text: torch.Tensor,
  471. prompt_text_len: torch.Tensor,
  472. prompt_speech_token: torch.Tensor,
  473. prompt_speech_token_len: torch.Tensor,
  474. embedding: torch.Tensor,
  475. sampling: int = 25,
  476. max_token_text_ratio: float = 20,
  477. min_token_text_ratio: float = 2,
  478. ) -> Generator[torch.Tensor, None, None]:
  479. device = prompt_text.device
  480. # 1. prepare input
  481. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  482. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  483. if prompt_speech_token_len != 0:
  484. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  485. else:
  486. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  487. lm_input = torch.concat([sos_emb], dim=1)
  488. # 2. iterate text
  489. out_tokens = []
  490. cache = None
  491. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  492. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  493. next_fill_index = -1
  494. for this_text in text:
  495. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  496. # prompt_speech_token_emb not empty, try append to lm_input
  497. while prompt_speech_token_emb.size(1) != 0:
  498. if text_cache.size(1) >= self.mix_ratio[0]:
  499. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  500. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  501. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  502. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  503. else:
  504. logging.info('not enough text token to decode, wait for more')
  505. break
  506. # no prompt_speech_token_emb remain, can decode some speech token
  507. if prompt_speech_token_emb.size(1) == 0:
  508. if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  509. logging.info('get fill token, need to append more text token')
  510. if text_cache.size(1) >= self.mix_ratio[0]:
  511. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  512. logging.info('append {} text token'.format(lm_input_text.size(1)))
  513. if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
  514. lm_input = lm_input_text
  515. else:
  516. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  517. text_cache = text_cache[:, self.mix_ratio[0]:]
  518. else:
  519. logging.info('not enough text token to decode, wait for more')
  520. continue
  521. while True:
  522. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  523. y_pred, cache = self.llm.forward_one_step(lm_input,
  524. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  525. cache=cache)
  526. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  527. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  528. top_ids = self.fill_token
  529. next_fill_index += (self.mix_ratio[1] + 1)
  530. else:
  531. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
  532. if top_ids == self.fill_token:
  533. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  534. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  535. out_tokens.append(top_ids)
  536. if top_ids >= self.speech_token_size:
  537. if top_ids == self.fill_token:
  538. break
  539. else:
  540. raise ValueError('should not get token {}'.format(top_ids))
  541. yield top_ids
  542. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  543. # 3. final decode
  544. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  545. logging.info('no more text token, decode until met eos')
  546. while True:
  547. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  548. y_pred, cache = self.llm.forward_one_step(lm_input,
  549. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  550. cache=cache)
  551. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  552. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
  553. out_tokens.append(top_ids)
  554. if top_ids >= self.speech_token_size:
  555. if top_ids == self.eos_token:
  556. break
  557. else:
  558. raise ValueError('should not get token {}'.format(top_ids))
  559. # in stream mode, yield token one by one
  560. yield top_ids
  561. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  562. class CosyVoice3LM(Qwen2LM):
  563. def __init__(
  564. self,
  565. llm_input_size: int,
  566. llm_output_size: int,
  567. speech_token_size: int,
  568. llm: torch.nn.Module,
  569. sampling: Callable,
  570. length_normalized_loss: bool = True,
  571. lsm_weight: float = 0.0,
  572. mix_ratio: List[int] = [5, 15],
  573. ):
  574. torch.nn.Module.__init__(self)
  575. self.llm_input_size = llm_input_size
  576. self.llm_output_size = llm_output_size
  577. self.speech_token_size = speech_token_size
  578. # 2. build speech token language model related modules
  579. self.sos = speech_token_size + 0
  580. self.eos_token = speech_token_size + 1
  581. self.task_id = speech_token_size + 2
  582. self.fill_token = speech_token_size + 3
  583. self.llm = llm
  584. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
  585. self.criterion_ce = LabelSmoothingLoss(
  586. size=speech_token_size + 200,
  587. padding_idx=IGNORE_ID,
  588. smoothing=lsm_weight,
  589. normalize_length=length_normalized_loss,
  590. )
  591. # 3. [Optional] build speech token related modules
  592. self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
  593. # 4. sampling method
  594. self.sampling = sampling
  595. self.mix_ratio = mix_ratio
  596. # 5. vllm related
  597. self.stop_token_ids = [speech_token_size + i for i in range(200)]
  598. self.vllm_output_queue = {}
  599. def forward(
  600. self,
  601. batch: dict,
  602. device: torch.device,
  603. ) -> Dict[str, Optional[torch.Tensor]]:
  604. """
  605. Args:
  606. text: (B, L, D)
  607. text_lengths: (B,)
  608. audio: (B, T, N) or (B, T)
  609. audio_lengths: (B,)
  610. """
  611. text_token = batch['text_token'].to(device)
  612. text_token_len = batch['text_token_len'].to(device)
  613. speech_token = batch['speech_token'].to(device)
  614. speech_token_len = batch['speech_token_len'].to(device)
  615. # 1. encode text_token
  616. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  617. # 3. sos and task_id
  618. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  619. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  620. # 2. encode speech_token
  621. speech_token_emb = self.speech_embedding(speech_token)
  622. # 3. prepare llm_input/target
  623. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
  624. speech_token, speech_token_emb, speech_token_len)
  625. lm_target = lm_target.to(device)
  626. # 4. run lm forward
  627. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  628. logits = self.llm_decoder(lm_output)
  629. loss = self.criterion_ce(logits, lm_target.to(device))
  630. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  631. return {'loss': loss, 'acc': acc}
  632. @torch.inference_mode()
  633. def inference(
  634. self,
  635. text: torch.Tensor,
  636. text_len: torch.Tensor,
  637. prompt_text: torch.Tensor,
  638. prompt_text_len: torch.Tensor,
  639. prompt_speech_token: torch.Tensor,
  640. prompt_speech_token_len: torch.Tensor,
  641. embedding: torch.Tensor,
  642. sampling: int = 25,
  643. max_token_text_ratio: float = 20,
  644. min_token_text_ratio: float = 2,
  645. uuid: str = '',
  646. ) -> Generator[torch.Tensor, None, None]:
  647. device = text.device
  648. text = torch.concat([prompt_text, text], dim=1)
  649. text_len += prompt_text_len
  650. text = self.llm.model.model.embed_tokens(text)
  651. # 3. concat llm_input
  652. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  653. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  654. if prompt_speech_token_len != 0:
  655. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  656. else:
  657. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  658. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  659. # 4. cal min/max_length
  660. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  661. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  662. # 5. step by step decode
  663. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  664. yield token