瀏覽代碼

add flow trt wrapper

lyuxiang.lx 7 月之前
父節點
當前提交
a442317d17

+ 2 - 2
cosyvoice/cli/cosyvoice.py

@@ -137,7 +137,7 @@ class CosyVoice:
 
 class CosyVoice2(CosyVoice):
 
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False, trt_concurrent=1):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.fp16 = fp16
@@ -159,7 +159,7 @@ class CosyVoice2(CosyVoice):
         if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
             load_jit, load_trt, fp16 = False, False, False
             logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
-        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
+        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache, trt_concurrent)
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))

+ 26 - 7
cosyvoice/cli/model.py

@@ -1,4 +1,5 @@
 # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#               2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,6 +14,7 @@
 # limitations under the License.
 import os
 from typing import Generator
+import queue
 import torch
 import numpy as np
 import threading
@@ -22,6 +24,7 @@ from contextlib import nullcontext
 import uuid
 from cosyvoice.utils.common import fade_in_out
 from cosyvoice.utils.file_utils import convert_onnx_to_trt
+from cosyvoice.utils.common import TrtContextWrapper
 
 
 class CosyVoiceModel:
@@ -89,9 +92,12 @@ class CosyVoiceModel:
         del self.flow.decoder.estimator
         import tensorrt as trt
         with open(flow_decoder_estimator_model, 'rb') as f:
-            self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
-        assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
-        self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
+            estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
+        if isinstance(self, CosyVoice2Model):
+            self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
+        else:
+            self.flow.decoder.estimator = estimator_engine.create_execution_context()
 
     def get_trt_kwargs(self):
         min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
@@ -231,7 +237,9 @@ class CosyVoiceModel:
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
             self.flow_cache_dict.pop(this_uuid)
-        torch.cuda.empty_cache()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.current_stream().synchronize()
 
 
 class CosyVoice2Model(CosyVoiceModel):
@@ -241,13 +249,15 @@ class CosyVoice2Model(CosyVoiceModel):
                  flow: torch.nn.Module,
                  hift: torch.nn.Module,
                  fp16: bool = False,
-                 use_flow_cache: bool = False):
+                 use_flow_cache: bool = False,
+                 trt_concurrent: int = 1):
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         self.llm = llm
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
         self.use_flow_cache = use_flow_cache
+        self.trt_concurrent = trt_concurrent
         if self.fp16 is True:
             self.llm.half()
             self.flow.half()
@@ -261,12 +271,16 @@ class CosyVoice2Model(CosyVoiceModel):
         self.speech_window = np.hamming(2 * self.source_cache_len)
         # rtf and decoding related
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
+        for _ in range(trt_concurrent):
+            self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
         self.lock = threading.Lock()
         # dict used to store session related variable
         self.tts_speech_token_dict = {}
         self.llm_end_dict = {}
         self.flow_cache_dict = {}
         self.hift_cache_dict = {}
