|
|
@@ -0,0 +1,263 @@
|
|
|
+# SPDX-License-Identifier: Apache-2.0
|
|
|
+
|
|
|
+# Adapted from
|
|
|
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
|
|
+# Copyright 2024 The Qwen team.
|
|
|
+# Copyright 2023 The vLLM team.
|
|
|
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
|
+#
|
|
|
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
|
+# and OPT implementations in this library. It has been modified from its
|
|
|
+# original forms to accommodate minor architectural differences compared
|
|
|
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
|
|
+from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
|
|
|
+from typing_extensions import TypeVar
|
|
|
+
|
|
|
+import torch
|
|
|
+from torch import nn
|
|
|
+
|
|
|
+from vllm.attention import AttentionMetadata
|
|
|
+from vllm.config import VllmConfig
|
|
|
+from vllm.logger import init_logger
|
|
|
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
|
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
|
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
|
+from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
|
+from vllm.sequence import IntermediateTensors
|
|
|
+
|
|
|
+from vllm.model_executor.models.interfaces import T
|
|
|
+from vllm.model_executor.models.qwen2 import Qwen2Model
|
|
|
+
|
|
|
+from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
|
|
+
|
|
|
+logger = init_logger(__name__)
|
|
|
+
|
|
|
+IGNORE_ID = -1
|
|
|
+
|
|
|
+
|
|
|
+class CosyVoice2Model(nn.Module):
|
|
|
+
|
|
|
+ packed_modules_mapping = {
|
|
|
+ "qkv_proj": [
|
|
|
+ "q_proj",
|
|
|
+ "k_proj",
|
|
|
+ "v_proj",
|
|
|
+ ],
|
|
|
+ "gate_up_proj": [
|
|
|
+ "gate_proj",
|
|
|
+ "up_proj",
|
|
|
+ ],
|
|
|
+ }
|
|
|
+
|
|
|
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
|
+ super().__init__()
|
|
|
+ config = vllm_config.model_config.hf_config
|
|
|
+ quant_config = vllm_config.quant_config
|
|
|
+ lora_config = vllm_config.lora_config
|
|
|
+
|
|
|
+ self.config = config
|
|
|
+ self.lora_config = lora_config
|
|
|
+ self.quant_config = quant_config
|
|
|
+
|
|
|
+ self.llm_input_size = 896
|
|
|
+ self.llm_output_size = 896
|
|
|
+
|
|
|
+ self.speech_token_size = 6561+3
|
|
|
+ self.llm_token_size = config.vocab_size
|
|
|
+
|
|
|
+ # 2. build speech token language model related modules
|
|
|
+ self.sos_eos = 0
|
|
|
+ self.task_id = 1
|
|
|
+ self.fill_token = 2
|
|
|
+
|
|
|
+
|
|
|
+ self.allow_patterns_overrides = ["llm.*"]
|
|
|
+ self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
|
|
|
+ self.model = Qwen2Model(vllm_config=vllm_config,
|
|
|
+ prefix=maybe_prefix(prefix, "model"))
|
|
|
+
|
|
|
+ # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
|
|
|
+ self.llm_decoder = ParallelLMHead(self.speech_token_size,
|
|
|
+ self.llm_output_size,
|
|
|
+ bias=True,
|
|
|
+ quant_config=quant_config,
|
|
|
+ prefix=maybe_prefix(
|
|
|
+ prefix, "llm_decoder"))
|
|
|
+ self.logits_processor = LogitsProcessor(self.speech_token_size)
|
|
|
+
|
|
|
+ # length_normalized_loss: bool = True,
|
|
|
+ # lsm_weight: float = 0.0,
|
|
|
+ # self.criterion_ce = LabelSmoothingLoss(
|
|
|
+ # size=self.speech_token_size,
|
|
|
+ # padding_idx=IGNORE_ID,
|
|
|
+ # smoothing=lsm_weight,
|
|
|
+ # normalize_length=length_normalized_loss,
|
|
|
+ # )
|
|
|
+
|
|
|
+ # 3. [Optional] build speech token related modules
|
|
|
+ self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
|
|
|
+
|
|
|
+ # 4. sampling method
|
|
|
+ ## use vllm sampling method
|
|
|
+ self.sampler = get_sampler()
|
|
|
+ self.make_empty_intermediate_tensors = (
|
|
|
+ self.model.make_empty_intermediate_tensors)
|
|
|
+
|
|
|
+ self.mix_ratio: List[int] = [5, 15]
|
|
|
+
|
|
|
+ # 定义特殊token常量
|
|
|
+ self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
|
|
|
+ self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32) # 163840 + 6564 = 170404
|
|
|
+ self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32) # 170405
|
|
|
+ self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
|
|
|
+
|
|
|
+ self.zero_embed_buffer = torch.zeros(
|
|
|
+ (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
|
|
|
+ dtype=self.llm_embedding.weight.dtype,
|
|
|
+ device=self.llm_embedding.weight.device
|
|
|
+ )
|
|
|
+ self.inputs_embed_buffer = torch.zeros(
|
|
|
+ (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
|
|
|
+ dtype=self.llm_embedding.weight.dtype,
|
|
|
+ device=self.llm_embedding.weight.device,
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_sos_eos_emb(self):
|
|
|
+ return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
|
|
+
|
|
|
+ def get_task_id_emb(self):
|
|
|
+ return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
|
|
+
|
|
|
+ def get_input_embeddings(
|
|
|
+ self,
|
|
|
+ input_ids: torch.Tensor,
|
|
|
+ multimodal_embeddings: Optional[T] = None,
|
|
|
+ attn_metadata: Optional["AttentionMetadata"] = None,
|
|
|
+ ) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Returns the input embeddings merged from the text embeddings from
|
|
|
+ input_ids and the multimodal embeddings generated from multimodal
|
|
|
+ kwargs.
|
|
|
+ """
|
|
|
+ # 创建掩码,标记哪些 token_id 属于音频 Token
|
|
|
+ mask = input_ids < self.speech_token_size
|
|
|
+
|
|
|
+ # 获取 input_ids 的原始形状
|
|
|
+ input_shape = input_ids.shape
|
|
|
+ # 展平 input_ids 和掩码以便统一处理
|
|
|
+ flat_input_ids = input_ids.view(-1)
|
|
|
+ flat_mask = mask.view(-1)
|
|
|
+
|
|
|
+ inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
|
|
|
+ inputs_embeds.zero_()
|
|
|
+
|
|
|
+ # Process speech tokens
|
|
|
+ if flat_mask.any():
|
|
|
+ speech_token_ids = flat_input_ids[flat_mask]
|
|
|
+ inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
|
|
|
+
|
|
|
+ # 处理大于 delta 的 token_id
|
|
|
+ if (~flat_mask).any():
|
|
|
+ llm_token_ids = flat_input_ids[~flat_mask]
|
|
|
+ llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
|
|
|
+
|
|
|
+ sos_eos_mask = llm_token_ids == self.sos_eos_token_id
|
|
|
+ task_mask = llm_token_ids == self.task_token_id
|
|
|
+ zero_mask = llm_token_ids == self.zero_token_id
|
|
|
+ normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
|
|
|
+
|
|
|
+ # 分层处理逻辑
|
|
|
+ # 第一优先级:SOS/EOS标记
|
|
|
+ if sos_eos_mask.any():
|
|
|
+ llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
|
|
|
+
|
|
|
+ # 第二优先级:任务标记
|
|
|
+ if task_mask.any():
|
|
|
+ llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
|
|
|
+
|
|
|
+ # 第二优先级:空音频标记
|
|
|
+ if zero_mask.any():
|
|
|
+ llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
|
|
|
+
|
|
|
+ # 常规LLM token
|
|
|
+ if normal_mask.any():
|
|
|
+ original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
|
|
|
+ # print('original_ids: ',original_ids)
|
|
|
+ llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
|
|
|
+
|
|
|
+ inputs_embeds[~flat_mask] = llm_embeds
|
|
|
+
|
|
|
+ inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
|
|
|
+
|
|
|
+ # 合并多模态嵌入(如果有)
|
|
|
+ if multimodal_embeddings is not None:
|
|
|
+ inputs_embeds = merge_multimodal_embeddings(
|
|
|
+ input_ids, inputs_embeds, multimodal_embeddings,
|
|
|
+ self.config.audio_token_index
|
|
|
+ )
|
|
|
+ return inputs_embeds
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids: torch.Tensor,
|
|
|
+ positions: torch.Tensor,
|
|
|
+ kv_caches: List[torch.Tensor],
|
|
|
+ attn_metadata: AttentionMetadata,
|
|
|
+ intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
|
+ inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
+ ) -> Union[torch.Tensor, IntermediateTensors]:
|
|
|
+ if inputs_embeds is None:
|
|
|
+ inputs_embeds = self.get_input_embeddings(
|
|
|
+ input_ids,
|
|
|
+ attn_metadata=attn_metadata,
|
|
|
+ )
|
|
|
+ return self.model(input_ids, positions, kv_caches,
|
|
|
+ attn_metadata, intermediate_tensors,
|
|
|
+ inputs_embeds)
|
|
|
+
|
|
|
+ def compute_logits(
|
|
|
+ self,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
+ ) -> Optional[torch.Tensor]:
|
|
|
+ logits = self.logits_processor(self.llm_decoder, hidden_states,
|
|
|
+ sampling_metadata)
|
|
|
+ return logits
|
|
|
+
|
|
|
+ def sample(
|
|
|
+ self,
|
|
|
+ logits: torch.Tensor,
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
+ ) -> Optional[SamplerOutput]:
|
|
|
+ next_tokens = self.sampler(logits, sampling_metadata)
|
|
|
+ return next_tokens
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
|
|
|
+ for name, param in weights:
|
|
|
+ # 处理Qwen2Model核心参数
|
|
|
+ if name.startswith("llm."):
|
|
|
+ if name.startswith("llm.model.model."):
|
|
|
+ name = name.replace("llm.model.model.", "model.")
|
|
|
+ else:
|
|
|
+ continue
|
|
|
+ # print('weights name: ', name)
|
|
|
+ yield name, param
|
|
|
+
|
|
|
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
+ weights = self.convert_weights(weights)
|
|
|
+ loader = AutoWeightsLoader(self)
|
|
|
+ loader.load_weights(weights)
|