lyuxiang.lx 11 mesiacov pred
rodič
commit
6dd68b9d5e

+ 1 - 1
cosyvoice/cli/cosyvoice.py

@@ -140,7 +140,7 @@ class CosyVoice:
 
 class CosyVoice2(CosyVoice):
 
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.fp16 = fp16

+ 14 - 19
cosyvoice/cli/model.py

@@ -59,9 +59,6 @@ class CosyVoiceModel:
         self.stream_scale_factor = 1
         assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
-        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
-        for _ in range(trt_concurrent):
-            self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
         self.lock = threading.Lock()
         # dict used to store session related variable
         self.tts_speech_token_dict = {}
@@ -69,7 +66,6 @@ class CosyVoiceModel:
         self.mel_overlap_dict = {}
         self.flow_cache_dict = {}
         self.hift_cache_dict = {}
-        self.trt_context_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
@@ -98,7 +94,7 @@ class CosyVoiceModel:
         with open(flow_decoder_estimator_model, 'rb') as f:
             estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
         assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
-        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent, device=self.device)
 
     def get_trt_kwargs(self):
         min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
@@ -125,7 +121,8 @@ class CosyVoiceModel:
                                             prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                             prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                             prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                            embedding=llm_embedding.to(self.device)):
+                                            embedding=llm_embedding.to(self.device),
+                                            uuid=uuid):
                     self.tts_speech_token_dict[uuid].append(i)
         self.llm_end_dict[uuid] = True
 
@@ -180,13 +177,11 @@ class CosyVoiceModel:
             prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
-        this_trt_context = self.trt_context_pool.get()
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
             self.hift_cache_dict[this_uuid] = None
             self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
             self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
-            self.trt_context_dict[this_uuid] = this_trt_context
         if source_speech_token.shape[1] == 0:
             p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         else:
@@ -240,8 +235,6 @@ class CosyVoiceModel:
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
             self.flow_cache_dict.pop(this_uuid)
-            self.trt_context_pool.put(self.trt_context_dict[this_uuid])
-            self.trt_context_dict.pop(this_uuid)
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
             torch.cuda.current_stream().synchronize()
@@ -273,22 +266,28 @@ class CosyVoice2Model(CosyVoiceModel):
         self.speech_window = np.hamming(2 * self.source_cache_len)
         # rtf and decoding related
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
-        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
-        for _ in range(trt_concurrent):
-            self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
         self.lock = threading.Lock()
         # dict used to store session related variable
         self.tts_speech_token_dict = {}
         self.llm_end_dict = {}
         self.hift_cache_dict = {}
-        self.trt_context_dict = {}
 
     def load_jit(self, flow_encoder_model):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
+    def load_vllm(self, model_dir):
+        export_cosyvoice2_vllm(self.llm, model_dir, self.device)
+        from vllm import EngineArgs, LLMEngine
+        engine_args = EngineArgs(model=model_dir,
+                                 skip_tokenizer_init=True,
+                                 enable_prompt_embeds=True,
+                                 gpu_memory_utilization=0.2)
+        self.llm.vllm = LLMEngine.from_engine_args(engine_args)
+        del self.llm.llm.model.model.layers
+
     def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