+        self.trt_context_dict = {}
 
     def init_flow_cache(self):
         encoder_cache = {'offset': 0,
@@ -304,7 +318,7 @@ class CosyVoice2Model(CosyVoiceModel):
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
-        with torch.cuda.amp.autocast(self.fp16):
+        with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
             tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
                                                                       token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                                                       prompt_token=prompt_token.to(self.device),
@@ -349,6 +363,7 @@ class CosyVoice2Model(CosyVoiceModel):
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
             self.hift_cache_dict[this_uuid] = None
             self.flow_cache_dict[this_uuid] = self.init_flow_cache()
+            self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
         if source_speech_token.shape[1] == 0:
             p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         else:
@@ -405,4 +420,8 @@ class CosyVoice2Model(CosyVoiceModel):
             self.llm_end_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
             self.flow_cache_dict.pop(this_uuid)
-        torch.cuda.empty_cache()
+            self.trt_context_pool.put(self.trt_context_dict[this_uuid])
+            self.trt_context_dict.pop(this_uuid)
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.current_stream().synchronize()

+ 52 - 46
cosyvoice/flow/flow_matching.py

@@ -1,4 +1,5 @@
 # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#               2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -290,50 +291,55 @@ class CausalConditionalCFM(ConditionalCFM):
             x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
             cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
         else:
-            with self.lock:
-                self.estimator.set_input_shape('x', (2, 80, x.size(2)))
-                self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
-                self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
-                self.estimator.set_input_shape('t', (2,))
-                self.estimator.set_input_shape('spks', (2, 80))
-                self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
-                self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
-                self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
-                self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
-                self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
-                self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
-                self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
-                self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
-                # run trt engine
-                down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
-                mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
-                up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
-                assert self.estimator.execute_v2([x.contiguous().data_ptr(),
-                                                  mask.contiguous().data_ptr(),
-                                                  mu.contiguous().data_ptr(),
-                                                  t.contiguous().data_ptr(),
-                                                  spks.contiguous().data_ptr(),
-                                                  cond.contiguous().data_ptr(),
-                                                  cache['down_blocks_conv_cache'].contiguous().data_ptr(),
-                                                  cache['down_blocks_kv_cache'].contiguous().data_ptr(),
-                                                  cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
-                                                  cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
-                                                  cache['up_blocks_conv_cache'].contiguous().data_ptr(),
-                                                  cache['up_blocks_kv_cache'].contiguous().data_ptr(),
-                                                  cache['final_blocks_conv_cache'].contiguous().data_ptr(),
-                                                  x.data_ptr(),
-                                                  cache['down_blocks_conv_cache'].data_ptr(),
-                                                  down_blocks_kv_cache_out.data_ptr(),
-                                                  cache['mid_blocks_conv_cache'].data_ptr(),
-                                                  mid_blocks_kv_cache_out.data_ptr(),
-                                                  cache['up_blocks_conv_cache'].data_ptr(),
-                                                  up_blocks_kv_cache_out.data_ptr(),
-                                                  cache['final_blocks_conv_cache'].data_ptr()]) is True
-                cache = (cache['down_blocks_conv_cache'],
-                         down_blocks_kv_cache_out,
-                         cache['mid_blocks_conv_cache'],
-                         mid_blocks_kv_cache_out,
-                         cache['up_blocks_conv_cache'],
-                         up_blocks_kv_cache_out,
-                         cache['final_blocks_conv_cache'])
+            estimator, trt_engine = self.estimator.acquire_estimator()
+            estimator.set_input_shape('x', (2, 80, x.size(2)))
+            estimator.set_input_shape('mask', (2, 1, x.size(2)))
+            estimator.set_input_shape('mu', (2, 80, x.size(2)))
+            estimator.set_input_shape('t', (2,))
+            estimator.set_input_shape('spks', (2, 80))
+            estimator.set_input_shape('cond', (2, 80, x.size(2)))
+            estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
+            estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
+            estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
+            estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
+            estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
+            estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
+            estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
+            down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
+            mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
+            up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
+            data_ptrs = [x.contiguous().data_ptr(),
+                         mask.contiguous().data_ptr(),
+                         mu.contiguous().data_ptr(),
+                         t.contiguous().data_ptr(),
+                         spks.contiguous().data_ptr(),
+                         cond.contiguous().data_ptr(),
+                         cache['down_blocks_conv_cache'].contiguous().data_ptr(),
+                         cache['down_blocks_kv_cache'].contiguous().data_ptr(),
+                         cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
+                         cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
+                         cache['up_blocks_conv_cache'].contiguous().data_ptr(),
+                         cache['up_blocks_kv_cache'].contiguous().data_ptr(),
+                         cache['final_blocks_conv_cache'].contiguous().data_ptr(),
+                         x.data_ptr(),
+                         cache['down_blocks_conv_cache'].data_ptr(),
+                         down_blocks_kv_cache_out.data_ptr(),
+                         cache['mid_blocks_conv_cache'].data_ptr(),
+                         mid_blocks_kv_cache_out.data_ptr(),
+                         cache['up_blocks_conv_cache'].data_ptr(),
+                         up_blocks_kv_cache_out.data_ptr(),
+                         cache['final_blocks_conv_cache'].data_ptr()]
+            for i, j in enumerate(data_ptrs):
+                estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
+            # run trt engine
+            assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+            torch.cuda.current_stream().synchronize()
+            self.estimator.release_estimator(estimator)
+            cache = (cache['down_blocks_conv_cache'],
+                     down_blocks_kv_cache_out,
+                     cache['mid_blocks_conv_cache'],
+                     mid_blocks_kv_cache_out,
+                     cache['up_blocks_conv_cache'],
+                     up_blocks_kv_cache_out,
+                     cache['final_blocks_conv_cache'])
         return x, cache

+ 212 - 0
cosyvoice/llm/llm_vllm.py

@@ -0,0 +1,212 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+import queue
+import asyncio
+import threading
+from typing import List, Generator, AsyncGenerator
+import torch
+from cosyvoice.utils.file_utils import logging
+from cosyvoice.llm.llm import Qwen2LM
+
+# 启用vllm V1版本
+import os
+os.environ["VLLM_USE_V1"] = '1'
+from vllm import ModelRegistry
+from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
+from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
+from vllm.sampling_params import SamplingParams
+
+from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
+ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
+
+# EngineArgs
+ENGINE_ARGS = {
+    "block_size": 16,
+    "swap_space": 0,
+    # "enforce_eager": True,
+    "gpu_memory_utilization": 0.4,
+    "max_num_batched_tokens": 1024,
+    "max_model_len": 1024,
+    "max_num_seqs": 256,
+    "disable_log_requests": True,
+    "disable_log_stats": True,
+    "dtype": "float16"
+}
+
+from vllm.sampling_params import RequestOutputKind
+# SamplingParams
+SAMPLING_PARAMS = {
+    "temperature": 1,  # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
+    "top_p": 1,       # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
+    "top_k": 25,
+    # "min_tokens": 80,       # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动
+    # "presence_penalty": 1.0,    # 不支持设置
+    # "frequency_penalty": 0.0,   # 不支持设置
+    "max_tokens": 1024,
+    "detokenize": False,          # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算
+    "ignore_eos": False,
+    "output_kind": RequestOutputKind.DELTA  # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码
+}
+
+def tensor_to_list(tensor: torch.tensor):
+    return tensor.view(-1).cpu().numpy().tolist()
+
+class VllmQwen2LM(Qwen2LM):
+    def __init__(
+            self,
+            model_dir,
+            mix_ratio: List[int] = [5, 15],
+    ):
+        self.fp16 = False
+        self.half = lambda: None
+        self.mix_ratio = mix_ratio
+        # ---------------------------------------------
+        # vllm engine 的参数配置
+        engine_args = AsyncEngineArgs(
+            model=model_dir,
+            **ENGINE_ARGS,
+        )
+        self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
+
+        self.speech_token_size = 6564       # 6561 + 3
+        self.llm_token_size = 151936        # llm  vocab_size
+        self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
+        self.task_token_id = self.sos_eos_token_id + 1
+        self.zero_token_id = self.task_token_id + 1
+
+        # vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务
+        self.loop = asyncio.new_event_loop()
+        self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
+        self.loop_thread.start()
+
+    def _run_event_loop(self):
+        asyncio.set_event_loop(self.loop)
+        self.loop.run_forever()
+
+    async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens):
+        sampling_params = SamplingParams(**SAMPLING_PARAMS)
+        sampling_params.stop_token_ids = stop_token_ids or [6561]
+        if max_tokens:
+            sampling_params.max_tokens = max_tokens
+        async for output in self.llm_engine.generate(
+                {
+                    "prompt_token_ids": prompt_token_ids,
+                },
+                sampling_params=sampling_params,
+                request_id=request_id or f"{time.time()}",
+        ):
+            out_queue.put((output.outputs[0], output.finished))
+
+    def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
+        out_queue = queue.Queue()
+        asyncio.run_coroutine_threadsafe(
+            self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop
+        )
+        # 接收 out_queue 返回的结果
+        finished = False
+        while not finished:
+            (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
+            yield output
+
+    def inference(
+            self,
+            text: torch.Tensor,
+            text_len: torch.Tensor,
+            prompt_text: torch.Tensor,
+            prompt_text_len: torch.Tensor,
+            prompt_speech_token: torch.Tensor,
+            prompt_speech_token_len: torch.Tensor,
+            embedding: torch.Tensor,
+            sampling: int = 25,
+            max_token_text_ratio: float = 20,
+            min_token_text_ratio: float = 2,
+    ) -> Generator[torch.Tensor|int, None, None]:
+        prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
+        prompt_speech_token = tensor_to_list(prompt_speech_token)
+
+        text = tensor_to_list(text + torch.tensor(6564))
+        prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
+                           [self.task_token_id] + prompt_speech_token
+        max_tokens = len(text) * 20
+        for output in self.llm_inference(
+                prompt_token_ids,
+                stop_token_ids=[6561],
+                max_tokens=max_tokens,
+        ):
+            if output.token_ids[-1] == 6561:
+                need_add_tokens = output.token_ids[:-1]
+            else:
+                need_add_tokens = output.token_ids
+            for token in need_add_tokens:
+                yield token
+
+    def inference_bistream(
+            self,
+            text: Generator,
+            prompt_text: torch.Tensor,
+            prompt_text_len: torch.Tensor,
+            prompt_speech_token: torch.Tensor,
+            prompt_speech_token_len: torch.Tensor,
+            embedding: torch.Tensor,
+            sampling: int = 25,
+            max_token_text_ratio: float = 20,
+            min_token_text_ratio: float = 2,
+    ) -> Generator[torch.Tensor, None, None]:
+        prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
+        prompt_speech_token = tensor_to_list(prompt_speech_token)
+
+        last_tokens = []
+        prompt_token_ids = [self.sos_eos_token_id]
+        text_tokens_cache = prompt_text
+        for this_text in text:
+            this_text = tensor_to_list(this_text + torch.tensor(6564))
+            # text need tokens
+            assert isinstance(this_text, list), "text need token ids List[int]."
+            text_tokens_cache += this_text
+            while len(prompt_speech_token) != 0:
+                if len(text_tokens_cache) >= self.mix_ratio[0]:
+                    text_input_token = text_tokens_cache[:self.mix_ratio[0]]
+                    speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
+                    prompt_token_ids += text_input_token + speech_input_token
+                    # reset the last cache
+                    text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
+                    prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
+                else:
+                    break
+            if len(prompt_speech_token) == 0:
+                if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
+                    if len(text_tokens_cache) >= self.mix_ratio[0]:
+                        text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
+                        prompt_token_ids += text_tokens_temp
+                        text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
+                    else:
+                        continue
+                for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
+                    last_tokens = output.token_ids
+                    if last_tokens[-1] == 6563:
+                        need_add_tokens = last_tokens[:-1]
+                    else:
+                        need_add_tokens = last_tokens
+                    for token in need_add_tokens:
+                        yield token
+                    prompt_token_ids.extend(need_add_tokens)
+        prompt_token_ids += text_tokens_cache + [self.task_token_id]
+        for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
+            if output.token_ids[-1] == 6561:
+                need_add_tokens = output.token_ids[:-1]
+            else:
+                need_add_tokens = output.token_ids
+            for token in need_add_tokens:
+                yield token

+ 263 - 0
cosyvoice/llm/vllm_use_cosyvoice2_model.py

@@ -0,0 +1,263 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
+# Copyright 2024 The Qwen team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Qwen2 model compatible with HuggingFace weights."""
+from typing import Iterable, List, Optional, Set, Tuple, Union, Iterator, overload, TypedDict, Mapping, Any
+from typing_extensions import TypeVar
+
+import torch
+from torch import nn
+
+from vllm.attention import AttentionMetadata
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from vllm.model_executor.models.interfaces import T
+from vllm.model_executor.models.qwen2 import Qwen2Model
+
+from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
+
+logger = init_logger(__name__)
+
+IGNORE_ID = -1
+
+
+class CosyVoice2Model(nn.Module):
+
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+        config = vllm_config.model_config.hf_config
+        quant_config = vllm_config.quant_config
+        lora_config = vllm_config.lora_config
+
+        self.config = config
+        self.lora_config = lora_config
+        self.quant_config = quant_config
+
+        self.llm_input_size = 896
+        self.llm_output_size = 896
+
+        self.speech_token_size = 6561+3
+        self.llm_token_size = config.vocab_size
+
+        # 2. build speech token language model related modules
+        self.sos_eos = 0
+        self.task_id = 1
+        self.fill_token = 2
+
+
+        self.allow_patterns_overrides = ["llm.*"]
+        self.llm_embedding = torch.nn.Embedding(2, self.llm_input_size)
+        self.model = Qwen2Model(vllm_config=vllm_config,
+                              prefix=maybe_prefix(prefix, "model"))
+
+        # self.llm_decoder = nn.Linear(self.llm_output_size, self.speech_token_size)
+        self.llm_decoder = ParallelLMHead(self.speech_token_size,
+                                      self.llm_output_size,
+                                      bias=True,
+                                      quant_config=quant_config,
+                                      prefix=maybe_prefix(
+                                          prefix, "llm_decoder"))
+        self.logits_processor = LogitsProcessor(self.speech_token_size)
+
+        # length_normalized_loss: bool = True,
+        # lsm_weight: float = 0.0,
+        # self.criterion_ce = LabelSmoothingLoss(
+        #     size=self.speech_token_size,
+        #     padding_idx=IGNORE_ID,
+        #     smoothing=lsm_weight,
+        #     normalize_length=length_normalized_loss,
+        # )
+
+        # 3. [Optional] build speech token related modules
+        self.speech_embedding = torch.nn.Embedding(self.speech_token_size, self.llm_input_size)
+
+        # 4. sampling method
+        ## use vllm sampling method
+        self.sampler = get_sampler()
+        self.make_empty_intermediate_tensors = (
+            self.model.make_empty_intermediate_tensors)
+
+        self.mix_ratio: List[int] = [5, 15]
+
+        # 定义特殊token常量
+        self.llm_token_id_delta = torch.tensor(self.speech_token_size, dtype=torch.int32)
+        self.sos_eos_token_id = torch.tensor((self.llm_token_id_delta + self.llm_token_size + 1), dtype=torch.int32)  # 163840 + 6564 = 170404
+        self.task_token_id = self.sos_eos_token_id + torch.tensor(1, dtype=torch.int32)  # 170405
+        self.zero_token_id = self.task_token_id + torch.tensor(1, dtype=torch.int32)
+
+        self.zero_embed_buffer = torch.zeros(
+            (vllm_config.scheduler_config.max_num_seqs, self.llm_input_size),
+            dtype=self.llm_embedding.weight.dtype,
+            device=self.llm_embedding.weight.device
+        )
+        self.inputs_embed_buffer = torch.zeros(
+            (vllm_config.scheduler_config.max_num_batched_tokens, self.llm_input_size),
+            dtype=self.llm_embedding.weight.dtype,
+            device=self.llm_embedding.weight.device,
+        )
+
+    def get_sos_eos_emb(self):
+        return self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+
+    def get_task_id_emb(self):
+        return self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+
+    def get_input_embeddings(
+        self,
+        input_ids: torch.Tensor,
+        multimodal_embeddings: Optional[T] = None,
+        attn_metadata: Optional["AttentionMetadata"] = None,
+    ) -> torch.Tensor:
+        """
+        Returns the input embeddings merged from the text embeddings from
+        input_ids and the multimodal embeddings generated from multimodal
+        kwargs.
+        """
+        # 创建掩码,标记哪些 token_id 属于音频 Token
+        mask = input_ids < self.speech_token_size
+
+        # 获取 input_ids 的原始形状
+        input_shape = input_ids.shape
+        # 展平 input_ids 和掩码以便统一处理
+        flat_input_ids = input_ids.view(-1)
+        flat_mask = mask.view(-1)
+
+        inputs_embeds = self.inputs_embed_buffer[:flat_input_ids.shape[0]]
+        inputs_embeds.zero_()
+
+        # Process speech tokens
+        if flat_mask.any():
+            speech_token_ids = flat_input_ids[flat_mask]
+            inputs_embeds[flat_mask] = self.speech_embedding(speech_token_ids)
+
+        # 处理大于 delta 的 token_id
+        if (~flat_mask).any():
+            llm_token_ids = flat_input_ids[~flat_mask]
+            llm_embeds = torch.zeros_like(inputs_embeds[~flat_mask])
+
+            sos_eos_mask = llm_token_ids == self.sos_eos_token_id
+            task_mask = llm_token_ids == self.task_token_id
+            zero_mask = llm_token_ids == self.zero_token_id
+            normal_mask = ~(sos_eos_mask | task_mask | zero_mask)
+
+            # 分层处理逻辑
+            # 第一优先级:SOS/EOS标记
+            if sos_eos_mask.any():
+                llm_embeds[sos_eos_mask] = self.llm_embedding.weight[self.sos_eos].unsqueeze(0)
+
+            # 第二优先级:任务标记
+            if task_mask.any():
+                llm_embeds[task_mask] = self.llm_embedding.weight[self.task_id].unsqueeze(0)
+
+            # 第二优先级:空音频标记
+            if zero_mask.any():
+                llm_embeds[zero_mask] = self.zero_embed_buffer[:len(llm_embeds[zero_mask])]
+
+            # 常规LLM token
+            if normal_mask.any():
+                original_ids = llm_token_ids[normal_mask] - self.llm_token_id_delta
+                # print('original_ids: ',original_ids)
+                llm_embeds[normal_mask] = self.model.get_input_embeddings(original_ids)
+
+            inputs_embeds[~flat_mask] = llm_embeds
+
+        inputs_embeds = inputs_embeds.view(*input_shape, self.llm_input_size)
+
+        # 合并多模态嵌入(如果有)
+        if multimodal_embeddings is not None:
+            inputs_embeds = merge_multimodal_embeddings(
+                input_ids, inputs_embeds, multimodal_embeddings,
+                self.config.audio_token_index
+            )
+        return inputs_embeds
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        if inputs_embeds is None:
+            inputs_embeds = self.get_input_embeddings(
+                input_ids,
+                attn_metadata=attn_metadata,
+            )
+        return self.model(input_ids, positions, kv_caches,
+                        attn_metadata, intermediate_tensors,
+                        inputs_embeds)
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
+        logits = self.logits_processor(self.llm_decoder, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    @staticmethod
+    def convert_weights(weights: Iterable[Tuple[str, torch.Tensor]]) -> Iterable[Tuple[str, torch.Tensor]]:
+        for name, param in weights:
+            # 处理Qwen2Model核心参数
+            if name.startswith("llm."):
+                if name.startswith("llm.model.model."):
+                    name = name.replace("llm.model.model.", "model.")
+                else:
+                    continue
+            # print('weights name: ', name)
+            yield name, param
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        weights = self.convert_weights(weights)
+        loader = AutoWeightsLoader(self)
+        loader.load_weights(weights)

+ 19 - 0
cosyvoice/utils/common.py

@@ -1,5 +1,6 @@
 # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
 #               2024 Alibaba Inc (authors: Xiang Lyu)
+#               2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
 # Modified from ESPnet(https://github.com/espnet/espnet)
 """Unility functions for Transformer."""
 
+import queue
 import random
 from typing import List
 
@@ -164,3 +166,20 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
     #     chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
     mask = (1.0 - mask) * -1.0e+10
     return mask
+
+
+class TrtContextWrapper:
+    def __init__(self, trt_engine, trt_concurrent=1):
+        self.trt_context_pool = queue.Queue()
+        self.trt_engine = trt_engine
+        for _ in range(trt_concurrent):
+            trt_context = trt_engine.create_execution_context()
+            assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
+            self.trt_context_pool.put(trt_context)
+        assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
+
+    def acquire_estimator(self):
+        return self.trt_context_pool.get(), self.trt_engine
+
+    def release_estimator(self, context):
+        self.trt_context_pool.put(context)

+ 1 - 1
cosyvoice/utils/file_utils.py

@@ -56,7 +56,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
     network = builder.create_network(network_flags)
     parser = trt.OnnxParser(network, logger)
     config = builder.create_builder_config()
-    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33)  # 8GB
+    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB
     if fp16:
         config.set_flag(trt.BuilderFlag.FP16)
     profile = builder.create_optimization_profile()

+ 40 - 0
requirements_vllm.txt

@@ -0,0 +1,40 @@
+vllm==0.7.3
+pydantic==2.10.6
+torch==2.5.1
+torchaudio==2.5.1
+
+conformer==0.3.2
+
+diffusers==0.32.2
+gdown==5.1.0
+grpcio==1.57.0
+grpcio-tools==1.57.0
+hydra-core==1.3.2
+HyperPyYAML==1.2.2
+inflect==7.3.1
+librosa==0.10.2
+
+lightning==2.5.0.post0
+matplotlib==3.7.5
+modelscope==1.15.0
+
+networkx==3.4.2
+omegaconf==2.3.0
+onnx==1.17.0
+
+onnxruntime-gpu==1.19.0; sys_platform == 'linux'
+
+#openai-whisper==20231117
+openai-whisper==20240930
+protobuf==4.25
+pyworld==0.3.4
+rich==13.7.1
+soundfile==0.12.1
+tensorboard==2.14.0
+wget==3.2
+WeTextProcessing==1.0.3
+
+# trt use
+tensorrt-cu12==10.0.1
+tensorrt-cu12-bindings==10.0.1
+tensorrt-cu12-libs==10.0.1