vllm_use_cosyvoice2_model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # SPDX-License-Identifier: Apache-2.0
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  25. from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
  26. from typing_extensions import TypeVar
  27. import torch
  28. from torch import nn
  29. from vllm.attention import AttentionMetadata
  30. from vllm.config import VllmConfig
  31. from vllm.logger import init_logger
  32. from vllm.model_executor.layers.logits_processor import LogitsProcessor
  33. from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
  34. from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
  35. from vllm.model_executor.sampling_metadata import SamplingMetadata
  36. from vllm.sequence import IntermediateTensors
  37. from vllm.model_executor.models.interfaces import T
  38. from vllm.model_executor.models.qwen2 import Qwen2Model
  39. from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
  40. logger = init_logger(__name__)
  41. IGNORE_ID = -1
  42. class CosyVoice2Model(nn.Module):
  43. packed_modules_mapping = {
  44. "qkv_proj": [
  45. "q_proj",
  46. "k_proj",
  47. "v_proj",
  48. ],
  49. "gate_up_proj": [
  50. "gate_proj",
  51. "up_proj",
  52. ],
  53. }
  54. def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
  55. super().__init__()
  56. config = vllm_config.model_config.hf_config
  57. quant_config = vllm_config.quant_config
  58. lora_config = vllm_config.lora_config
  59. self.config = config
  60. self.lora_config = lora_config
  61. self.quant_config = quant_config
  62. self.llm_input_size = 896
  63. self.llm_output_size = 896
  64. self.speech_token_size = 6561+3
  65. self.llm_token_size = config.vocab_size
  66. # 2. build speech token language model related modules
  67. self.sos_eos = 0
  68. self.task_id = 1
  69. self.fill_token = 2
  70. self.allow_patterns_overrides = ["llm.*"]
  71. self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
  72. self.model = Qwen2Model(vllm_config=vllm_config,
  73. prefix=maybe_prefix(prefix, "model"))
  74. # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
  75. self.llm_decoder = ParallelLMHead(self.speech_token_size,
  76. self.llm_output_size,
  77. bias=True,
  78. quant_config=quant_config,
  79. prefix=maybe_prefix(
  80. prefix, "llm_decoder"))
  81. self.logits_processor = LogitsProcessor(self.speech_token_size)
  82. # length_normalized_loss: bool = True,
  83. # lsm_weight: float = 0.0,
  84. # self.criterion_ce = LabelSmoothingLoss(
  85. # size=self.speech_token_size,
  86. # padding_idx=IGNORE_ID,
  87. # smoothing=lsm_weight,
  88. # normalize_length=length_normalized_loss,
  89. # )
  90. # 3. [Optional] build speech token related modules
  91. self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
  92. # 4. sampling method
  93. ## use vllm sampling method
  94. self.sampler = get_sampler()
  95. self.make_empty_intermediate_tensors = (
  96. self.model.make_empty_intermediate_tensors)
  97. self.mix_ratio: List[int] = [5, 15]
  98. # 定义特殊token常量
  99. self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
  100. self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404
  101. self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405
  102. self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
  103. self.zero_embed_buffer = torch.zeros(
  104. (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
  105. dtype=self.llm_embedding.weight.dtype,
  106. device=self.llm_embedding.weight.device
  107. )
  108. self.inputs_embed_buffer = torch.zeros(
  109. (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
  110. dtype=self.llm_embedding.weight.dtype,
  111. device=self.llm_embedding.weight.device,
  112. )
  113. def get_sos_eos_emb(self):
  114. return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
  115. def get_task_id_emb(self):
  116. return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
  117. def get_input_embeddings(
  118. self,
  119. input_ids: torch.Tensor,
  120. multimodal_embeddings: Optional[T] = None,
  121. attn_metadata: Optional["AttentionMetadata"] = None,
  122. ) -> torch.Tensor:
  123. """
  124. Returns the input embeddings merged from the text embeddings from
  125. input_ids and the multimodal embeddings generated from multimodal
  126. kwargs.
  127. """
  128. # 创建掩码,标记哪些 token_id 属于音频 Token
  129. mask = input_ids < self.speech_token_size
  130. # 获取 input_ids 的原始形状
  131. input_shape = input_ids.shape
  132. # 展平 input_ids 和掩码以便统一处理
  133. flat_input_ids = input_ids.view(-1)
  134. flat_mask = mask.view(-1)
  135. inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
  136. inputs_embeds.zero_()
  137. # Process speech tokens
  138. if flat_mask.any():
  139. speech_token_ids = flat_input_ids[flat_mask]
  140. inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
  141. # 处理大于 delta 的 token_id
  142. if (~flat_mask).any():
  143. llm_token_ids = flat_input_ids[~flat_mask]
  144. llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
  145. sos_eos_mask = llm_token_ids == self.sos_eos_token_id
  146. task_mask = llm_token_ids == self.task_token_id
  147. zero_mask = llm_token_ids == self.zero_token_id
  148. normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
  149. # 分层处理逻辑
  150. # 第一优先级:SOS/EOS标记
  151. if sos_eos_mask.any():
  152. llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
  153. # 第二优先级:任务标记
  154. if task_mask.any():
  155. llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
  156. # 第二优先级:空音频标记
  157. if zero_mask.any():
  158. llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
  159. # 常规LLM token
  160. if normal_mask.any():
  161. original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
  162. # print('original_ids: ',original_ids)
  163. llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
  164. inputs_embeds[~flat_mask] = llm_embeds
  165. inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
  166. # 合并多模态嵌入(如果有)
  167. if multimodal_embeddings is not None:
  168. inputs_embeds = merge_multimodal_embeddings(
  169. input_ids, inputs_embeds, multimodal_embeddings,
  170. self.config.audio_token_index
  171. )
  172. return inputs_embeds
  173. def forward(
  174. self,
  175. input_ids: torch.Tensor,
  176. positions: torch.Tensor,
  177. kv_caches: List[torch.Tensor],
  178. attn_metadata: AttentionMetadata,
  179. intermediate_tensors: Optional[IntermediateTensors] = None,
  180. inputs_embeds: Optional[torch.Tensor] = None,
  181. ) -> Union[torch.Tensor, IntermediateTensors]:
  182. if inputs_embeds is None:
  183. inputs_embeds = self.get_input_embeddings(
  184. input_ids,
  185. attn_metadata=attn_metadata,
  186. )
  187. return self.model(input_ids, positions, kv_caches,
  188. attn_metadata, intermediate_tensors,
  189. inputs_embeds)
  190. def compute_logits(
  191. self,
  192. hidden_states: torch.Tensor,
  193. sampling_metadata: SamplingMetadata,
  194. ) -> Optional[torch.Tensor]:
  195. logits = self.logits_processor(self.llm_decoder, hidden_states,
  196. sampling_metadata)
  197. return logits
  198. def sample(
  199. self,
  200. logits: torch.Tensor,
  201. sampling_metadata: SamplingMetadata,
  202. ) -> Optional[SamplerOutput]:
  203. next_tokens = self.sampler(logits, sampling_metadata)
  204. return next_tokens
  205. @staticmethod
  206. def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
  207. for name, param in weights:
  208. # 处理Qwen2Model核心参数
  209. if name.startswith("llm."):
  210. if name.startswith("llm.model.model."):
  211. name = name.replace("llm.model.model.", "model.")
  212. else:
  213. continue
  214. # print('weights name: ', name)
  215. yield name, param
  216. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  217. weights = self.convert_weights(weights)
  218. loader = AutoWeightsLoader(self)
  219. loader.load_weights(weights)