-        with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
+        with torch.cuda.amp.autocast(self.fp16):
             tts_mel, _ = self.flow.inference(token=token.to(self.device),
                                              token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                              prompt_token=prompt_token.to(self.device),
@@ -330,11 +329,9 @@ class CosyVoice2Model(CosyVoiceModel):
             prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
-        this_trt_context = self.trt_context_pool.get()
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
             self.hift_cache_dict[this_uuid] = None
-            self.trt_context_dict[this_uuid] = this_trt_context
         if source_speech_token.shape[1] == 0:
             p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         else:
@@ -388,8 +385,6 @@ class CosyVoice2Model(CosyVoiceModel):
             self.tts_speech_token_dict.pop(this_uuid)
             self.llm_end_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
-            self.trt_context_pool.put(self.trt_context_dict[this_uuid])
-            self.trt_context_dict.pop(this_uuid)
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
             torch.cuda.current_stream().synchronize()

+ 23 - 21
cosyvoice/flow/flow_matching.py

@@ -16,6 +16,7 @@ import threading
 import torch
 import torch.nn.functional as F
 from matcha.models.components.flow_matching import BASECFM
+from cosyvoice.utils.common import set_all_random_seed
 
 
 class ConditionalCFM(BASECFM):
@@ -32,7 +33,6 @@ class ConditionalCFM(BASECFM):
         in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
         # Just change the architecture of the estimator here
         self.estimator = estimator
-        self.lock = threading.Lock()
 
     @torch.inference_mode()
     def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
@@ -127,26 +127,27 @@ class ConditionalCFM(BASECFM):
         if isinstance(self.estimator, torch.nn.Module):
             return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
         else:
-            estimator, trt_engine = self.estimator.acquire_estimator()
-            estimator.set_input_shape('x', (2, 80, x.size(2)))
-            estimator.set_input_shape('mask', (2, 1, x.size(2)))
-            estimator.set_input_shape('mu', (2, 80, x.size(2)))
-            estimator.set_input_shape('t', (2,))
-            estimator.set_input_shape('spks', (2, 80))
-            estimator.set_input_shape('cond', (2, 80, x.size(2)))
-            data_ptrs = [x.contiguous().data_ptr(),
-                         mask.contiguous().data_ptr(),
-                         mu.contiguous().data_ptr(),
-                         t.contiguous().data_ptr(),
-                         spks.contiguous().data_ptr(),
-                         cond.contiguous().data_ptr(),
-                         x.data_ptr()]
-            for i, j in enumerate(data_ptrs):
-                estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
-            # run trt engine
-            assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
-            torch.cuda.current_stream().synchronize()
-            self.estimator.release_estimator(estimator)
+            [estimator, stream], trt_engine = self.estimator.acquire_estimator()
+            with stream:
+                estimator.set_input_shape('x', (2, 80, x.size(2)))
+                estimator.set_input_shape('mask', (2, 1, x.size(2)))
+                estimator.set_input_shape('mu', (2, 80, x.size(2)))
+                estimator.set_input_shape('t', (2,))
+                estimator.set_input_shape('spks', (2, 80))
+                estimator.set_input_shape('cond', (2, 80, x.size(2)))
+                data_ptrs = [x.contiguous().data_ptr(),
+                            mask.contiguous().data_ptr(),
+                            mu.contiguous().data_ptr(),
+                            t.contiguous().data_ptr(),
+                            spks.contiguous().data_ptr(),
+                            cond.contiguous().data_ptr(),
+                            x.data_ptr()]
+                for i, j in enumerate(data_ptrs):
+                    estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
+                # run trt engine
+                assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+                torch.cuda.current_stream().synchronize()
+            self.estimator.release_estimator(estimator, stream)
             return x
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -194,6 +195,7 @@ class ConditionalCFM(BASECFM):
 class CausalConditionalCFM(ConditionalCFM):
     def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
         super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
+        set_all_random_seed(0)
         self.rand_noise = torch.randn([1, 80, 50 * 300])
 
     @torch.inference_mode()

+ 59 - 17
cosyvoice/llm/llm.py

@@ -11,7 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import queue
 import random
+import time
+import threading
 from typing import Dict, Optional, Callable, List, Generator
 import torch
 from torch import nn
@@ -170,6 +173,7 @@ class TransformerLM(torch.nn.Module):
             sampling: int = 25,
             max_token_text_ratio: float = 20,
             min_token_text_ratio: float = 2,
+            uuid: str = '',
     ) -> Generator[torch.Tensor, None, None]:
         device = text.device
         text = torch.concat([prompt_text, text], dim=1)
@@ -270,7 +274,6 @@ class Qwen2LM(TransformerLM):
         self.llm_input_size = llm_input_size
         self.llm_output_size = llm_output_size
         self.speech_token_size = speech_token_size
-
         # 2. build speech token language model related modules
         self.sos_eos = 0
         self.task_id = 1
@@ -292,6 +295,11 @@ class Qwen2LM(TransformerLM):
         # 4. sampling method
         self.sampling = sampling
         self.mix_ratio = mix_ratio
+        
+        # 5. vllm related
+        self.stop_token_ids = [speech_token_size + i for i in range(3)]
+        self.vllm_output_queue = {}
+        self.lock = threading.Lock()
 
     def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
         lm_target, lm_input = [], []
