llm.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import queue
  16. import random
  17. import time
  18. import threading
  19. from typing import Dict, Optional, Callable, List, Generator
  20. import torch
  21. from torch import nn
  22. import torch.nn.functional as F
  23. from transformers import Qwen2ForCausalLM
  24. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  25. from cosyvoice.utils.common import IGNORE_ID
  26. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  27. from cosyvoice.utils.common import th_accuracy
  28. from cosyvoice.utils.file_utils import logging
  29. from cosyvoice.utils.mask import make_pad_mask
  30. class TransformerLM(torch.nn.Module):
  31. def __init__(
  32. self,
  33. text_encoder_input_size: int,
  34. llm_input_size: int,
  35. llm_output_size: int,
  36. text_token_size: int,
  37. speech_token_size: int,
  38. text_encoder: torch.nn.Module,
  39. llm: torch.nn.Module,
  40. sampling: Callable,
  41. length_normalized_loss: bool = True,
  42. lsm_weight: float = 0.0,
  43. spk_embed_dim: int = 192,
  44. ):
  45. super().__init__()
  46. self.llm_input_size = llm_input_size
  47. self.speech_token_size = speech_token_size
  48. # 1. build text token inputs related modules
  49. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  50. self.text_encoder = text_encoder
  51. self.text_encoder_affine_layer = nn.Linear(
  52. self.text_encoder.output_size(),
  53. llm_input_size
  54. )
  55. # 2. build speech token language model related modules
  56. self.sos = 0
  57. self.task_id = 1
  58. self.eos_token = self.speech_token_size
  59. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  60. self.llm = llm
  61. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  62. self.criterion_ce = LabelSmoothingLoss(
  63. size=speech_token_size + 1,
  64. padding_idx=IGNORE_ID,
  65. smoothing=lsm_weight,
  66. normalize_length=length_normalized_loss,
  67. )
  68. # 3. [Optional] build speech token related modules
  69. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  70. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  71. # 4. sampling method
  72. self.sampling = sampling
  73. def encode(
  74. self,
  75. text: torch.Tensor,
  76. text_lengths: torch.Tensor,
  77. ):
  78. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  79. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  80. encoder_out = self.text_encoder_affine_layer(encoder_out)
  81. return encoder_out, encoder_out_lens
  82. def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  83. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  84. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  85. lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  86. for i in range(len(text_token))]
  87. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  88. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  89. return lm_input, lm_input_len
  90. def forward(
  91. self,
  92. batch: dict,
  93. device: torch.device,
  94. ) -> Dict[str, Optional[torch.Tensor]]:
  95. """
  96. Args:
  97. text: (B, L, D)
  98. text_lengths: (B,)
  99. audio: (B, T, N) or (B, T)
  100. audio_lengths: (B,)
  101. """
  102. text_token = batch['text_token'].to(device)
  103. text_token_len = batch['text_token_len'].to(device)
  104. speech_token = batch['speech_token'].to(device)
  105. speech_token_len = batch['speech_token_len'].to(device)
  106. embedding = batch['embedding'].to(device)
  107. # 1. prepare llm_target
  108. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  109. [self.speech_token_size]) for i in range(text_token.size(0))]
  110. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  111. # 1. encode text_token
  112. text_token = self.text_embedding(text_token)
  113. text_token, text_token_len = self.encode(text_token, text_token_len)
  114. # 2. embedding projection
  115. embedding = F.normalize(embedding, dim=1)
  116. embedding = self.spk_embed_affine_layer(embedding)
  117. embedding = embedding.unsqueeze(1)
  118. # 3. sos and task_id
  119. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  120. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  121. # 4. encode speech_token
  122. speech_token = self.speech_embedding(speech_token)
  123. # 5. unpad and pad
  124. lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
  125. task_id_emb, speech_token, speech_token_len)
  126. # 6. run lm forward
  127. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  128. logits = self.llm_decoder(lm_output)
  129. loss = self.criterion_ce(logits, lm_target)
  130. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  131. return {'loss': loss, 'acc': acc}
  132. def sampling_ids(
  133. self,
  134. weighted_scores: torch.Tensor,
  135. decoded_tokens: List,
  136. sampling: int,
  137. ignore_eos: bool = True,
  138. ):
  139. num_trials, max_trials = 0, 100
  140. while True:
  141. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  142. if (not ignore_eos) or (top_ids < self.speech_token_size):
  143. break
  144. num_trials += 1
  145. if num_trials > max_trials:
  146. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  147. return top_ids
  148. @torch.inference_mode()
  149. def inference(
  150. self,
  151. text: torch.Tensor,
  152. text_len: torch.Tensor,
  153. prompt_text: torch.Tensor,
  154. prompt_text_len: torch.Tensor,
  155. prompt_speech_token: torch.Tensor,
  156. prompt_speech_token_len: torch.Tensor,
  157. embedding: torch.Tensor,
  158. sampling: int = 25,
  159. max_token_text_ratio: float = 20,
  160. min_token_text_ratio: float = 2,
  161. uuid: str = '',
  162. ) -> Generator[torch.Tensor, None, None]:
  163. device = text.device
  164. text = torch.concat([prompt_text, text], dim=1)
  165. text_len += prompt_text_len
  166. text = self.text_embedding(text)
  167. # 1. encode text
  168. text, text_len = self.encode(text, text_len)
  169. # 2. encode embedding
  170. if embedding.shape[0] != 0:
  171. embedding = F.normalize(embedding, dim=1)
  172. embedding = self.spk_embed_affine_layer(embedding)
  173. embedding = embedding.unsqueeze(dim=1)
  174. else:
  175. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  176. # 3. concat llm_input
  177. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  178. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  179. if prompt_speech_token_len != 0:
  180. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  181. else:
  182. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  183. lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  184. # 4. cal min/max_length
  185. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  186. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  187. # 5. step by step decode
  188. out_tokens = []
  189. offset = 0
  190. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  191. for i in range(max_len):
  192. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  193. att_cache=att_cache, cnn_cache=cnn_cache,
  194. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  195. device=lm_input.device)).to(torch.bool))
  196. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  197. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  198. if top_ids == self.eos_token:
  199. break
  200. # in stream mode, yield token one by one
  201. yield top_ids
  202. out_tokens.append(top_ids)
  203. offset += lm_input.size(1)
  204. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  205. class Qwen2Encoder(torch.nn.Module):
  206. def __init__(self, pretrain_path):
  207. super().__init__()
  208. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  209. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  210. T = xs.size(1)
  211. masks = ~make_pad_mask(xs_lens, T)
  212. outs = self.model(
  213. inputs_embeds=xs,
  214. attention_mask=masks,
  215. output_hidden_states=True,
  216. return_dict=True,
  217. )
  218. return outs.hidden_states[-1], masks.unsqueeze(1)
  219. def forward_one_step(self, xs, masks, cache=None):
  220. input_masks = masks[:, -1, :]
  221. outs = self.model(
  222. inputs_embeds=xs,
  223. attention_mask=input_masks,
  224. output_hidden_states=True,
  225. return_dict=True,
  226. use_cache=True,
  227. past_key_values=cache,
  228. )
  229. xs = outs.hidden_states[-1]
  230. new_cache = outs.past_key_values
  231. return xs, new_cache
  232. class Qwen2LM(TransformerLM):
  233. def __init__(
  234. self,
  235. llm_input_size: int,
  236. llm_output_size: int,
  237. speech_token_size: int,
  238. llm: torch.nn.Module,
  239. sampling: Callable,
  240. length_normalized_loss: bool = True,
  241. lsm_weight: float = 0.0,
  242. mix_ratio: List[int] = [5, 15],
  243. ):
  244. torch.nn.Module.__init__(self)
  245. self.llm_input_size = llm_input_size
  246. self.llm_output_size = llm_output_size
  247. self.speech_token_size = speech_token_size
  248. # 2. build speech token language model related modules
  249. self.sos = 0
  250. self.task_id = 1
  251. self.eos_token = speech_token_size
  252. self.fill_token = speech_token_size + 2
  253. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  254. self.llm = llm
  255. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  256. self.criterion_ce = LabelSmoothingLoss(
  257. size=speech_token_size + 3,
  258. padding_idx=IGNORE_ID,
  259. smoothing=lsm_weight,
  260. normalize_length=length_normalized_loss,
  261. )
  262. # 3. [Optional] build speech token related modules
  263. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  264. # 4. sampling method
  265. self.sampling = sampling
  266. self.mix_ratio = mix_ratio
  267. # 5. vllm related
  268. self.stop_token_ids = [speech_token_size + i for i in range(3)]
  269. self.vllm_output_queue = {}
  270. def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
  271. lm_target, lm_input = [], []
  272. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  273. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  274. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  275. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  276. for i in range(len(text_token)):
  277. # bistream sequence
  278. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  279. this_lm_target, this_lm_input = [], []
  280. this_lm_target.append(IGNORE_ID)
  281. this_lm_input.append(sos_emb.squeeze(dim=0))
  282. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  283. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  284. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  285. if len(this_text_token) == self.mix_ratio[0]:
  286. assert len(this_speech_token) == self.mix_ratio[1]
  287. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  288. this_lm_target += this_speech_token
  289. this_lm_target.append(self.fill_token)
  290. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  291. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  292. else:
  293. this_lm_target += [-1] * len(this_text_token)
  294. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  295. this_lm_target.append(self.eos_token)
  296. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  297. this_lm_input.append(task_id_emb.squeeze(dim=0))
  298. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  299. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  300. # unistream sequence
  301. else:
  302. this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
  303. this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
  304. lm_target.append(this_lm_target)
  305. lm_input.append(this_lm_input)
  306. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  307. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  308. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  309. return lm_target, lm_input, lm_input_len
  310. def forward(
  311. self,
  312. batch: dict,
  313. device: torch.device,
  314. ) -> Dict[str, Optional[torch.Tensor]]:
  315. """
  316. Args:
  317. text: (B, L, D)
  318. text_lengths: (B,)
  319. audio: (B, T, N) or (B, T)
  320. audio_lengths: (B,)
  321. """
  322. text_token = batch['text_token'].to(device)
  323. text_token_len = batch['text_token_len'].to(device)
  324. speech_token = batch['speech_token'].to(device)
  325. speech_token_len = batch['speech_token_len'].to(device)
  326. # 1. encode text_token
  327. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  328. # 3. sos and task_id
  329. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  330. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  331. # 2. encode speech_token
  332. speech_token_emb = self.speech_embedding(speech_token)
  333. # 3. prepare llm_input/target
  334. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len)
  335. lm_target = lm_target.to(device)
  336. # 4. run lm forward
  337. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  338. logits = self.llm_decoder(lm_output)
  339. loss = self.criterion_ce(logits, lm_target.to(device))
  340. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  341. return {'loss': loss, 'acc': acc}
  342. def forward_dpo(
  343. self,
  344. batch: dict,
  345. device: torch.device,
  346. ) -> Dict[str, Optional[torch.Tensor]]:
  347. text_token = batch['text_token'].to(device)
  348. text_token_len = batch['text_token_len'].to(device)
  349. speech_token = batch['speech_token'].to(device)
  350. speech_token_len = batch['speech_token_len'].to(device)
  351. reject_speech_token = batch['reject_speech_token'].to(device)
  352. reject_speech_token_len = batch['reject_speech_token_len'].to(device)
  353. # 1. encode text_token
  354. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  355. # 3. sos and task_id
  356. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  357. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  358. # 2. encode speech_token
  359. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  360. reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
  361. speech_token_combined = speech_token + reject_speech_token
  362. speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
  363. speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
  364. speech_token_combined_emb = self.speech_embedding(speech_token_combined)
  365. # 3. prepare llm_input/target
  366. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
  367. task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
  368. lm_target = lm_target.to(device)
  369. # 4. run lm forward
  370. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  371. logits = self.llm_decoder(lm_output)
  372. chosen_logits = logits[:text_token.shape[0]]
  373. rejected_logits = logits[text_token.shape[0]:]
  374. chosen_lm_target = lm_target[:text_token.shape[0]]
  375. rejected_lm_target = lm_target[text_token.shape[0]:]
  376. loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
  377. acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
  378. # 5. calculate dpo logits
  379. chosen_lm_mask = chosen_lm_target == IGNORE_ID
  380. rejected_lm_mask = rejected_lm_target == IGNORE_ID
  381. chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  382. rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  383. chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
  384. rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
  385. return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
  386. @torch.inference_mode()
  387. def inference(
  388. self,
  389. text: torch.Tensor,
  390. text_len: torch.Tensor,
  391. prompt_text: torch.Tensor,
  392. prompt_text_len: torch.Tensor,
  393. prompt_speech_token: torch.Tensor,
  394. prompt_speech_token_len: torch.Tensor,
  395. embedding: torch.Tensor,
  396. sampling: int = 25,
  397. max_token_text_ratio: float = 20,
  398. min_token_text_ratio: float = 2,
  399. uuid: str = '',
  400. ) -> Generator[torch.Tensor, None, None]:
  401. device = text.device
  402. text = torch.concat([prompt_text, text], dim=1)
  403. text_len += prompt_text_len
  404. text = self.llm.model.model.embed_tokens(text)
  405. # 3. concat llm_input
  406. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  407. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  408. if prompt_speech_token_len != 0:
  409. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  410. else:
  411. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  412. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  413. # 4. cal min/max_length
  414. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  415. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  416. # 5. step by step decode
  417. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  418. yield token
  419. @torch.inference_mode()
  420. def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
  421. if hasattr(self, 'vllm'):
  422. from vllm import SamplingParams, RequestOutput
  423. sampling_params = SamplingParams(top_k=sampling,
  424. stop_token_ids=self.stop_token_ids,
  425. min_tokens=min_len,
  426. max_tokens=max_len)
  427. with self.lock:
  428. self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
  429. self.vllm_output_queue[uuid] = queue.Queue()
  430. out_tokens = []
  431. while True:
  432. with self.lock:
  433. if self.vllm_output_queue[uuid].empty() is True:
  434. request_outputs: List[RequestOutput] = self.vllm.step()
  435. for request_output in request_outputs:
  436. top_ids = list(request_output.outputs[0].token_ids)[-1]
  437. self.vllm_output_queue[request_output.request_id].put(top_ids)
  438. if self.vllm_output_queue[uuid].empty() is False:
  439. top_ids = self.vllm_output_queue[uuid].get()
  440. if top_ids in self.stop_token_ids:
  441. break
  442. # in stream mode, yield token one by one
  443. yield top_ids
  444. out_tokens.append(top_ids)
  445. if len(out_tokens) == max_len:
  446. break
  447. time.sleep(0.001)
  448. with self.lock:
  449. self.vllm_output_queue.pop(uuid)
  450. else:
  451. out_tokens = []
  452. cache = None
  453. for i in range(max_len):
  454. y_pred, cache = self.llm.forward_one_step(lm_input,
  455. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  456. cache=cache)
  457. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  458. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
  459. if top_ids in self.stop_token_ids:
  460. break
  461. # in stream mode, yield token one by one
  462. yield top_ids
  463. out_tokens.append(top_ids)
  464. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  465. @torch.inference_mode()
  466. def inference_bistream(
  467. self,
  468. text: Generator,
  469. prompt_text: torch.Tensor,
  470. prompt_text_len: torch.Tensor,
  471. prompt_speech_token: torch.Tensor,
  472. prompt_speech_token_len: torch.Tensor,
  473. embedding: torch.Tensor,
  474. sampling: int = 25,
  475. max_token_text_ratio: float = 20,
  476. min_token_text_ratio: float = 2,
  477. ) -> Generator[torch.Tensor, None, None]:
  478. device = prompt_text.device
  479. # 1. prepare input
  480. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  481. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  482. if prompt_speech_token_len != 0:
  483. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  484. else:
  485. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  486. lm_input = torch.concat([sos_emb], dim=1)
  487. # 2. iterate text
  488. out_tokens = []
  489. cache = None
  490. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  491. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  492. next_fill_index = -1
  493. for this_text in text:
  494. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  495. # prompt_speech_token_emb not empty, try append to lm_input
  496. while prompt_speech_token_emb.size(1) != 0:
  497. if text_cache.size(1) >= self.mix_ratio[0]:
  498. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  499. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  500. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  501. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  502. else:
  503. logging.info('not enough text token to decode, wait for more')
  504. break
  505. # no prompt_speech_token_emb remain, can decode some speech token
  506. if prompt_speech_token_emb.size(1) == 0:
  507. if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  508. logging.info('get fill token, need to append more text token')
  509. if text_cache.size(1) >= self.mix_ratio[0]:
  510. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  511. logging.info('append {} text token'.format(lm_input_text.size(1)))
  512. if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
  513. lm_input = lm_input_text
  514. else:
  515. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  516. text_cache = text_cache[:, self.mix_ratio[0]:]
  517. else:
  518. logging.info('not enough text token to decode, wait for more')
  519. continue
  520. while True:
  521. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  522. y_pred, cache = self.llm.forward_one_step(lm_input,
  523. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  524. cache=cache)
  525. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  526. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  527. top_ids = self.fill_token
  528. next_fill_index += (self.mix_ratio[1] + 1)
  529. else:
  530. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
  531. if top_ids == self.fill_token:
  532. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  533. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  534. out_tokens.append(top_ids)
  535. if top_ids >= self.speech_token_size:
  536. if top_ids == self.fill_token:
  537. break
  538. else:
  539. raise ValueError('should not get token {}'.format(top_ids))
  540. yield top_ids
  541. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  542. # 3. final decode
  543. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  544. logging.info('no more text token, decode until met eos')
  545. while True:
  546. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  547. y_pred, cache = self.llm.forward_one_step(lm_input,
  548. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  549. cache=cache)
  550. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  551. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
  552. out_tokens.append(top_ids)
  553. if top_ids >= self.speech_token_size:
  554. if top_ids == self.eos_token:
  555. break
  556. else:
  557. raise ValueError('should not get token {}'.format(top_ids))
  558. # in stream mode, yield token one by one
  559. yield top_ids
  560. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  561. class CosyVoice3LM(Qwen2LM):
  562. def __init__(
  563. self,
  564. llm_input_size: int,
  565. llm_output_size: int,
  566. speech_token_size: int,
  567. llm: torch.nn.Module,
  568. sampling: Callable,
  569. length_normalized_loss: bool = True,
  570. lsm_weight: float = 0.0,
  571. mix_ratio: List[int] = [5, 15],
  572. ):
  573. torch.nn.Module.__init__(self)
  574. self.llm_input_size = llm_input_size
  575. self.llm_output_size = llm_output_size
  576. self.speech_token_size = speech_token_size
  577. # 2. build speech token language model related modules
  578. self.sos = speech_token_size + 0
  579. self.eos_token = speech_token_size + 1
  580. self.task_id = speech_token_size + 2
  581. self.fill_token = speech_token_size + 3
  582. self.llm = llm
  583. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
  584. self.criterion_ce = LabelSmoothingLoss(
  585. size=speech_token_size + 200,
  586. padding_idx=IGNORE_ID,
  587. smoothing=lsm_weight,
  588. normalize_length=length_normalized_loss,
  589. )
  590. # 3. [Optional] build speech token related modules
  591. self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
  592. # 4. sampling method
  593. self.sampling = sampling
  594. self.mix_ratio = mix_ratio
  595. # 5. vllm related
  596. self.stop_token_ids = [speech_token_size + i for i in range(200)]
  597. self.vllm_output_queue = {}
  598. def forward(
  599. self,
  600. batch: dict,
  601. device: torch.device,
  602. ) -> Dict[str, Optional[torch.Tensor]]:
  603. """
  604. Args:
  605. text: (B, L, D)
  606. text_lengths: (B,)
  607. audio: (B, T, N) or (B, T)
  608. audio_lengths: (B,)
  609. """
  610. text_token = batch['text_token'].to(device)
  611. text_token_len = batch['text_token_len'].to(device)
  612. speech_token = batch['speech_token'].to(device)
  613. speech_token_len = batch['speech_token_len'].to(device)
  614. # 1. encode text_token
  615. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  616. # 3. sos and task_id
  617. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  618. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  619. # 2. encode speech_token
  620. speech_token_emb = self.speech_embedding(speech_token)
  621. # 3. prepare llm_input/target
  622. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len)
  623. lm_target = lm_target.to(device)
  624. # 4. run lm forward
  625. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  626. logits = self.llm_decoder(lm_output)
  627. loss = self.criterion_ce(logits, lm_target.to(device))
  628. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  629. return {'loss': loss, 'acc': acc}
  630. @torch.inference_mode()
  631. def inference(
  632. self,
  633. text: torch.Tensor,
  634. text_len: torch.Tensor,
  635. prompt_text: torch.Tensor,
  636. prompt_text_len: torch.Tensor,
  637. prompt_speech_token: torch.Tensor,
  638. prompt_speech_token_len: torch.Tensor,
  639. embedding: torch.Tensor,
  640. sampling: int = 25,
  641. max_token_text_ratio: float = 20,
  642. min_token_text_ratio: float = 2,
  643. uuid: str = '',
  644. ) -> Generator[torch.Tensor, None, None]:
  645. device = text.device
  646. text = torch.concat([prompt_text, text], dim=1)
  647. text_len += prompt_text_len
  648. text = self.llm.model.model.embed_tokens(text)
  649. # 3. concat llm_input
  650. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  651. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  652. if prompt_speech_token_len != 0:
  653. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  654. else:
  655. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  656. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  657. # 4. cal min/max_length
  658. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  659. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  660. # 5. step by step decode
  661. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  662. yield token