1
0

llm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, Optional, Callable, List, Generator
  15. import torch
  16. from torch import nn
  17. import torch.nn.functional as F
  18. from transformers import Qwen2ForCausalLM
  19. from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  20. from cosyvoice.utils.common import IGNORE_ID
  21. from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
  22. from cosyvoice.utils.common import th_accuracy
  23. class TransformerLM(torch.nn.Module):
  24. def __init__(
  25. self,
  26. text_encoder_input_size: int,
  27. llm_input_size: int,
  28. llm_output_size: int,
  29. text_token_size: int,
  30. speech_token_size: int,
  31. text_encoder: torch.nn.Module,
  32. llm: torch.nn.Module,
  33. sampling: Callable,
  34. length_normalized_loss: bool = True,
  35. lsm_weight: float = 0.0,
  36. spk_embed_dim: int = 192,
  37. ):
  38. super().__init__()
  39. self.llm_input_size = llm_input_size
  40. self.speech_token_size = speech_token_size
  41. # 1. build text token inputs related modules
  42. self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
  43. self.text_encoder = text_encoder
  44. self.text_encoder_affine_layer = nn.Linear(
  45. self.text_encoder.output_size(),
  46. llm_input_size
  47. )
  48. # 2. build speech token language model related modules
  49. self.sos_eos = 0
  50. self.task_id = 1
  51. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  52. self.llm = llm
  53. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
  54. self.criterion_ce = LabelSmoothingLoss(
  55. size=speech_token_size + 1,
  56. padding_idx=IGNORE_ID,
  57. smoothing=lsm_weight,
  58. normalize_length=length_normalized_loss,
  59. )
  60. # 3. [Optional] build speech token related modules
  61. self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
  62. self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
  63. # 4. sampling method
  64. self.sampling = sampling
  65. def encode(
  66. self,
  67. text: torch.Tensor,
  68. text_lengths: torch.Tensor,
  69. ):
  70. encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
  71. encoder_out_lens = encoder_mask.squeeze(1).sum(1)
  72. encoder_out = self.text_encoder_affine_layer(encoder_out)
  73. return encoder_out, encoder_out_lens
  74. def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
  75. text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
  76. speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
  77. lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
  78. for i in range(len(text_token))]
  79. lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
  80. lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
  81. return lm_input, lm_input_len
  82. def forward(
  83. self,
  84. batch: dict,
  85. device: torch.device,
  86. ) -> Dict[str, Optional[torch.Tensor]]:
  87. """
  88. Args:
  89. text: (B, L, D)
  90. text_lengths: (B,)
  91. audio: (B, T, N) or (B, T)
  92. audio_lengths: (B,)
  93. """
  94. text_token = batch['text_token'].to(device)
  95. text_token_len = batch['text_token_len'].to(device)
  96. speech_token = batch['speech_token'].to(device)
  97. speech_token_len = batch['speech_token_len'].to(device)
  98. embedding = batch['embedding'].to(device)
  99. # 1. prepare llm_target
  100. lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
  101. [self.speech_token_size]) for i in range(text_token.size(0))]
  102. lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
  103. # 1. encode text_token
  104. text_token = self.text_embedding(text_token)
  105. text_token, text_token_len = self.encode(text_token, text_token_len)
  106. # 2. embedding projection
  107. embedding = F.normalize(embedding, dim=1)
  108. embedding = self.spk_embed_affine_layer(embedding)
  109. embedding = embedding.unsqueeze(1)
  110. # 3. eos and task_id
  111. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  112. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  113. # 4. encode speech_token
  114. speech_token = self.speech_embedding(speech_token)
  115. # 5. unpad and pad
  116. lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
  117. task_id_emb, speech_token, speech_token_len)
  118. # 6. run lm forward
  119. lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
  120. logits = self.llm_decoder(lm_output)
  121. loss = self.criterion_ce(logits, lm_target)
  122. acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
  123. return {'loss': loss, 'acc': acc}
  124. def sampling_ids(
  125. self,
  126. weighted_scores: torch.Tensor,
  127. decoded_tokens: List,
  128. sampling: int,
  129. ignore_eos: bool = True,
  130. ):
  131. while True:
  132. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  133. if (not ignore_eos) or (self.speech_token_size not in top_ids):
  134. break
  135. return top_ids
  136. @torch.inference_mode()
  137. def inference(
  138. self,
  139. text: torch.Tensor,
  140. text_len: torch.Tensor,
  141. prompt_text: torch.Tensor,
  142. prompt_text_len: torch.Tensor,
  143. prompt_speech_token: torch.Tensor,
  144. prompt_speech_token_len: torch.Tensor,
  145. embedding: torch.Tensor,
  146. sampling: int = 25,
  147. max_token_text_ratio: float = 20,
  148. min_token_text_ratio: float = 2,
  149. ) -> Generator[torch.Tensor, None, None]:
  150. device = text.device
  151. text = torch.concat([prompt_text, text], dim=1)
  152. text_len += prompt_text_len
  153. text = self.text_embedding(text)
  154. # 1. encode text
  155. text, text_len = self.encode(text, text_len)
  156. # 2. encode embedding
  157. if embedding.shape[0] != 0:
  158. embedding = F.normalize(embedding, dim=1)
  159. embedding = self.spk_embed_affine_layer(embedding)
  160. embedding = embedding.unsqueeze(dim=1)
  161. else:
  162. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  163. # 3. concat llm_input
  164. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  165. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  166. if prompt_speech_token_len != 0:
  167. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  168. else:
  169. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  170. lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  171. # 4. cal min/max_length
  172. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  173. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  174. # 5. step by step decode
  175. out_tokens = []
  176. offset = 0
  177. att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
  178. for i in range(max_len):
  179. y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
  180. att_cache=att_cache, cnn_cache=cnn_cache,
  181. att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
  182. device=lm_input.device)).to(torch.bool))
  183. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  184. # force continue decode first token
  185. if i == 0:
  186. logp[:, self.speech_token_size] = -float('inf')
  187. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  188. if top_ids == self.speech_token_size:
  189. break
  190. # in stream mode, yield token one by one
  191. yield top_ids
  192. out_tokens.append(top_ids)
  193. offset += lm_input.size(1)
  194. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
  195. class Qwen2Encoder(torch.nn.Module):
  196. def __init__(self, pretrain_path):
  197. super().__init__()
  198. self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
  199. def forward_one_step(self, xs, masks, cache=None):
  200. input_masks = masks[:, -1, :]
  201. outs = self.model(
  202. inputs_embeds=xs,
  203. attention_mask=input_masks,
  204. output_hidden_states=True,
  205. return_dict=True,
  206. use_cache=True,
  207. past_key_values=cache,
  208. )
  209. xs = outs.hidden_states[-1]
  210. new_cache = outs.past_key_values
  211. return xs, new_cache
  212. class Qwen2LM(torch.nn.Module):
  213. def __init__(
  214. self,
  215. llm_input_size: int,
  216. llm_output_size: int,
  217. speech_token_size: int,
  218. llm: torch.nn.Module,
  219. sampling: Callable,
  220. length_normalized_loss: bool = True,
  221. lsm_weight: float = 0.0,
  222. ):
  223. super().__init__()
  224. self.llm_input_size = llm_input_size
  225. self.llm_output_size = llm_output_size
  226. self.speech_token_size = speech_token_size
  227. # 2. build speech token language model related modules
  228. self.sos_eos = 0
  229. self.task_id = 1
  230. self.fill_token = 2
  231. self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
  232. self.llm = llm
  233. self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
  234. self.criterion_ce = LabelSmoothingLoss(
  235. size=speech_token_size + 3,
  236. padding_idx=IGNORE_ID,
  237. smoothing=lsm_weight,
  238. normalize_length=length_normalized_loss,
  239. )
  240. # 3. [Optional] build speech token related modules
  241. self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
  242. # 4. sampling method
  243. self.sampling = sampling
  244. def sampling_ids(
  245. self,
  246. weighted_scores: torch.Tensor,
  247. decoded_tokens: List,
  248. sampling: int,
  249. ignore_eos: bool = True,
  250. ):
  251. while True:
  252. top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
  253. if (not ignore_eos) or (self.speech_token_size not in top_ids):
  254. break
  255. return top_ids
  256. @torch.inference_mode()
  257. def inference(
  258. self,
  259. text: torch.Tensor,
  260. text_len: torch.Tensor,
  261. prompt_text: torch.Tensor,
  262. prompt_text_len: torch.Tensor,
  263. prompt_speech_token: torch.Tensor,
  264. prompt_speech_token_len: torch.Tensor,
  265. embedding: torch.Tensor,
  266. sampling: int = 25,
  267. max_token_text_ratio: float = 20,
  268. min_token_text_ratio: float = 2,
  269. ) -> Generator[torch.Tensor, None, None]:
  270. device = text.device
  271. text = torch.concat([prompt_text, text], dim=1)
  272. text_len += prompt_text_len
  273. text = self.llm.model.model.embed_tokens(text)
  274. # 2. encode embedding
  275. embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  276. # 3. concat llm_input
  277. sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  278. task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  279. if prompt_speech_token_len != 0:
  280. prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
  281. else:
  282. prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
  283. lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
  284. # 4. cal min/max_length
  285. min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
  286. max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
  287. # 5. step by step decode
  288. out_tokens = []
  289. cache = None
  290. for i in range(max_len):
  291. y_pred, cache = self.llm.forward_one_step(lm_input,
  292. masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
  293. cache=cache)
  294. logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
  295. top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
  296. if top_ids == self.speech_token_size:
  297. break
  298. if top_ids > self.speech_token_size:
  299. continue
  300. # in stream mode, yield token one by one
  301. yield top_ids
  302. out_tokens.append(top_ids)
  303. lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)