llm.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import queue
  16. import random
  17. import time
  18. import threading
  19. from typing import Dict, Optional, Callable, List, Generator
  20. import numpy as np
  21. import torch
  22. from torch import nn
  23. import torch.nn.functional as F
  24. from transformers import Qwen2ForCausalLM
  25. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  26. from cosyvoice.utils.common import IGNORE_ID
  27. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  28. from cosyvoice.utils.common import th_accuracy
  29. from cosyvoice.utils.file_utils import logging
  30. from cosyvoice.utils.mask import make_pad_mask
  31. class TransformerLM(torch.nn.Module):
  32. def __init__(
  33. self,
  34. text_encoder_input_size: int,
  35. llm_input_size: int,
  36. llm_output_size: int,
  37. text_token_size: int,
  38. speech_token_size: int,
  39. text_encoder: torch.nn.Module,
  40. llm: torch.nn.Module,
  41. sampling: Callable,
  42. length_normalized_loss: bool = True,
  43. lsm_weight: float = 0.0,
  44. spk_embed_dim: int = 192,
  45. ):
  46. super().__init__()
  47. self.llm_input_size = llm_input_size
  48. self.speech_token_size = speech_token_size
  49. # 1. build text token inputs related modules
  50. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  51. self.text_encoder = text_encoder
  52. self.text_encoder_affine_layer = nn.Linear(
  53. self.text_encoder.output_size(),
  54. llm_input_size
  55. )
  56. # 2. build speech token language model related modules
  57. self.sos = 0
  58. self.task_id = 1
  59. self.eos_token = self.speech_token_size
  60. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  61. self.llm = llm
  62. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  63. self.criterion_ce = LabelSmoothingLoss(
  64. size=speech_token_size + 1,
  65. padding_idx=IGNORE_ID,
  66. smoothing=lsm_weight,
  67. normalize_length=length_normalized_loss,
  68. )
  69. # 3. [Optional] build speech token related modules
  70. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  71. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  72. # 4. sampling method
  73. self.sampling = sampling
  74. def encode(
  75. self,
  76. text: torch.Tensor,
  77. text_lengths: torch.Tensor,
  78. ):
  79. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  80. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  81. encoder_out = self.text_encoder_affine_layer(encoder_out)
  82. return encoder_out, encoder_out_lens
  83. def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  84. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  85. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  86. lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  87. for i in range(len(text_token))]
  88. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  89. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  90. return lm_input, lm_input_len
  91. def forward(
  92. self,
  93. batch: dict,
  94. device: torch.device,
  95. ) -> Dict[str, Optional[torch.Tensor]]:
  96. """
  97. Args:
  98. text: (B, L, D)
  99. text_lengths: (B,)
  100. audio: (B, T, N) or (B, T)
  101. audio_lengths: (B,)
  102. """
  103. text_token = batch['text_token'].to(device)
  104. text_token_len = batch['text_token_len'].to(device)
  105. speech_token = batch['speech_token'].to(device)
  106. speech_token_len = batch['speech_token_len'].to(device)
  107. embedding = batch['embedding'].to(device)
  108. # 1. prepare llm_target
  109. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  110. [self.speech_token_size]) for i in range(text_token.size(0))]
  111. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  112. # 1. encode text_token
  113. text_token = self.text_embedding(text_token)
  114. text_token, text_token_len = self.encode(text_token, text_token_len)
  115. # 2. embedding projection
  116. embedding = F.normalize(embedding, dim=1)
  117. embedding = self.spk_embed_affine_layer(embedding)
  118. embedding = embedding.unsqueeze(1)
  119. # 3. sos and task_id
  120. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  121. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  122. # 4. encode speech_token
  123. speech_token = self.speech_embedding(speech_token)
  124. # 5. unpad and pad
  125. lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
  126. task_id_emb, speech_token, speech_token_len)
  127. # 6. run lm forward
  128. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  129. logits = self.llm_decoder(lm_output)
  130. loss = self.criterion_ce(logits, lm_target)
  131. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  132. return {'loss': loss, 'acc': acc}
  133. def sampling_ids(
  134. self,
  135. weighted_scores: torch.Tensor,
  136. decoded_tokens: List,
  137. sampling: int,
  138. ignore_eos: bool = True,
  139. ):
  140. num_trials, max_trials = 0, 100
  141. while True:
  142. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  143. if (not ignore_eos) or (top_ids < self.speech_token_size):
  144. break
  145. num_trials += 1
  146. if num_trials > max_trials:
  147. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  148. return top_ids
  149. @torch.inference_mode()
  150. def inference(
  151. self,
  152. text: torch.Tensor,
  153. text_len: torch.Tensor,
  154. prompt_text: torch.Tensor,
  155. prompt_text_len: torch.Tensor,
  156. prompt_speech_token: torch.Tensor,
  157. prompt_speech_token_len: torch.Tensor,
  158. embedding: torch.Tensor,
  159. sampling: int = 25,
  160. max_token_text_ratio: float = 20,
  161. min_token_text_ratio: float = 2,
  162. uuid: str = '',
  163. ) -> Generator[torch.Tensor, None, None]:
  164. device = text.device
  165. text = torch.concat([prompt_text, text], dim=1)
  166. text_len += prompt_text_len
  167. text = self.text_embedding(text)
  168. # 1. encode text
  169. text, text_len = self.encode(text, text_len)
  170. # 2. encode embedding
  171. if embedding.shape[0] != 0:
  172. embedding = F.normalize(embedding, dim=1)
  173. embedding = self.spk_embed_affine_layer(embedding)
  174. embedding = embedding.unsqueeze(dim=1)
  175. else:
  176. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  177. # 3. concat llm_input
  178. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  179. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  180. if prompt_speech_token_len != 0:
  181. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  182. else:
  183. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  184. lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  185. # 4. cal min/max_length
  186. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  187. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  188. # 5. step by step decode
  189. out_tokens = []
  190. offset = 0
  191. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  192. for i in range(max_len):
  193. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  194. att_cache=att_cache, cnn_cache=cnn_cache,
  195. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  196. device=lm_input.device)).to(torch.bool))
  197. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  198. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
  199. if top_ids == self.eos_token:
  200. break
  201. # in stream mode, yield token one by one
  202. yield top_ids
  203. out_tokens.append(top_ids)
  204. offset += lm_input.size(1)
  205. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  206. class Qwen2Encoder(torch.nn.Module):
  207. def __init__(self, pretrain_path):
  208. super().__init__()
  209. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  210. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  211. T = xs.size(1)
  212. masks = ~make_pad_mask(xs_lens, T)
  213. outs = self.model(
  214. inputs_embeds=xs,
  215. attention_mask=masks,
  216. output_hidden_states=True,
  217. return_dict=True,
  218. )
  219. return outs.hidden_states[-1], masks.unsqueeze(1)
  220. def forward_one_step(self, xs, masks, cache=None):
  221. input_masks = masks[:, -1, :]
  222. outs = self.model(
  223. inputs_embeds=xs,
  224. attention_mask=input_masks,
  225. output_hidden_states=True,
  226. return_dict=True,
  227. use_cache=True,
  228. past_key_values=cache,
  229. )
  230. xs = outs.hidden_states[-1]
  231. new_cache = outs.past_key_values
  232. return xs, new_cache
  233. class Qwen2LM(TransformerLM):
  234. def __init__(
  235. self,
  236. llm_input_size: int,
  237. llm_output_size: int,
  238. speech_token_size: int,
  239. llm: torch.nn.Module,
  240. sampling: Callable,
  241. length_normalized_loss: bool = True,
  242. lsm_weight: float = 0.0,
  243. mix_ratio: List[int] = [5, 15],
  244. ):
  245. torch.nn.Module.__init__(self)
  246. self.llm_input_size = llm_input_size
  247. self.llm_output_size = llm_output_size
  248. self.speech_token_size = speech_token_size
  249. # 2. build speech token language model related modules
  250. self.sos = 0
  251. self.task_id = 1
  252. self.eos_token = speech_token_size
  253. self.fill_token = speech_token_size + 2
  254. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  255. self.llm = llm
  256. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  257. self.criterion_ce = LabelSmoothingLoss(
  258. size=speech_token_size + 3,
  259. padding_idx=IGNORE_ID,
  260. smoothing=lsm_weight,
  261. normalize_length=length_normalized_loss,
  262. )
  263. # 3. [Optional] build speech token related modules
  264. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  265. # 4. sampling method
  266. self.sampling = sampling
  267. self.mix_ratio = mix_ratio
  268. # 5. vllm related
  269. self.stop_token_ids = [speech_token_size + i for i in range(3)]
  270. self.vllm_output_queue = {}
  271. def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
  272. lm_target, lm_input = [], []
  273. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  274. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  275. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  276. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  277. # NOTE add instruct_token in CosyVoice3
  278. if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
  279. instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
  280. instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
  281. else:
  282. instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
  283. instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
  284. instruct_token_len = torch.zeros(len(text_token)).to(text_token_len)
  285. for i in range(len(text_token)):
  286. # bistream sequence
  287. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  288. this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
  289. this_lm_target += [IGNORE_ID] * instruct_token_len[i]
  290. this_lm_input.append(instruct_token_emb[i])
  291. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  292. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  293. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  294. if len(this_text_token) == self.mix_ratio[0]:
  295. assert len(this_speech_token) == self.mix_ratio[1]
  296. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  297. this_lm_target += this_speech_token
  298. this_lm_target.append(self.fill_token)
  299. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  300. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  301. else:
  302. this_lm_target += [-1] * len(this_text_token)
  303. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  304. this_lm_target.append(self.eos_token)
  305. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  306. this_lm_input.append(task_id_emb.squeeze(dim=0))
  307. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  308. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  309. # unistream sequence
  310. else:
  311. this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
  312. this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
  313. lm_target.append(this_lm_target)
  314. lm_input.append(this_lm_input)
  315. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  316. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  317. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  318. return lm_target, lm_input, lm_input_len
  319. def forward(
  320. self,
  321. batch: dict,
  322. device: torch.device,
  323. ) -> Dict[str, Optional[torch.Tensor]]:
  324. """
  325. Args:
  326. text: (B, L, D)
  327. text_lengths: (B,)
  328. audio: (B, T, N) or (B, T)
  329. audio_lengths: (B,)
  330. """
  331. text_token = batch['text_token'].to(device)
  332. text_token_len = batch['text_token_len'].to(device)
  333. speech_token = batch['speech_token'].to(device)
  334. speech_token_len = batch['speech_token_len'].to(device)
  335. # 1. encode text_token
  336. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  337. # 3. sos and task_id
  338. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  339. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  340. # 2. encode speech_token
  341. speech_token_emb = self.speech_embedding(speech_token)
  342. # 3. prepare llm_input/target
  343. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
  344. speech_token, speech_token_emb, speech_token_len)
  345. lm_target = lm_target.to(device)
  346. # 4. run lm forward
  347. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  348. logits = self.llm_decoder(lm_output)
  349. loss = self.criterion_ce(logits, lm_target.to(device))
  350. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  351. return {'loss': loss, 'acc': acc}
  352. def forward_dpo(
  353. self,
  354. batch: dict,
  355. device: torch.device,
  356. ) -> Dict[str, Optional[torch.Tensor]]:
  357. text_token = batch['text_token'].to(device)
  358. text_token_len = batch['text_token_len'].to(device)
  359. speech_token = batch['speech_token'].to(device)
  360. speech_token_len = batch['speech_token_len'].to(device)
  361. reject_speech_token = batch['reject_speech_token'].to(device)
  362. reject_speech_token_len = batch['reject_speech_token_len'].to(device)
  363. # 1. encode text_token
  364. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  365. # 3. sos and task_id
  366. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  367. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  368. # 2. encode speech_token
  369. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  370. reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
  371. speech_token_combined = speech_token + reject_speech_token
  372. speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
  373. speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
  374. speech_token_combined_emb = self.speech_embedding(speech_token_combined)
  375. # 3. prepare llm_input/target
  376. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
  377. task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
  378. lm_target = lm_target.to(device)
  379. # 4. run lm forward
  380. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  381. logits = self.llm_decoder(lm_output)
  382. chosen_logits = logits[:text_token.shape[0]]
  383. rejected_logits = logits[text_token.shape[0]:]
  384. chosen_lm_target = lm_target[:text_token.shape[0]]
  385. rejected_lm_target = lm_target[text_token.shape[0]:]
  386. loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
  387. acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
  388. # 5. calculate dpo logits
  389. chosen_lm_mask = chosen_lm_target == IGNORE_ID
  390. rejected_lm_mask = rejected_lm_target == IGNORE_ID
  391. chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  392. rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  393. chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
  394. rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
  395. return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
  396. @torch.inference_mode()
  397. def inference(
  398. self,
  399. text: torch.Tensor,
  400. text_len: torch.Tensor,
  401. prompt_text: torch.Tensor,
  402. prompt_text_len: torch.Tensor,
  403. prompt_speech_token: torch.Tensor,
  404. prompt_speech_token_len: torch.Tensor,
  405. embedding: torch.Tensor,
  406. sampling: int = 25,
  407. max_token_text_ratio: float = 20,
  408. min_token_text_ratio: float = 2,
  409. uuid: str = '',
  410. ) -> Generator[torch.Tensor, None, None]:
  411. device = text.device
  412. text = torch.concat([prompt_text, text], dim=1)
  413. text_len += prompt_text_len
  414. text = self.llm.model.model.embed_tokens(text)
  415. # 3. concat llm_input
  416. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  417. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  418. if prompt_speech_token_len != 0:
  419. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  420. else:
  421. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  422. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  423. # 4. cal min/max_length
  424. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  425. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  426. # 5. step by step decode
  427. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  428. yield token
  429. @torch.inference_mode()
  430. def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
  431. if hasattr(self, 'vllm'):
  432. from vllm import SamplingParams, RequestOutput
  433. sampling_params = SamplingParams(top_k=sampling,
  434. stop_token_ids=self.stop_token_ids,
  435. min_tokens=min_len,
  436. max_tokens=max_len)
  437. with self.lock:
  438. self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
  439. self.vllm_output_queue[uuid] = queue.Queue()
  440. out_tokens = []
  441. while True:
  442. with self.lock:
  443. if self.vllm_output_queue[uuid].empty() is True:
  444. request_outputs: List[RequestOutput] = self.vllm.step()
  445. for request_output in request_outputs:
  446. top_ids = list(request_output.outputs[0].token_ids)[-1]
  447. self.vllm_output_queue[request_output.request_id].put(top_ids)
  448. if self.vllm_output_queue[uuid].empty() is False:
  449. top_ids = self.vllm_output_queue[uuid].get()
  450. if top_ids in self.stop_token_ids:
  451. break
  452. # in stream mode, yield token one by one
  453. yield top_ids
  454. out_tokens.append(top_ids)
  455. if len(out_tokens) == max_len:
  456. break
  457. time.sleep(0.001)
  458. with self.lock:
  459. self.vllm_output_queue.pop(uuid)
  460. else:
  461. out_tokens = []
  462. cache = None
  463. for i in range(max_len):
  464. y_pred, cache = self.llm.forward_one_step(lm_input,
  465. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  466. cache=cache)
  467. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  468. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
  469. if top_ids in self.stop_token_ids:
  470. break
  471. # in stream mode, yield token one by one
  472. yield top_ids
  473. out_tokens.append(top_ids)
  474. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  475. @torch.inference_mode()
  476. def inference_bistream(
  477. self,
  478. text: Generator,
  479. prompt_text: torch.Tensor,
  480. prompt_text_len: torch.Tensor,
  481. prompt_speech_token: torch.Tensor,
  482. prompt_speech_token_len: torch.Tensor,
  483. embedding: torch.Tensor,
  484. sampling: int = 25,
  485. max_token_text_ratio: float = 20,
  486. min_token_text_ratio: float = 2,
  487. ) -> Generator[torch.Tensor, None, None]:
  488. device = prompt_text.device
  489. # 1. prepare input
  490. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  491. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  492. if prompt_speech_token_len != 0:
  493. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  494. else:
  495. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  496. lm_input = torch.concat([sos_emb], dim=1)
  497. # 2. iterate text
  498. out_tokens = []
  499. cache = None
  500. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  501. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  502. next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
  503. for this_text in text:
  504. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  505. # prompt_speech_token_emb not empty, try append to lm_input
  506. while prompt_speech_token_emb.size(1) != 0:
  507. if text_cache.size(1) >= self.mix_ratio[0]:
  508. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  509. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  510. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  511. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  512. else:
  513. logging.info('not enough text token to decode, wait for more')
  514. break
  515. # no prompt_speech_token_emb remain, can decode some speech token
  516. if prompt_speech_token_emb.size(1) == 0:
  517. if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  518. logging.info('get fill token, need to append more text token')
  519. if text_cache.size(1) >= self.mix_ratio[0]:
  520. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  521. logging.info('append {} text token'.format(lm_input_text.size(1)))
  522. if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
  523. lm_input = lm_input_text
  524. else:
  525. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  526. text_cache = text_cache[:, self.mix_ratio[0]:]
  527. else:
  528. logging.info('not enough text token to decode, wait for more')
  529. continue
  530. while True:
  531. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  532. y_pred, cache = self.llm.forward_one_step(lm_input,
  533. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  534. cache=cache)
  535. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  536. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  537. top_ids = self.fill_token
  538. next_fill_index += (self.mix_ratio[1] + 1)
  539. else:
  540. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
  541. if top_ids == self.fill_token:
  542. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  543. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  544. out_tokens.append(top_ids)
  545. if top_ids >= self.speech_token_size:
  546. if top_ids == self.fill_token:
  547. break
  548. else:
  549. raise ValueError('should not get token {}'.format(top_ids))
  550. yield top_ids
  551. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  552. # 3. final decode
  553. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  554. logging.info('no more text token, decode until met eos')
  555. while True:
  556. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  557. y_pred, cache = self.llm.forward_one_step(lm_input,
  558. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  559. cache=cache)
  560. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  561. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
  562. out_tokens.append(top_ids)
  563. if top_ids >= self.speech_token_size:
  564. if top_ids == self.eos_token:
  565. break
  566. else:
  567. raise ValueError('should not get token {}'.format(top_ids))
  568. # in stream mode, yield token one by one
  569. yield top_ids
  570. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  571. class CosyVoice3LM(Qwen2LM):
  572. def __init__(
  573. self,
  574. llm_input_size: int,
  575. llm_output_size: int,
  576. speech_token_size: int,
  577. llm: torch.nn.Module,
  578. sampling: Callable,
  579. length_normalized_loss: bool = True,
  580. lsm_weight: float = 0.0,
  581. mix_ratio: List[int] = [5, 15],
  582. ):
  583. torch.nn.Module.__init__(self)
  584. self.llm_input_size = llm_input_size
  585. self.llm_output_size = llm_output_size
  586. self.speech_token_size = speech_token_size
  587. # 2. build speech token language model related modules
  588. self.sos = speech_token_size + 0
  589. self.eos_token = speech_token_size + 1
  590. self.task_id = speech_token_size + 2
  591. self.fill_token = speech_token_size + 3
  592. self.llm = llm
  593. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
  594. self.criterion_ce = LabelSmoothingLoss(
  595. size=speech_token_size + 200,
  596. padding_idx=IGNORE_ID,
  597. smoothing=lsm_weight,
  598. normalize_length=length_normalized_loss,
  599. )
  600. # 3. [Optional] build speech token related modules
  601. self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
  602. # 4. sampling method
  603. self.sampling = sampling
  604. self.mix_ratio = mix_ratio
  605. # 5. vllm related
  606. self.stop_token_ids = [speech_token_size + i for i in range(200)]
  607. self.vllm_output_queue = {}
  608. def forward(
  609. self,
  610. batch: dict,
  611. device: torch.device,
  612. ) -> Dict[str, Optional[torch.Tensor]]:
  613. """
  614. Args:
  615. text: (B, L, D)
  616. text_lengths: (B,)
  617. audio: (B, T, N) or (B, T)
  618. audio_lengths: (B,)
  619. """
  620. text_token = batch['text_token'].to(device)
  621. text_token_len = batch['text_token_len'].to(device)
  622. speech_token = batch['speech_token'].to(device)
  623. speech_token_len = batch['speech_token_len'].to(device)
  624. # NOTE should append instruct_token to sequence, not implemented yet
  625. instruct_token = batch['instruct_token'].to(device)
  626. instruct_token_len = batch['instruct_token_len'].to(device)
  627. # 1. encode text_token
  628. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  629. instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
  630. # 3. sos and task_id
  631. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  632. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  633. # 2. encode speech_token
  634. speech_token_emb = self.speech_embedding(speech_token)
  635. # 3. prepare llm_input/target
  636. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
  637. speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
  638. lm_target = lm_target.to(device)
  639. # 4. run lm forward
  640. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  641. logits = self.llm_decoder(lm_output)
  642. loss = self.criterion_ce(logits, lm_target.to(device))
  643. acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
  644. return {'loss': loss, 'acc': acc}
  645. @torch.inference_mode()
  646. def inference(
  647. self,
  648. text: torch.Tensor,
  649. text_len: torch.Tensor,
  650. prompt_text: torch.Tensor,
  651. prompt_text_len: torch.Tensor,
  652. prompt_speech_token: torch.Tensor,
  653. prompt_speech_token_len: torch.Tensor,
  654. embedding: torch.Tensor,
  655. sampling: int = 25,
  656. max_token_text_ratio: float = 20,
  657. min_token_text_ratio: float = 2,
  658. uuid: str = '',
  659. ) -> Generator[torch.Tensor, None, None]:
  660. device = text.device
  661. text = torch.concat([prompt_text, text], dim=1)
  662. text_len += prompt_text_len
  663. text = self.llm.model.model.embed_tokens(text)
  664. # 3. concat llm_input
  665. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  666. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  667. if prompt_speech_token_len != 0:
  668. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  669. else:
  670. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  671. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  672. # 4. cal min/max_length
  673. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  674. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  675. # 5. step by step decode
  676. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  677. yield token