llm.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os, queue
  16. import random
  17. import time
  18. import threading
  19. from typing import Dict, Optional, Callable, List, Generator
  20. import numpy as np
  21. import torch
  22. from torch import nn
  23. import torch.nn.functional as F
  24. from transformers import Qwen2ForCausalLM
  25. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  26. from cosyvoice.utils.common import IGNORE_ID
  27. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  28. from cosyvoice.utils.common import th_accuracy
  29. from cosyvoice.utils.file_utils import logging
  30. from cosyvoice.utils.mask import make_pad_mask
  31. from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
  32. class TransformerLM(torch.nn.Module):
  33. def __init__(
  34. self,
  35. text_encoder_input_size: int,
  36. llm_input_size: int,
  37. llm_output_size: int,
  38. text_token_size: int,
  39. speech_token_size: int,
  40. text_encoder: torch.nn.Module,
  41. llm: torch.nn.Module,
  42. sampling: Callable,
  43. length_normalized_loss: bool = True,
  44. lsm_weight: float = 0.0,
  45. spk_embed_dim: int = 192,
  46. ):
  47. super().__init__()
  48. self.llm_input_size = llm_input_size
  49. self.speech_token_size = speech_token_size
  50. # 1. build text token inputs related modules
  51. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  52. self.text_encoder = text_encoder
  53. self.text_encoder_affine_layer = nn.Linear(
  54. self.text_encoder.output_size(),
  55. llm_input_size
  56. )
  57. # 2. build speech token language model related modules
  58. self.sos = 0
  59. self.task_id = 1
  60. self.eos_token = self.speech_token_size
  61. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  62. self.llm = llm
  63. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  64. self.criterion_ce = LabelSmoothingLoss(
  65. size=speech_token_size + 1,
  66. padding_idx=IGNORE_ID,
  67. smoothing=lsm_weight,
  68. normalize_length=length_normalized_loss,
  69. )
  70. # 3. [Optional] build speech token related modules
  71. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  72. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  73. # 4. sampling method
  74. self.sampling = sampling
  75. def encode(
  76. self,
  77. text: torch.Tensor,
  78. text_lengths: torch.Tensor,
  79. ):
  80. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  81. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  82. encoder_out = self.text_encoder_affine_layer(encoder_out)
  83. return encoder_out, encoder_out_lens
  84. def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  85. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  86. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  87. lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  88. for i in range(len(text_token))]
  89. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  90. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  91. return lm_input, lm_input_len
  92. def forward(
  93. self,
  94. batch: dict,
  95. device: torch.device,
  96. ) -> Dict[str, Optional[torch.Tensor]]:
  97. """
  98. Args:
  99. text: (B, L, D)
  100. text_lengths: (B,)
  101. audio: (B, T, N) or (B, T)
  102. audio_lengths: (B,)
  103. """
  104. text_token = batch['text_token'].to(device)
  105. text_token_len = batch['text_token_len'].to(device)
  106. speech_token = batch['speech_token'].to(device)
  107. speech_token_len = batch['speech_token_len'].to(device)
  108. embedding = batch['embedding'].to(device)
  109. # 1. prepare llm_target
  110. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  111. [self.speech_token_size]) for i in range(text_token.size(0))]
  112. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  113. # 1. encode text_token
  114. text_token = self.text_embedding(text_token)
  115. text_token, text_token_len = self.encode(text_token, text_token_len)
  116. # 2. embedding projection
  117. embedding = F.normalize(embedding, dim=1)
  118. embedding = self.spk_embed_affine_layer(embedding)
  119. embedding = embedding.unsqueeze(1)
  120. # 3. sos and task_id
  121. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  122. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  123. # 4. encode speech_token
  124. speech_token = self.speech_embedding(speech_token)
  125. # 5. unpad and pad
  126. lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
  127. task_id_emb, speech_token, speech_token_len)
  128. # 6. run lm forward
  129. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  130. logits = self.llm_decoder(lm_output)
  131. loss = self.criterion_ce(logits, lm_target)
  132. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  133. return {'loss': loss, 'acc': acc}
  134. def sampling_ids(
  135. self,
  136. weighted_scores: torch.Tensor,
  137. decoded_tokens: List,
  138. sampling: int,
  139. ignore_eos: bool = True,
  140. ):
  141. num_trials, max_trials = 0, 100
  142. while True:
  143. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  144. if (not ignore_eos) or (top_ids < self.speech_token_size):
  145. break
  146. num_trials += 1
  147. if num_trials > max_trials:
  148. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  149. return top_ids
  150. @torch.inference_mode()
  151. def inference(
  152. self,
  153. text: torch.Tensor,
  154. text_len: torch.Tensor,
  155. prompt_text: torch.Tensor,
  156. prompt_text_len: torch.Tensor,
  157. prompt_speech_token: torch.Tensor,
  158. prompt_speech_token_len: torch.Tensor,
  159. embedding: torch.Tensor,
  160. sampling: int = 25,
  161. max_token_text_ratio: float = 20,
  162. min_token_text_ratio: float = 2,
  163. uuid: str = '',
  164. ) -> Generator[torch.Tensor, None, None]:
  165. device = text.device
  166. text = torch.concat([prompt_text, text], dim=1)
  167. text_len += prompt_text_len
  168. text = self.text_embedding(text)
  169. # 1. encode text
  170. text, text_len = self.encode(text, text_len)
  171. # 2. encode embedding
  172. if embedding.shape[0] != 0:
  173. embedding = F.normalize(embedding, dim=1)
  174. embedding = self.spk_embed_affine_layer(embedding)
  175. embedding = embedding.unsqueeze(dim=1)
  176. else:
  177. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  178. # 3. concat llm_input
  179. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  180. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  181. if prompt_speech_token_len != 0:
  182. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  183. else:
  184. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  185. lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  186. # 4. cal min/max_length
  187. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  188. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  189. # 5. step by step decode
  190. out_tokens = []
  191. offset = 0
  192. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  193. for i in range(max_len):
  194. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  195. att_cache=att_cache, cnn_cache=cnn_cache,
  196. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  197. device=lm_input.device)).to(torch.bool))
  198. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  199. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
  200. if top_ids == self.eos_token:
  201. break
  202. # in stream mode, yield token one by one
  203. yield top_ids
  204. out_tokens.append(top_ids)
  205. offset += lm_input.size(1)
  206. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  207. class Qwen2Encoder(torch.nn.Module):
  208. def __init__(self, pretrain_path):
  209. super().__init__()
  210. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  211. def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
  212. T = xs.size(1)
  213. masks = ~make_pad_mask(xs_lens, T)
  214. outs = self.model(
  215. inputs_embeds=xs,
  216. attention_mask=masks,
  217. output_hidden_states=True,
  218. return_dict=True,
  219. )
  220. return outs.hidden_states[-1], masks.unsqueeze(1)
  221. def forward_one_step(self, xs, masks, cache=None):
  222. input_masks = masks[:, -1, :]
  223. outs = self.model(
  224. inputs_embeds=xs,
  225. attention_mask=input_masks,
  226. output_hidden_states=True,
  227. return_dict=True,
  228. use_cache=True,
  229. past_key_values=cache,
  230. )
  231. xs = outs.hidden_states[-1]
  232. new_cache = outs.past_key_values
  233. return xs, new_cache
  234. class Qwen2LM(TransformerLM):
  235. def __init__(
  236. self,
  237. llm_input_size: int,
  238. llm_output_size: int,
  239. speech_token_size: int,
  240. llm: torch.nn.Module,
  241. sampling: Callable,
  242. length_normalized_loss: bool = True,
  243. lsm_weight: float = 0.0,
  244. mix_ratio: List[int] = [5, 15],
  245. ):
  246. torch.nn.Module.__init__(self)
  247. self.llm_input_size = llm_input_size
  248. self.llm_output_size = llm_output_size
  249. self.speech_token_size = speech_token_size
  250. # 2. build speech token language model related modules
  251. self.sos = 0
  252. self.task_id = 1
  253. self.eos_token = speech_token_size
  254. self.fill_token = speech_token_size + 2
  255. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  256. self.llm = llm
  257. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  258. self.criterion_ce = LabelSmoothingLoss(
  259. size=speech_token_size + 3,
  260. padding_idx=IGNORE_ID,
  261. smoothing=lsm_weight,
  262. normalize_length=length_normalized_loss,
  263. )
  264. # 3. [Optional] build speech token related modules
  265. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  266. # 4. sampling method
  267. self.sampling = sampling
  268. self.mix_ratio = mix_ratio
  269. # 5. vllm related
  270. self.stop_token_ids = [speech_token_size + i for i in range(3)]
  271. self.vllm_output_queue = {}
  272. if online_feature is True:
  273. self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
  274. def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
  275. lm_target, lm_input = [], []
  276. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  277. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  278. text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
  279. speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
  280. # NOTE add instruct_token in CosyVoice3
  281. if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
  282. instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
  283. instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
  284. else:
  285. instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
  286. instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
  287. instruct_token_len = torch.zeros(len(text_token)).to(text_token_len)
  288. for i in range(len(text_token)):
  289. # bistream sequence
  290. if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
  291. this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
  292. this_lm_target += [IGNORE_ID] * instruct_token_len[i]
  293. this_lm_input.append(instruct_token_emb[i])
  294. for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
  295. this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
  296. this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
  297. if len(this_text_token) == self.mix_ratio[0]:
  298. assert len(this_speech_token) == self.mix_ratio[1]
  299. this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
  300. this_lm_target += this_speech_token
  301. this_lm_target.append(self.fill_token)
  302. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
  303. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
  304. else:
  305. this_lm_target += [-1] * len(this_text_token)
  306. this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
  307. this_lm_target.append(self.eos_token)
  308. this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
  309. this_lm_input.append(task_id_emb.squeeze(dim=0))
  310. this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
  311. this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
  312. # unistream sequence
  313. else:
  314. this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
  315. this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
  316. lm_target.append(this_lm_target)
  317. lm_input.append(this_lm_input)
  318. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  319. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  320. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
  321. return lm_target, lm_input, lm_input_len
  322. def forward(
  323. self,
  324. batch: dict,
  325. device: torch.device,
  326. ) -> Dict[str, Optional[torch.Tensor]]:
  327. """
  328. Args:
  329. text: (B, L, D)
  330. text_lengths: (B,)
  331. audio: (B, T, N) or (B, T)
  332. audio_lengths: (B,)
  333. """
  334. text_token = batch['text_token'].to(device)
  335. text_token_len = batch['text_token_len'].to(device)
  336. speech_token = batch['speech_token'].to(device)
  337. speech_token_len = batch['speech_token_len'].to(device)
  338. # 1. encode text_token
  339. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  340. # 3. sos and task_id
  341. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  342. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  343. # 2. encode speech_token
  344. speech_token_emb = self.speech_embedding(speech_token)
  345. # 3. prepare llm_input/target
  346. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
  347. speech_token, speech_token_emb, speech_token_len)
  348. lm_target = lm_target.to(device)
  349. # 4. run lm forward
  350. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  351. logits = self.llm_decoder(lm_output)
  352. loss = self.criterion_ce(logits, lm_target.to(device))
  353. acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
  354. return {'loss': loss, 'acc': acc}
  355. def forward_dpo(
  356. self,
  357. batch: dict,
  358. device: torch.device,
  359. ) -> Dict[str, Optional[torch.Tensor]]:
  360. text_token = batch['text_token'].to(device)
  361. text_token_len = batch['text_token_len'].to(device)
  362. speech_token = batch['speech_token'].to(device)
  363. speech_token_len = batch['speech_token_len'].to(device)
  364. reject_speech_token = batch['reject_speech_token'].to(device)
  365. reject_speech_token_len = batch['reject_speech_token_len'].to(device)
  366. # 1. encode text_token
  367. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  368. # 3. sos and task_id
  369. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  370. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  371. # 2. encode speech_token
  372. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  373. reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
  374. speech_token_combined = speech_token + reject_speech_token
  375. speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
  376. speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
  377. speech_token_combined_emb = self.speech_embedding(speech_token_combined)
  378. # 3. prepare llm_input/target
  379. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
  380. task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
  381. lm_target = lm_target.to(device)
  382. # 4. run lm forward
  383. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  384. logits = self.llm_decoder(lm_output)
  385. chosen_logits = logits[:text_token.shape[0]]
  386. rejected_logits = logits[text_token.shape[0]:]
  387. chosen_lm_target = lm_target[:text_token.shape[0]]
  388. rejected_lm_target = lm_target[text_token.shape[0]:]
  389. loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
  390. acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
  391. # 5. calculate dpo logits
  392. chosen_lm_mask = chosen_lm_target == IGNORE_ID
  393. rejected_lm_mask = rejected_lm_target == IGNORE_ID
  394. chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  395. rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
  396. chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
  397. rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
  398. return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
  399. @torch.inference_mode()
  400. def inference(
  401. self,
  402. text: torch.Tensor,
  403. text_len: torch.Tensor,
  404. prompt_text: torch.Tensor,
  405. prompt_text_len: torch.Tensor,
  406. prompt_speech_token: torch.Tensor,
  407. prompt_speech_token_len: torch.Tensor,
  408. embedding: torch.Tensor,
  409. sampling: int = 25,
  410. max_token_text_ratio: float = 20,
  411. min_token_text_ratio: float = 2,
  412. uuid: str = '',
  413. ) -> Generator[torch.Tensor, None, None]:
  414. device = text.device
  415. text = torch.concat([prompt_text, text], dim=1)
  416. text_len += prompt_text_len
  417. text = self.llm.model.model.embed_tokens(text)
  418. # 3. concat llm_input
  419. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  420. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  421. if prompt_speech_token_len != 0:
  422. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  423. else:
  424. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  425. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  426. # 4. cal min/max_length
  427. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  428. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  429. # 5. step by step decode
  430. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  431. yield token
  432. @torch.inference_mode()
  433. def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
  434. if hasattr(self, 'vllm'):
  435. from vllm import SamplingParams, RequestOutput
  436. sampling_params = SamplingParams(top_k=sampling,
  437. stop_token_ids=self.stop_token_ids,
  438. min_tokens=min_len,
  439. max_tokens=max_len)
  440. with self.lock:
  441. self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
  442. self.vllm_output_queue[uuid] = queue.Queue()
  443. out_tokens = []
  444. while True:
  445. with self.lock:
  446. if self.vllm_output_queue[uuid].empty() is True:
  447. request_outputs: List[RequestOutput] = self.vllm.step()
  448. for request_output in request_outputs:
  449. top_ids = list(request_output.outputs[0].token_ids)[-1]
  450. self.vllm_output_queue[request_output.request_id].put(top_ids)
  451. if self.vllm_output_queue[uuid].empty() is False:
  452. top_ids = self.vllm_output_queue[uuid].get()
  453. if top_ids in self.stop_token_ids:
  454. break
  455. # in stream mode, yield token one by one
  456. yield top_ids
  457. out_tokens.append(top_ids)
  458. if len(out_tokens) == max_len:
  459. break
  460. time.sleep(0.001)
  461. with self.lock:
  462. self.vllm_output_queue.pop(uuid)
  463. else:
  464. out_tokens = []
  465. cache = None
  466. for i in range(max_len):
  467. y_pred, cache = self.llm.forward_one_step(lm_input,
  468. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  469. cache=cache)
  470. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  471. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
  472. if top_ids in self.stop_token_ids:
  473. break
  474. # in stream mode, yield token one by one
  475. yield top_ids
  476. out_tokens.append(top_ids)
  477. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  478. @torch.inference_mode()
  479. def inference_bistream(
  480. self,
  481. text: Generator,
  482. prompt_text: torch.Tensor,
  483. prompt_text_len: torch.Tensor,
  484. prompt_speech_token: torch.Tensor,
  485. prompt_speech_token_len: torch.Tensor,
  486. embedding: torch.Tensor,
  487. sampling: int = 25,
  488. max_token_text_ratio: float = 20,
  489. min_token_text_ratio: float = 2,
  490. ) -> Generator[torch.Tensor, None, None]:
  491. device = prompt_text.device
  492. # 1. prepare input
  493. sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
  494. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  495. if prompt_speech_token_len != 0:
  496. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  497. else:
  498. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  499. lm_input = torch.concat([sos_emb], dim=1)
  500. # 2. iterate text
  501. out_tokens = []
  502. cache = None
  503. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  504. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  505. next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
  506. for this_text in text:
  507. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  508. # prompt_speech_token_emb not empty, try append to lm_input
  509. while prompt_speech_token_emb.size(1) != 0:
  510. if text_cache.size(1) >= self.mix_ratio[0]:
  511. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  512. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  513. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  514. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  515. else:
  516. logging.info('not enough text token to decode, wait for more')
  517. break
  518. # no prompt_speech_token_emb remain, can decode some speech token
  519. if prompt_speech_token_emb.size(1) == 0:
  520. if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  521. logging.info('get fill token, need to append more text token')
  522. if text_cache.size(1) >= self.mix_ratio[0]:
  523. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  524. logging.info('append {} text token'.format(lm_input_text.size(1)))
  525. if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
  526. lm_input = lm_input_text
  527. else:
  528. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  529. text_cache = text_cache[:, self.mix_ratio[0]:]
  530. else:
  531. logging.info('not enough text token to decode, wait for more')
  532. continue
  533. while True:
  534. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  535. y_pred, cache = self.llm.forward_one_step(lm_input,
  536. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  537. cache=cache)
  538. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  539. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  540. top_ids = self.fill_token
  541. next_fill_index += (self.mix_ratio[1] + 1)
  542. else:
  543. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
  544. if top_ids == self.fill_token:
  545. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  546. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  547. out_tokens.append(top_ids)
  548. if top_ids >= self.speech_token_size:
  549. if top_ids == self.fill_token:
  550. break
  551. else:
  552. raise ValueError('should not get token {}'.format(top_ids))
  553. yield top_ids
  554. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  555. # 3. final decode
  556. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  557. logging.info('no more text token, decode until met eos')
  558. while True:
  559. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  560. y_pred, cache = self.llm.forward_one_step(lm_input,
  561. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  562. cache=cache)
  563. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  564. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
  565. out_tokens.append(top_ids)
  566. if top_ids >= self.speech_token_size:
  567. if top_ids == self.eos_token:
  568. break
  569. else:
  570. raise ValueError('should not get token {}'.format(top_ids))
  571. # in stream mode, yield token one by one
  572. yield top_ids
  573. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  574. class CosyVoice3LM(Qwen2LM):
  575. def __init__(
  576. self,
  577. llm_input_size: int,
  578. llm_output_size: int,
  579. speech_token_size: int,
  580. llm: torch.nn.Module,
  581. sampling: Callable,
  582. length_normalized_loss: bool = True,
  583. lsm_weight: float = 0.0,
  584. mix_ratio: List[int] = [5, 15],
  585. ):
  586. torch.nn.Module.__init__(self)
  587. self.llm_input_size = llm_input_size
  588. self.llm_output_size = llm_output_size
  589. self.speech_token_size = speech_token_size
  590. # 2. build speech token language model related modules
  591. self.sos = speech_token_size + 0
  592. self.eos_token = speech_token_size + 1
  593. self.task_id = speech_token_size + 2
  594. self.fill_token = speech_token_size + 3
  595. self.llm = llm
  596. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
  597. self.criterion_ce = LabelSmoothingLoss(
  598. size=speech_token_size + 200,
  599. padding_idx=IGNORE_ID,
  600. smoothing=lsm_weight,
  601. normalize_length=length_normalized_loss,
  602. )
  603. # 3. [Optional] build speech token related modules
  604. self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
  605. # 4. sampling method
  606. self.sampling = sampling
  607. self.mix_ratio = mix_ratio
  608. # 5. vllm related
  609. self.stop_token_ids = [speech_token_size + i for i in range(200)]
  610. self.vllm_output_queue = {}
  611. if online_feature is True:
  612. self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
  613. def forward(
  614. self,
  615. batch: dict,
  616. device: torch.device,
  617. ) -> Dict[str, Optional[torch.Tensor]]:
  618. """
  619. Args:
  620. text: (B, L, D)
  621. text_lengths: (B,)
  622. audio: (B, T, N) or (B, T)
  623. audio_lengths: (B,)
  624. """
  625. text_token = batch['text_token'].to(device)
  626. text_token_len = batch['text_token_len'].to(device)
  627. speech_token = batch['speech_token'].to(device)
  628. speech_token_len = batch['speech_token_len'].to(device)
  629. # NOTE should append instruct_token to sequence, not implemented yet
  630. instruct_token = batch['instruct_token'].to(device)
  631. instruct_token_len = batch['instruct_token_len'].to(device)
  632. # 1. encode text_token
  633. text_token_emb = self.llm.model.model.embed_tokens(text_token)
  634. instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
  635. # 3. sos and task_id
  636. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  637. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  638. # 2. encode speech_token
  639. speech_token_emb = self.speech_embedding(speech_token)
  640. # 3. prepare llm_input/target
  641. lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
  642. speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
  643. lm_target = lm_target.to(device)
  644. # 4. run lm forward
  645. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  646. logits = self.llm_decoder(lm_output)
  647. loss = self.criterion_ce(logits, lm_target.to(device))
  648. acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
  649. return {'loss': loss, 'acc': acc}
  650. @torch.inference_mode()
  651. def inference(
  652. self,
  653. text: torch.Tensor,
  654. text_len: torch.Tensor,
  655. prompt_text: torch.Tensor,
  656. prompt_text_len: torch.Tensor,
  657. prompt_speech_token: torch.Tensor,
  658. prompt_speech_token_len: torch.Tensor,
  659. embedding: torch.Tensor,
  660. sampling: int = 25,
  661. max_token_text_ratio: float = 20,
  662. min_token_text_ratio: float = 2,
  663. uuid: str = '',
  664. ) -> Generator[torch.Tensor, None, None]:
  665. device = text.device
  666. text = torch.concat([prompt_text, text], dim=1)
  667. text_len += prompt_text_len
  668. text = self.llm.model.model.embed_tokens(text)
  669. # 3. concat llm_input
  670. sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
  671. task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
  672. if prompt_speech_token_len != 0:
  673. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  674. else:
  675. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  676. lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  677. # 4. cal min/max_length
  678. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  679. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  680. # 5. step by step decode
  681. for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
  682. yield token