Browse Source

add vllm export

root 11 months ago
parent
commit
9f55c5af8f

+ 2 - 0
cosyvoice/cli/cosyvoice.py

@@ -166,6 +166,8 @@ class CosyVoice2(CosyVoice):
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
+        if load_vllm:
+            self.model.load_vllm('{}/vllm'.format(model_dir))
         if load_jit:
             self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
         if load_trt:

+ 1 - 1
cosyvoice/cli/model.py

@@ -23,7 +23,7 @@ from torch.nn import functional as F
 from contextlib import nullcontext
 import uuid
 from cosyvoice.utils.common import fade_in_out
-from cosyvoice.utils.file_utils import convert_onnx_to_trt
+from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
 from cosyvoice.utils.common import TrtContextWrapper
 
 

+ 0 - 263
cosyvoice/llm/vllm_use_cosyvoice2_model.py

@@ -1,263 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-
-# Adapted from
-# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
-# Copyright 2024 The Qwen team.
-# Copyright 2023 The vLLM team.
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Inference-only Qwen2 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
-from typing_extensions import TypeVar
-
-import torch
-from torch import nn
-
-from vllm.attention import AttentionMetadata
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
-from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
-from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.sequence import IntermediateTensors
-
-from vllm.model_executor.models.interfaces import T
-from vllm.model_executor.models.qwen2 import Qwen2Model
-
-from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
-
-logger = init_logger(__name__)
-
-IGNORE_ID = -1
-
-
-class CosyVoice2Model(nn.Module):
-
-    packed_modules_mapping = {
-        "qkv_proj": [
-            "q_proj",
-            "k_proj",
-            "v_proj",
-        ],
-        "gate_up_proj": [
-            "gate_proj",
-            "up_proj",
-        ],
-    }
-
-    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
-        super().__init__()
-        config = vllm_config.model_config.hf_config
-        quant_config = vllm_config.quant_config
-        lora_config = vllm_config.lora_config
-
-        self.config = config
-        self.lora_config = lora_config
-        self.quant_config = quant_config
-
-        self.llm_input_size = 896
-        self.llm_output_size = 896
-
-        self.speech_token_size = 6561+3
-        self.llm_token_size = config.vocab_size
-
-        # 2. build speech token language model related modules
-        self.sos_eos = 0
-        self.task_id = 1
-        self.fill_token = 2
-
-
-        self.allow_patterns_overrides = ["llm.*"]
-        self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
-        self.model = Qwen2Model(vllm_config=vllm_config,
-                              prefix=maybe_prefix(prefix, "model"))
-
-        # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
-        self.llm_decoder = ParallelLMHead(self.speech_token_size,
-                                      self.llm_output_size,
-                                      bias=True,
-                                      quant_config=quant_config,
-                                      prefix=maybe_prefix(
-                                          prefix, "llm_decoder"))
-        self.logits_processor = LogitsProcessor(self.speech_token_size)
-
-        # length_normalized_loss: bool = True,
-        # lsm_weight: float = 0.0,
-        # self.criterion_ce = LabelSmoothingLoss(
-        #     size=self.speech_token_size,
-        #     padding_idx=IGNORE_ID,
-        #     smoothing=lsm_weight,
-        #     normalize_length=length_normalized_loss,
-        # )
-
-        # 3. [Optional] build speech token related modules
-        self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
-
-        # 4. sampling method
-        ## use vllm sampling method
-        self.sampler = get_sampler()
-        self.make_empty_intermediate_tensors = (
-            self.model.make_empty_intermediate_tensors)
-
-        self.mix_ratio: List[int] = [5, 15]
-
-        # 定义特殊token常量
-        self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
-        self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32)  # 163840 + 6564 = 170404
-        self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32)  # 170405
-        self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
-
-        self.zero_embed_buffer = torch.zeros(
-            (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
-            dtype=self.llm_embedding.weight.dtype,
-            device=self.llm_embedding.weight.device
-        )
-        self.inputs_embed_buffer = torch.zeros(
-            (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
-            dtype=self.llm_embedding.weight.dtype,
-            device=self.llm_embedding.weight.device,
-        )
-
-    def get_sos_eos_emb(self):
-        return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
-
-    def get_task_id_emb(self):
-        return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
-
-    def get_input_embeddings(
-        self,
-        input_ids: torch.Tensor,
-        multimodal_embeddings: Optional[T] = None,
-        attn_metadata: Optional["AttentionMetadata"] = None,
-    ) -> torch.Tensor:
-        """
-        Returns the input embeddings merged from the text embeddings from
-        input_ids and the multimodal embeddings generated from multimodal
-        kwargs.
-        """
-        # 创建掩码,标记哪些 token_id 属于音频 Token
-        mask = input_ids < self.speech_token_size
-
-        # 获取 input_ids 的原始形状
-        input_shape = input_ids.shape
-        # 展平 input_ids 和掩码以便统一处理
-        flat_input_ids = input_ids.view(-1)
-        flat_mask = mask.view(-1)
-
-        inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
-        inputs_embeds.zero_()
-
-        # Process speech tokens
-        if flat_mask.any():
-            speech_token_ids = flat_input_ids[flat_mask]
-            inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
-
-        # 处理大于 delta 的 token_id
-        if (~flat_mask).any():
-            llm_token_ids = flat_input_ids[~flat_mask]
-            llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
-
-            sos_eos_mask = llm_token_ids == self.sos_eos_token_id
-            task_mask = llm_token_ids == self.task_token_id
-            zero_mask = llm_token_ids == self.zero_token_id
-            normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
-
-            # 分层处理逻辑
-            # 第一优先级:SOS/EOS标记
-            if sos_eos_mask.any():
-                llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
-
-            # 第二优先级:任务标记
-            if task_mask.any():
-                llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
-
-            # 第二优先级:空音频标记
-            if zero_mask.any():
-                llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
-
-            # 常规LLM token
-            if normal_mask.any():
-                original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
-                # print('original_ids: ',original_ids)
-                llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
-
-            inputs_embeds[~flat_mask] = llm_embeds
-
-        inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
-
-        # 合并多模态嵌入(如果有)
-        if multimodal_embeddings is not None:
-            inputs_embeds = merge_multimodal_embeddings(
-                input_ids, inputs_embeds, multimodal_embeddings,
-                self.config.audio_token_index
-            )
-        return inputs_embeds
-
-    def forward(
-        self,
-        input_ids: torch.Tensor,
-        positions: torch.Tensor,
-        kv_caches: List[torch.Tensor],
-        attn_metadata: AttentionMetadata,
-        intermediate_tensors: Optional[IntermediateTensors] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-    ) -> Union[torch.Tensor, IntermediateTensors]:
-        if inputs_embeds is None:
-            inputs_embeds = self.get_input_embeddings(
-                input_ids,
-                attn_metadata=attn_metadata,
-            )
-        return self.model(input_ids, positions, kv_caches,
-                        attn_metadata, intermediate_tensors,
-                        inputs_embeds)
-
-    def compute_logits(
-        self,
-        hidden_states: torch.Tensor,
-        sampling_metadata: SamplingMetadata,
-    ) -> Optional[torch.Tensor]:
-        logits = self.logits_processor(self.llm_decoder, hidden_states,
-                                       sampling_metadata)
-        return logits
-
-    def sample(
-        self,
-        logits: torch.Tensor,
-        sampling_metadata: SamplingMetadata,
-    ) -> Optional[SamplerOutput]:
-        next_tokens = self.sampler(logits, sampling_metadata)
-        return next_tokens
-
-    @staticmethod
-    def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
-        for name, param in weights:
-            # 处理Qwen2Model核心参数
-            if name.startswith("llm."):
-                if name.startswith("llm.model.model."):
-                    name = name.replace("llm.model.model.", "model.")
-                else:
-                    continue
-            # print('weights name: ', name)
-            yield name, param
-
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        weights = self.convert_weights(weights)
-        loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)

+ 43 - 1
cosyvoice/utils/file_utils.py

@@ -1,5 +1,6 @@
 # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
 #               2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
+#               2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,8 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os
 import json
-import torchaudio
+import torch, torchaudio
 import logging
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 logging.basicConfig(level=logging.DEBUG,
@@ -83,3 +85,43 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
     with open(trt_model, "wb") as f:
         f.write(engine_bytes)
     logging.info("Succesfully convert onnx to trt...")
+
+
+def export_cosyvoice2_vllm(model, model_path, device):
+    if os.path.exists(model_path):
+        return
+    pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
+    vocab_size = model.speech_embedding.num_embeddings
+    feature_size = model.speech_embedding.embedding_dim
+    pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
+
+    dtype = torch.bfloat16
+    # lm_head
+    new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
+    with torch.no_grad():
+        new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
+        new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
+        new_lm_head.weight[vocab_size:] = 0
+        new_lm_head.bias[vocab_size:] = 0
+    model.llm.model.lm_head = new_lm_head
+    new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
+    # embed_tokens
+    embed_tokens = model.llm.model.model.embed_tokens
+    with torch.no_grad():
+        new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
+        new_codec_embed.weight[vocab_size:] = 0
+    model.llm.model.set_input_embeddings(new_codec_embed)
+    model.llm.model.to(device)
+    model.llm.model.to(dtype)
+    tmp_vocab_size = model.llm.model.config.vocab_size
+    tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
+    del model.llm.model.generation_config.eos_token_id
+    del model.llm.model.config.bos_token_id
+    del model.llm.model.config.eos_token_id
+    model.llm.model.config.vocab_size = pad_vocab_size
+    model.llm.model.config.tie_word_embeddings = False
+    model.llm.model.config.use_bias = True
+    model.llm.model.save_pretrained(model_path)
+    model.llm.model.config.vocab_size = tmp_vocab_size
+    model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
+    model.llm.model.set_input_embeddings(embed_tokens)

+ 1 - 1
requirements.txt

@@ -1,7 +1,7 @@
 --extra-index-url https://download.pytorch.org/whl/cu121
 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
 conformer==0.3.2
-deepspeed==0.14.2; sys_platform == 'linux'
+deepspeed==0.15.1; sys_platform == 'linux'
 diffusers==0.29.0
 gdown==5.1.0
 gradio==5.4.0