@@ -382,6 +390,7 @@ class Qwen2LM(TransformerLM):
             sampling: int = 25,
             max_token_text_ratio: float = 20,
             min_token_text_ratio: float = 2,
+            uuid: str = '',
     ) -> Generator[torch.Tensor, None, None]:
         device = text.device
         text = torch.concat([prompt_text, text], dim=1)
@@ -402,22 +411,55 @@ class Qwen2LM(TransformerLM):
         max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
 
         # 5. step by step decode
-        out_tokens = []
-        cache = None
-        for i in range(max_len):
-            y_pred, cache = self.llm.forward_one_step(lm_input,
-                                                      masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
-                                                      cache=cache)
-            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
-            if top_ids == self.speech_token_size:
-                break
-            if top_ids > self.speech_token_size:
-                continue
-            # in stream mode, yield token one by one
-            yield top_ids
-            out_tokens.append(top_ids)
-            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+        for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
+            yield token
+
+    @torch.inference_mode()
+    def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
+        if hasattr(self, 'vllm'):
+            from vllm import SamplingParams, RequestOutput
+            sampling_params = SamplingParams(top_k=sampling,
+                                             stop_token_ids=self.stop_token_ids,
+                                             min_tokens=min_len,
+                                             max_tokens=max_len)
+            with self.lock:
+                self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
+                self.vllm_output_queue[uuid] = queue.Queue()
+            out_tokens = []
+            while True:
+                with self.lock:
+                    if self.vllm_output_queue[uuid].empty() is True:
+                        request_outputs: List[RequestOutput] = self.vllm.step()
+                        for request_output in request_outputs:
+                            top_ids = list(request_output.outputs[0].token_ids)[-1]
+                            self.vllm_output_queue[request_output.request_id].put(top_ids)
+                if self.vllm_output_queue[uuid].empty() is False:
+                    top_ids = self.vllm_output_queue[uuid].get()
+                    if top_ids in self.stop_token_ids:
+                        break
+                    # in stream mode, yield token one by one
+                    yield top_ids
+                    out_tokens.append(top_ids)
+                time.sleep(0.001)
+            with self.lock:
+                self.vllm_output_queue.pop(uuid)
+        else:
+            out_tokens = []
+            cache = None
+            for i in range(max_len):
+                y_pred, cache = self.llm.forward_one_step(lm_input,
+                                                        masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
+                                                        cache=cache)
+                logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
+                top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
+                if top_ids == self.speech_token_size:
+                    break
+                if top_ids > self.speech_token_size:
+                    continue
+                # in stream mode, yield token one by one
+                yield top_ids
+                out_tokens.append(top_ids)
+                lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
 
     @torch.inference_mode()
     def inference_bistream(

+ 6 - 5
cosyvoice/utils/common.py

@@ -169,17 +169,18 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
 
 
 class TrtContextWrapper:
-    def __init__(self, trt_engine, trt_concurrent=1):
-        self.trt_context_pool = queue.Queue()
+    def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
+        self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
         self.trt_engine = trt_engine
         for _ in range(trt_concurrent):
             trt_context = trt_engine.create_execution_context()
+            trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
             assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
-            self.trt_context_pool.put(trt_context)
+            self.trt_context_pool.put([trt_context, trt_stream])
         assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
 
     def acquire_estimator(self):
         return self.trt_context_pool.get(), self.trt_engine
 
-    def release_estimator(self, context):
-        self.trt_context_pool.put(context)
+    def release_estimator(self, context, stream):
+        self.trt_context_pool.put([context, stream])

+ 2 - 1
cosyvoice/utils/file_utils.py

@@ -58,7 +58,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
     network = builder.create_network(network_flags)
     parser = trt.OnnxParser(network, logger)
     config = builder.create_builder_config()
-    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31)  # 1GB
+    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)  # 4GB
     if fp16:
         config.set_flag(trt.BuilderFlag.FP16)
     profile = builder.create_optimization_profile()
@@ -122,6 +122,7 @@ def export_cosyvoice2_vllm(model, model_path, device):
     model.llm.model.config.tie_word_embeddings = False
     model.llm.model.config.use_bias = True
     model.llm.model.save_pretrained(model_path)
+    os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
     model.llm.model.config.vocab_size = tmp_vocab_size
     model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
     model.llm.model.set_input_embeddings(embed_tokens)