llm.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os, queue
  16. import random
  17. import time
  18. import threading
  19. from typing import Dict, Optional, Callable, List, Generator
  20. import numpy as np
  21. import torch
  22. from torch import nn
  23. import torch.nn.functional as F
  24. from transformers import Qwen2ForCausalLM
  25. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  26. from cosyvoice.utils.common import IGNORE_ID
  27. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  28. from cosyvoice.utils.common import th_accuracy
  29. from cosyvoice.utils.file_utils import logging
  30. from cosyvoice.utils.mask import make_pad_mask
  31. from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
  32. class TransformerLM(torch.nn.Module):
  33. def __init__(
  34. self,
  35. text_encoder_input_size: int,
  36. llm_input_size: int,
  37. llm_output_size: int,
  38. text_token_size: int,
  39. speech_token_size: int,
  40. text_encoder: torch.nn.Module,
  41. llm: torch.nn.Module,
  42. sampling: Callable,
  43. length_normalized_loss: bool = True,
  44. lsm_weight: float = 0.0,
  45. spk_embed_dim: int = 192,
  46. ):
  47. super().__init__()
  48. self.llm_input_size = llm_input_size
  49. self.speech_token_size = speech_token_size
  50. # 1. build text token inputs related modules
  51. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  52. self.text_encoder = text_encoder
  53. self.text_encoder_affine_layer = nn.Linear(
  54. self.text_encoder.output_size(),
  55. llm_input_size
  56. )
  57. # 2. build speech token language model related modules
  58. self.sos = 0
  59. self.task_id = 1
  60. self.eos_token = self.speech_token_size
  61. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  62. self.llm = llm
  63. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  64. self.criterion_ce = LabelSmoothingLoss(
  65. size=speech_token_size + 1,
  66. padding_idx=IGNORE_ID,
  67. smoothing=lsm_weight,
  68. normalize_length=length_normalized_loss,
  69. )
  70. # 3. [Optional] build speech token related modules
  71. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  72. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  73. # 4. sampling method
  74. self.sampling = sampling
  75. def encode(
  76. self,
  77. text: torch.Tensor,
  78. text_lengths: torch.Tensor,
  79. ):
  80. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  81. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  82. encoder_out = self.text_encoder_affine_layer(encoder_out)
  83. return encoder_out, encoder_out_lens
  84. def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  85. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  86. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  87. lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  88. for i in range(len(text_token))]
  89. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  90. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  91. return lm_input, lm_input_len
  92. def forward(
  93. self,
  94. batch: dict,
  95. device: torch.device,
  96. ) -> Dict[str, Optional[torch.Tensor]]:
  97. """
  98. Args:
  99. text: (B, L, D)
  100. text_lengths: (B,)
  101. audio: (B, T, N) or (B, T)
  102. audio_lengths: (B,)
  103. """
  104. text_token = batch['text_token'].to(device)
  105. text_token_len = batch['text_token_len'].to(device)
  106. speech_token = batch['speech_token'].to(device)
  107. speech_token_len = batch['speech_token_len'].to(device)
  108. embedding = batch['embedding'].to(device)
  109. # 1. prepare llm_target
  110. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  111. [self.speech_token_size]) for i in range(text_token.size(0))]
  112. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  113. # 1. encode text_token
  114. text_token = self.text_embedding(text_token)
  115. text_token, text_token_len = self.encode(text_token, text_token_len)
  116. # 2. embedding projection
  117. embedding = F.normalize(embedding, dim=1)
  118. embedding = self.spk_embed_affine_layer(embedding)
  119. embedding = embedding.unsqueeze(1)
  120. # 3. sos and task_id
  121. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  122. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  123. # 4. encode speech_token
  124. speech_token = self.speech_embedding(speech_token)
  125. # 5. unpad and pad
  126. lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
  127. task_id_emb, speech_token, speech_token_len)
  128. # 6. run lm forward
  129. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  130. logits = self.llm_decoder(lm_output)
  131. loss = self.criterion_ce(logits, lm_target)
  132. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  133. return {'loss': loss, 'acc': acc}
  134. def sampling_ids(
  135. self,
  136. weighted_scores: torch.Tensor,
  137. decoded_tokens: List,
  138. sampling: int,
  139. ignore_eos: bool = True,
  140. ):
  141. if ignore_eos is True:
  142. weighted_scores[self.speech_token_size] = -float('inf')
  143. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  144. return top_ids
  145. @torch.inference_mode()
  146. def inference(
  147. self,
  148. text: torch.Tensor,
  149. text_len: torch.Tensor,
  150. prompt_text: torch.Tensor,
  151. prompt_text_len: torch.Tensor,
  152. prompt_speech_token: torch.Tensor,
  153. prompt_speech_token_len: torch.Tensor,
  154. embedding: torch.Tensor,
  155. sampling: int = 25,
  156. max_token_text_ratio: float = 20,
  157. min_token_text_ratio: float = 2,
  158. uuid: str = '',
  159. ) -> Generator[torch.Tensor, None, None]:
  160. device = text.device
  161. text = torch.concat([prompt_text, text], dim=1)
  162. text_len += prompt_text_len
  163. text = self.text_embedding(text)
  164. # 1. encode text
  165. text, text_len = self.encode(text, text_len)
  166. # 2. encode embedding
  167. if embedding.shape[0] != 0:
  168. embedding = F.normalize(embedding, dim=1)
  169. embedding = self.spk_embed_affine_layer(embedding)
  170. embedding = embedding.unsqueeze(dim=1)
  171. else:
  172. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  173. # 3. concat llm_input
  174. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  175. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  176. if prompt_speech_token_len != 0:
  177. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  178. else:
  179. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  180. lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  181. # 4. cal min/max_length
  182. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  183. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  184. # 5. step by step decode
  185. out_tokens = []
  186. offset = 0
  187. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  188. for i in range(max_len):
  189. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  190. att_cache=att_cache, cnn_cache=cnn_cache,
  191. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  192. device=lm_input.device)).to(torch.bool))
  193. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  194. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
  195. if top_ids == self.eos_token:
  196. break
  197. # in stream mode, yield token one by one
  198. yield top_ids
  199. out_tokens.append(top_ids)
  200. offset += lm_input.size(1)
  201. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  202. class Qwen2Encoder(torch.nn.Module):
  203. def __init__(self, pretrain_path):
  204. super().__init__()
  205. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  206. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  207. T = xs.size(1)
  208. masks = ~make_pad_mask(xs_lens, T)
  209. outs = self.model(
  210. inputs_embeds=xs,
  211. attention_mask=masks,
  212. output_hidden_states=True,
  213. return_dict=True,
  214. )
  215. return outs.hidden_states[-1], masks.unsqueeze(1)
  216. def forward_one_step(self, xs, masks, cache=None):
  217. input_masks = masks[:, -1, :]
  218. outs = self.model(
  219. inputs_embeds=xs,
  220. attention_mask=input_masks,
  221. output_hidden_states=True,
  222. return_dict=True,
  223. use_cache=True,
  224. past_key_values=cache,
  225. )
  226. xs = outs.hidden_states[-1]
  227. new_cache = outs.past_key_values
  228. return xs, new_cache
  229. class Qwen2LM(TransformerLM):
  230. def __init__(
  231. self,
  232. llm_input_size: int,
  233. llm_output_size: int,
  234. speech_token_size: int,
  235. llm: torch.nn.Module,
  236. sampling: Callable,
  237. length_normalized_loss: bool = True,
  238. lsm_weight: float = 0.0,
  239. mix_ratio: List[int] = [5, 15],
  240. ):
  241. torch.nn.Module.__init__(self)
  242. self.llm_input_size = llm_input_size
  243. self.llm_output_size = llm_output_size
  244. self.speech_token_size = speech_token_size
  245. # 2. build speech token language model related modules
  246. self.sos = 0
  247. self.task_id = 1
  248. self.eos_token = speech_token_size
  249. self.fill_token = speech_token_size + 2
  250. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  251. self.llm = llm
  252. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  253. self.criterion_ce = LabelSmoothingLoss(
  254. size=speech_token_size + 3,
  255. padding_idx=IGNORE_ID,
  256. smoothing=lsm_weight,
  257. normalize_length=length_normalized_loss,
  258. )
  259. # 3. [Optional] build speech token related modules
  260. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  261. # 4. sampling method
  262. self.sampling = sampling
  263. self.mix_ratio = mix_ratio
  264. # 5. vllm related
  265. self.stop_token_ids = [speech_token_size + i for i in range(3)]
  266. self.vllm_output_queue = {}
  267. if online_feature is True:
  268. self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
  269. def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
  270. lm_target, lm_input = [], []
  271. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  272. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  273. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  274. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  275. # NOTE add instruct_token in CosyVoice3
  276. if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
  277. instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
  278. instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
  279. else:
  280. instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
  281. instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
  282. instruct_token_len = torch.zeros(len(text_token)).to(text_token_len)
  283. for i in range(len(text_token)):
  284. # bistream sequence
  285. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  286. this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
  287. this_lm_target += [IGNORE_ID] * instruct_token_len[i]
  288. this_lm_input.append(instruct_token_emb[i])
  289. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  290. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  291. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  292. if len(this_text_token) == self.mix_ratio[0]:
  293. assert len(this_speech_token) == self.mix_ratio[1]
  294. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  295. this_lm_target += this_speech_token
  296. this_lm_target.append(self.fill_token)
  297. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  298. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  299. else:
  300. this_lm_target += [-1] * len(this_text_token)
  301. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  302. this_lm_target.append(self.eos_token)
  303. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  304. this_lm_input.append(task_id_emb.squeeze(dim=0))
  305. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  306. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  307. # unistream sequence
  308. else:
  309. this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
  310. this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
  311. lm_target.append(this_lm_target)
  312. lm_input.append(this_lm_input)
  313. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  314. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  315. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  316. return lm_target, lm_input, lm_input_len
  317. def forward(
  318. self,
  319. batch: dict,
  320. device: torch.device,
  321. ) -> Dict[str, Optional[torch.Tensor]]:
  322. """
  323. Args:
  324. text: (B, L, D)
  325. text_lengths: (B,)
  326. audio: (B, T, N) or (B, T)
  327. audio_lengths: (B,)
  328. """
  329. # 1. encode text_token
  330. text_token = batch['text_token'].to(device)
  331. text_token_len = batch['text_token_len'].to(device)
  332. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  333. # 2. encode speech_token
  334. if 'speech_token' not in batch:
  335. speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
  336. else:
  337. speech_token = batch['speech_token'].to(device)
  338. speech_token_len = batch['speech_token_len'].to(device)
  339. speech_token_emb = self.speech_embedding(speech_token)
  340. # 3. sos and task_id
  341. if self.__class__.__name__ == 'CosyVoice3LM':
  342. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  343. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  344. elif self.__class__.__name__ == 'Qwen2LM':
  345. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  346. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  347. else:
  348. raise ValueError
  349. # 4. prepare llm_input/target
  350. if self.__class__.__name__ == 'CosyVoice3LM':
  351. instruct_token = batch['instruct_token'].to(device)
  352. instruct_token_len = batch['instruct_token_len'].to(device)
  353. instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
  354. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
  355. speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
  356. elif self.__class__.__name__ == 'Qwen2LM':
  357. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
  358. speech_token, speech_token_emb, speech_token_len)
  359. else:
  360. raise ValueError
  361. lm_target = lm_target.to(device)
  362. # 4. run lm forward
  363. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  364. logits = self.llm_decoder(lm_output)
  365. loss = self.criterion_ce(logits, lm_target.to(device))
  366. acc = th_accuracy(logits.view(-1, self.llm_decoder.out_features), lm_target, ignore_label=IGNORE_ID)
  367. return {'loss': loss, 'acc': acc}
  368. def forward_dpo(
  369. self,
  370. batch: dict,
  371. device: torch.device,
  372. ) -> Dict[str, Optional[torch.Tensor]]:
  373. text_token = batch['text_token'].to(device)
  374. text_token_len = batch['text_token_len'].to(device)
  375. speech_token = batch['speech_token'].to(device)
  376. speech_token_len = batch['speech_token_len'].to(device)
  377. reject_speech_token = batch['reject_speech_token'].to(device)
  378. reject_speech_token_len = batch['reject_speech_token_len'].to(device)
  379. # 1. encode text_token
  380. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  381. # 3. sos and task_id
  382. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  383. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  384. # 2. encode speech_token
  385. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  386. reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
  387. speech_token_combined = speech_token + reject_speech_token
  388. speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
  389. speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
  390. speech_token_combined_emb = self.speech_embedding(speech_token_combined)
  391. # 3. prepare llm_input/target
  392. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
  393. task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
  394. lm_target = lm_target.to(device)
  395. # 4. run lm forward
  396. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  397. logits = self.llm_decoder(lm_output)
  398. chosen_logits = logits[:text_token.shape[0]]
  399. rejected_logits = logits[text_token.shape[0]:]
  400. chosen_lm_target = lm_target[:text_token.shape[0]]
  401. rejected_lm_target = lm_target[text_token.shape[0]:]
  402. loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
  403. acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
  404. # 5. calculate dpo logits
  405. chosen_lm_mask = chosen_lm_target == IGNORE_ID
  406. rejected_lm_mask = rejected_lm_target == IGNORE_ID
  407. chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  408. rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  409. chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
  410. rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
  411. return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
  412. @torch.inference_mode()
  413. def inference(
  414. self,
  415. text: torch.Tensor,
  416. text_len: torch.Tensor,
  417. prompt_text: torch.Tensor,
  418. prompt_text_len: torch.Tensor,
  419. prompt_speech_token: torch.Tensor,
  420. prompt_speech_token_len: torch.Tensor,
  421. embedding: torch.Tensor,
  422. sampling: int = 25,
  423. max_token_text_ratio: float = 20,
  424. min_token_text_ratio: float = 2,
  425. uuid: str = '',
  426. ) -> Generator[torch.Tensor, None, None]:
  427. device = text.device
  428. text = torch.concat([prompt_text, text], dim=1)
  429. text_len += prompt_text_len
  430. text_emb = self.llm.model.model.embed_tokens(text)
  431. if self.__class__.__name__ == 'CosyVoice3LM':
  432. # NOTE temporary hardcode, 151646 is <|endofprompt|> token
  433. assert 151646 in text, '<|endofprompt|> not detected in CosyVoice3 text or prompt_text, check your input!'
  434. # 3. concat llm_input
  435. if self.__class__.__name__ == 'CosyVoice3LM':
  436. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  437. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  438. elif self.__class__.__name__ == 'Qwen2LM':
  439. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  440. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  441. else:
  442. raise ValueError
  443. if prompt_speech_token_len != 0:
  444. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  445. else:
  446. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text_emb.dtype).to(device)
  447. lm_input = torch.concat([sos_emb, text_emb, task_id_emb, prompt_speech_token_emb], dim=1)
  448. # 4. cal min/max_length
  449. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  450. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  451. # 5. step by step decode
  452. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  453. yield token
  454. @torch.inference_mode()
  455. def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
  456. if hasattr(self, 'vllm'):
  457. from vllm import SamplingParams, RequestOutput
  458. sampling_params = SamplingParams(top_k=sampling,
  459. stop_token_ids=self.stop_token_ids,
  460. min_tokens=min_len,
  461. max_tokens=max_len)
  462. with self.lock:
  463. self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
  464. self.vllm_output_queue[uuid] = queue.Queue()
  465. out_tokens = []
  466. while True:
  467. with self.lock:
  468. if self.vllm_output_queue[uuid].empty() is True:
  469. request_outputs: List[RequestOutput] = self.vllm.step()
  470. for request_output in request_outputs:
  471. top_ids = list(request_output.outputs[0].token_ids)[-1]
  472. self.vllm_output_queue[request_output.request_id].put(top_ids)
  473. if self.vllm_output_queue[uuid].empty() is False:
  474. top_ids = self.vllm_output_queue[uuid].get()
  475. if top_ids in self.stop_token_ids:
  476. break
  477. # in stream mode, yield token one by one
  478. yield top_ids
  479. out_tokens.append(top_ids)
  480. if len(out_tokens) == max_len:
  481. break
  482. time.sleep(0.001)
  483. with self.lock:
  484. self.vllm_output_queue.pop(uuid)
  485. else:
  486. out_tokens = []
  487. cache = None
  488. for i in range(max_len):
  489. y_pred, cache = self.llm.forward_one_step(lm_input,
  490. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  491. cache=cache)
  492. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  493. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
  494. if top_ids in self.stop_token_ids:
  495. break
  496. # in stream mode, yield token one by one
  497. yield top_ids
  498. out_tokens.append(top_ids)
  499. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  500. @torch.inference_mode()
  501. def inference_bistream(
  502. self,
  503. text: Generator,
  504. prompt_text: torch.Tensor,
  505. prompt_text_len: torch.Tensor,
  506. prompt_speech_token: torch.Tensor,
  507. prompt_speech_token_len: torch.Tensor,
  508. embedding: torch.Tensor,
  509. sampling: int = 25,
  510. max_token_text_ratio: float = 20,
  511. min_token_text_ratio: float = 2,
  512. ) -> Generator[torch.Tensor, None, None]:
  513. device = prompt_text.device
  514. # 1. prepare input
  515. if self.__class__.__name__ == 'CosyVoice3LM':
  516. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  517. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  518. elif self.__class__.__name__ == 'Qwen2LM':
  519. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  520. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  521. else:
  522. raise ValueError
  523. if prompt_speech_token_len != 0:
  524. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  525. else:
  526. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  527. lm_input = torch.concat([sos_emb], dim=1)
  528. # 2. iterate text
  529. out_tokens = []
  530. cache = None
  531. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  532. if self.__class__.__name__ == 'CosyVoice3LM':
  533. # NOTE temporary hardcode, 151646 is <|endofprompt|> token
  534. assert 151646 in prompt_text, '<|endofprompt|> not detected in CosyVoice3 prompt_text, check your input!'
  535. eop_index = prompt_text.flatten().tolist().index(151646)
  536. lm_input = torch.concat([lm_input, self.llm.model.model.embed_tokens(prompt_text[:, :eop_index + 1])], dim=1)
  537. prompt_text = prompt_text[:, eop_index + 1:]
  538. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  539. next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
  540. for this_text in text:
  541. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  542. # prompt_speech_token_emb not empty, try append to lm_input
  543. while prompt_speech_token_emb.size(1) != 0:
  544. if text_cache.size(1) >= self.mix_ratio[0]:
  545. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  546. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  547. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  548. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  549. else:
  550. logging.info('not enough text token to decode, wait for more')
  551. break
  552. # no prompt_speech_token_emb remain, can decode some speech token
  553. if prompt_speech_token_emb.size(1) == 0:
  554. if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  555. logging.info('get fill token, need to append more text token')
  556. if text_cache.size(1) >= self.mix_ratio[0]:
  557. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  558. logging.info('append {} text token'.format(lm_input_text.size(1)))
  559. if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
  560. lm_input = lm_input_text
  561. else:
  562. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  563. text_cache = text_cache[:, self.mix_ratio[0]:]
  564. else:
  565. logging.info('not enough text token to decode, wait for more')
  566. continue
  567. while True:
  568. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  569. y_pred, cache = self.llm.forward_one_step(lm_input,
  570. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  571. cache=cache)
  572. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  573. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  574. top_ids = self.fill_token
  575. next_fill_index += (self.mix_ratio[1] + 1)
  576. else:
  577. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
  578. if top_ids == self.fill_token:
  579. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  580. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  581. out_tokens.append(top_ids)
  582. if top_ids >= self.speech_token_size:
  583. if top_ids == self.fill_token:
  584. break
  585. else:
  586. raise ValueError('should not get token {}'.format(top_ids))
  587. yield top_ids
  588. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  589. # 3. final decode
  590. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  591. logging.info('no more text token, decode until met eos')
  592. while True:
  593. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  594. y_pred, cache = self.llm.forward_one_step(lm_input,
  595. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  596. cache=cache)
  597. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  598. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
  599. out_tokens.append(top_ids)
  600. if top_ids >= self.speech_token_size:
  601. if top_ids == self.eos_token:
  602. break
  603. else:
  604. raise ValueError('should not get token {}'.format(top_ids))
  605. # in stream mode, yield token one by one
  606. yield top_ids
  607. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  608. class CosyVoice3LM(Qwen2LM):
  609. def __init__(
  610. self,
  611. llm_input_size: int,
  612. llm_output_size: int,
  613. speech_token_size: int,
  614. llm: torch.nn.Module,
  615. sampling: Callable,
  616. length_normalized_loss: bool = True,
  617. lsm_weight: float = 0.0,
  618. mix_ratio: List[int] = [5, 15],
  619. ):
  620. torch.nn.Module.__init__(self)
  621. self.llm_input_size = llm_input_size
  622. self.llm_output_size = llm_output_size
  623. self.speech_token_size = speech_token_size
  624. # 2. build speech token language model related modules
  625. self.sos = speech_token_size + 0
  626. self.eos_token = speech_token_size + 1
  627. self.task_id = speech_token_size + 2
  628. self.fill_token = speech_token_size + 3
  629. self.llm = llm
  630. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
  631. self.criterion_ce = LabelSmoothingLoss(
  632. size=speech_token_size + 200,
  633. padding_idx=IGNORE_ID,
  634. smoothing=lsm_weight,
  635. normalize_length=length_normalized_loss,
  636. )
  637. # 3. [Optional] build speech token related modules
  638. self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
  639. # 4. sampling method
  640. self.sampling = sampling
  641. self.mix_ratio = mix_ratio
  642. # 5. vllm related
  643. self.stop_token_ids = [speech_token_size + i for i in range(200)]
  644. self.vllm_output_queue = {}
  645. if online_feature is True:
  646. self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))