llm.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, Optional, Callable, List, Generator
  15. import torch
  16. from torch import nn
  17. import torch.nn.functional as F
  18. from transformers import Qwen2ForCausalLM
  19. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  20. from cosyvoice.utils.common import IGNORE_ID
  21. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  22. from cosyvoice.utils.common import th_accuracy
  23. from cosyvoice.utils.file_utils import logging
  24. class TransformerLM(torch.nn.Module):
  25. def __init__(
  26. self,
  27. text_encoder_input_size: int,
  28. llm_input_size: int,
  29. llm_output_size: int,
  30. text_token_size: int,
  31. speech_token_size: int,
  32. text_encoder: torch.nn.Module,
  33. llm: torch.nn.Module,
  34. sampling: Callable,
  35. length_normalized_loss: bool = True,
  36. lsm_weight: float = 0.0,
  37. spk_embed_dim: int = 192,
  38. ):
  39. super().__init__()
  40. self.llm_input_size = llm_input_size
  41. self.speech_token_size = speech_token_size
  42. # 1. build text token inputs related modules
  43. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  44. self.text_encoder = text_encoder
  45. self.text_encoder_affine_layer = nn.Linear(
  46. self.text_encoder.output_size(),
  47. llm_input_size
  48. )
  49. # 2. build speech token language model related modules
  50. self.sos_eos = 0
  51. self.task_id = 1
  52. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  53. self.llm = llm
  54. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  55. self.criterion_ce = LabelSmoothingLoss(
  56. size=speech_token_size + 1,
  57. padding_idx=IGNORE_ID,
  58. smoothing=lsm_weight,
  59. normalize_length=length_normalized_loss,
  60. )
  61. # 3. [Optional] build speech token related modules
  62. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  63. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  64. # 4. sampling method
  65. self.sampling = sampling
  66. def encode(
  67. self,
  68. text: torch.Tensor,
  69. text_lengths: torch.Tensor,
  70. ):
  71. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  72. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  73. encoder_out = self.text_encoder_affine_layer(encoder_out)
  74. return encoder_out, encoder_out_lens
  75. def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  76. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  77. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  78. lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  79. for i in range(len(text_token))]
  80. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  81. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  82. return lm_input, lm_input_len
  83. def forward(
  84. self,
  85. batch: dict,
  86. device: torch.device,
  87. ) -> Dict[str, Optional[torch.Tensor]]:
  88. """
  89. Args:
  90. text: (B, L, D)
  91. text_lengths: (B,)
  92. audio: (B, T, N) or (B, T)
  93. audio_lengths: (B,)
  94. """
  95. text_token = batch['text_token'].to(device)
  96. text_token_len = batch['text_token_len'].to(device)
  97. speech_token = batch['speech_token'].to(device)
  98. speech_token_len = batch['speech_token_len'].to(device)
  99. embedding = batch['embedding'].to(device)
  100. # 1. prepare llm_target
  101. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  102. [self.speech_token_size]) for i in range(text_token.size(0))]
  103. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  104. # 1. encode text_token
  105. text_token = self.text_embedding(text_token)
  106. text_token, text_token_len = self.encode(text_token, text_token_len)
  107. # 2. embedding projection
  108. embedding = F.normalize(embedding, dim=1)
  109. embedding = self.spk_embed_affine_layer(embedding)
  110. embedding = embedding.unsqueeze(1)
  111. # 3. eos and task_id
  112. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  113. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  114. # 4. encode speech_token
  115. speech_token = self.speech_embedding(speech_token)
  116. # 5. unpad and pad
  117. lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
  118. task_id_emb, speech_token, speech_token_len)
  119. # 6. run lm forward
  120. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  121. logits = self.llm_decoder(lm_output)
  122. loss = self.criterion_ce(logits, lm_target)
  123. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  124. return {'loss': loss, 'acc': acc}
  125. def sampling_ids(
  126. self,
  127. weighted_scores: torch.Tensor,
  128. decoded_tokens: List,
  129. sampling: int,
  130. ignore_eos: bool = True,
  131. ):
  132. num_trials, max_trials = 0, 100
  133. while True:
  134. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  135. if (not ignore_eos) or (self.speech_token_size not in top_ids):
  136. break
  137. num_trials += 1
  138. if num_trials > max_trials:
  139. raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
  140. return top_ids
  141. @torch.inference_mode()
  142. def inference(
  143. self,
  144. text: torch.Tensor,
  145. text_len: torch.Tensor,
  146. prompt_text: torch.Tensor,
  147. prompt_text_len: torch.Tensor,
  148. prompt_speech_token: torch.Tensor,
  149. prompt_speech_token_len: torch.Tensor,
  150. embedding: torch.Tensor,
  151. sampling: int = 25,
  152. max_token_text_ratio: float = 20,
  153. min_token_text_ratio: float = 2,
  154. ) -> Generator[torch.Tensor, None, None]:
  155. if self.fp16 is True:
  156. embedding = embedding.half()
  157. device = text.device
  158. text = torch.concat([prompt_text, text], dim=1)
  159. text_len += prompt_text_len
  160. text = self.text_embedding(text)
  161. # 1. encode text
  162. text, text_len = self.encode(text, text_len)
  163. # 2. encode embedding
  164. if embedding.shape[0] != 0:
  165. embedding = F.normalize(embedding, dim=1)
  166. embedding = self.spk_embed_affine_layer(embedding)
  167. embedding = embedding.unsqueeze(dim=1)
  168. else:
  169. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
  170. # 3. concat llm_input
  171. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  172. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  173. if prompt_speech_token_len != 0:
  174. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  175. else:
  176. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  177. lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  178. # 4. cal min/max_length
  179. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  180. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  181. # 5. step by step decode
  182. out_tokens = []
  183. offset = 0
  184. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  185. for i in range(max_len):
  186. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  187. att_cache=att_cache, cnn_cache=cnn_cache,
  188. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  189. device=lm_input.device)).to(torch.bool))
  190. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  191. # force continue decode first token
  192. if i == 0:
  193. logp[:, self.speech_token_size] = -float('inf')
  194. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  195. if top_ids == self.speech_token_size:
  196. break
  197. # in stream mode, yield token one by one
  198. yield top_ids
  199. out_tokens.append(top_ids)
  200. offset += lm_input.size(1)
  201. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  202. class Qwen2Encoder(torch.nn.Module):
  203. def __init__(self, pretrain_path):
  204. super().__init__()
  205. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  206. def forward_one_step(self, xs, masks, cache=None):
  207. input_masks = masks[:, -1, :]
  208. outs = self.model(
  209. inputs_embeds=xs,
  210. attention_mask=input_masks,
  211. output_hidden_states=True,
  212. return_dict=True,
  213. use_cache=True,
  214. past_key_values=cache,
  215. )
  216. xs = outs.hidden_states[-1]
  217. new_cache = outs.past_key_values
  218. return xs, new_cache
  219. class Qwen2LM(TransformerLM):
  220. def __init__(
  221. self,
  222. llm_input_size: int,
  223. llm_output_size: int,
  224. speech_token_size: int,
  225. llm: torch.nn.Module,
  226. sampling: Callable,
  227. length_normalized_loss: bool = True,
  228. lsm_weight: float = 0.0,
  229. mix_ratio: List[int] = [5, 15],
  230. ):
  231. torch.nn.Module.__init__(self)
  232. self.llm_input_size = llm_input_size
  233. self.llm_output_size = llm_output_size
  234. self.speech_token_size = speech_token_size
  235. # 2. build speech token language model related modules
  236. self.sos_eos = 0
  237. self.task_id = 1
  238. self.fill_token = 2
  239. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  240. self.llm = llm
  241. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  242. self.criterion_ce = LabelSmoothingLoss(
  243. size=speech_token_size + 3,
  244. padding_idx=IGNORE_ID,
  245. smoothing=lsm_weight,
  246. normalize_length=length_normalized_loss,
  247. )
  248. # 3. [Optional] build speech token related modules
  249. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  250. # 4. sampling method
  251. self.sampling = sampling
  252. self.mix_ratio = mix_ratio
  253. @torch.inference_mode()
  254. def inference(
  255. self,
  256. text: torch.Tensor,
  257. text_len: torch.Tensor,
  258. prompt_text: torch.Tensor,
  259. prompt_text_len: torch.Tensor,
  260. prompt_speech_token: torch.Tensor,
  261. prompt_speech_token_len: torch.Tensor,
  262. embedding: torch.Tensor,
  263. sampling: int = 25,
  264. max_token_text_ratio: float = 20,
  265. min_token_text_ratio: float = 2,
  266. ) -> Generator[torch.Tensor, None, None]:
  267. device = text.device
  268. text = torch.concat([prompt_text, text], dim=1)
  269. text_len += prompt_text_len
  270. text = self.llm.model.model.embed_tokens(text)
  271. # 3. concat llm_input
  272. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  273. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  274. if prompt_speech_token_len != 0:
  275. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  276. else:
  277. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  278. lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
  279. # 4. cal min/max_length
  280. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  281. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  282. # 5. step by step decode
  283. out_tokens = []
  284. cache = None
  285. for i in range(max_len):
  286. y_pred, cache = self.llm.forward_one_step(lm_input,
  287. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  288. cache=cache)
  289. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  290. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  291. if top_ids == self.speech_token_size:
  292. break
  293. if top_ids > self.speech_token_size:
  294. continue
  295. # in stream mode, yield token one by one
  296. yield top_ids
  297. out_tokens.append(top_ids)
  298. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  299. @torch.inference_mode()
  300. def inference_bistream(
  301. self,
  302. text: Generator,
  303. prompt_text: torch.Tensor,
  304. prompt_text_len: torch.Tensor,
  305. prompt_speech_token: torch.Tensor,
  306. prompt_speech_token_len: torch.Tensor,
  307. embedding: torch.Tensor,
  308. sampling: int = 25,
  309. max_token_text_ratio: float = 20,
  310. min_token_text_ratio: float = 2,
  311. ) -> Generator[torch.Tensor, None, None]:
  312. device = prompt_text.device
  313. # 1. prepare input
  314. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  315. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  316. if prompt_speech_token_len != 0:
  317. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  318. else:
  319. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
  320. lm_input = torch.concat([sos_eos_emb], dim=1)
  321. # 2. iterate text
  322. out_tokens = []
  323. cache = None
  324. # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
  325. text_cache = self.llm.model.model.embed_tokens(prompt_text)
  326. next_fill_index = -1
  327. for this_text in text:
  328. text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
  329. # prompt_speech_token_emb not empty, try append to lm_input
  330. while prompt_speech_token_emb.size(1) != 0:
  331. if text_cache.size(1) >= self.mix_ratio[0]:
  332. lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
  333. logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
  334. lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
  335. text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
  336. else:
  337. logging.info('not enough text token to decode, wait for more')
  338. break
  339. # no prompt_speech_token_emb remain, can decode some speech token
  340. if prompt_speech_token_emb.size(1) == 0:
  341. if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
  342. logging.info('get fill token, need to append more text token')
  343. if text_cache.size(1) >= self.mix_ratio[0]:
  344. lm_input_text = text_cache[:, :self.mix_ratio[0]]
  345. logging.info('append {} text token'.format(lm_input_text.size(1)))
  346. if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
  347. lm_input = lm_input_text
  348. else:
  349. lm_input = torch.concat([lm_input, lm_input_text], dim=1)
  350. text_cache = text_cache[:, self.mix_ratio[0]:]
  351. else:
  352. logging.info('not enough text token to decode, wait for more')
  353. continue
  354. while True:
  355. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  356. y_pred, cache = self.llm.forward_one_step(lm_input,
  357. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  358. cache=cache)
  359. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  360. if next_fill_index != -1 and len(out_tokens) == next_fill_index:
  361. top_ids = self.speech_token_size + 2
  362. next_fill_index += (self.mix_ratio[1] + 1)
  363. else:
  364. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
  365. if top_ids == self.speech_token_size + 2:
  366. next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
  367. logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
  368. out_tokens.append(top_ids)
  369. if top_ids >= self.speech_token_size:
  370. if top_ids == self.speech_token_size + 2:
  371. break
  372. else:
  373. raise ValueError('should not get token {}'.format(top_ids))
  374. yield top_ids
  375. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  376. # 3. final decode
  377. lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
  378. logging.info('no more text token, decode until met eos')
  379. while True:
  380. seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
  381. y_pred, cache = self.llm.forward_one_step(lm_input,
  382. masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
  383. cache=cache)
  384. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  385. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
  386. out_tokens.append(top_ids)
  387. if top_ids >= self.speech_token_size:
  388. if top_ids == self.speech_token_size:
  389. break
  390. else:
  391. raise ValueError('should not get token {}'.format(top_ids))
  392. # in stream mode, yield token one by one
  393. yield top_ids
  394. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)