Przeglądaj źródła

mv AsyncLLMEngine init to CosyVoice2

雾聪 1 rok temu
rodzic
commit
9b3f351496
2 zmienionych plików z 23 dodań i 21 usunięć
  1. 22 0
      cosyvoice/cli/cosyvoice.py
  2. 1 21
      cosyvoice/llm/llm_vllm.py

+ 22 - 0
cosyvoice/cli/cosyvoice.py

@@ -166,7 +166,29 @@ class CosyVoice2(CosyVoice):
             logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
         if use_vllm:
             try:
+                os.environ["VLLM_USE_V1"] = '1'
+                from vllm import AsyncLLMEngine
+                from vllm.engine.arg_utils import AsyncEngineArgs                
+                # EngineArgs
+                ENGINE_ARGS = {
+                    "block_size": 16,
+                    "swap_space": 0,
+                    # "enforce_eager": True,
+                    "gpu_memory_utilization": 0.4,
+                    "max_num_batched_tokens": 1024,
+                    "max_model_len": 1024,
+                    "max_num_seqs": 256,
+                    "disable_log_requests": True,
+                    "disable_log_stats": True,
+                    "dtype": "bfloat16"
+                }
                 self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
+                engine_args = AsyncEngineArgs(
+                    model=model_dir,
+                    **ENGINE_ARGS,
+                )
+                self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
+                self.model.llm_engine = self.llm_engine
             except Exception as e:
                 logging.warning(f'use vllm inference failed. \n{e}')
                 raise e

+ 1 - 21
cosyvoice/llm/llm_vllm.py

@@ -31,20 +31,6 @@ from vllm.sampling_params import SamplingParams
 from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
 ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
 
-# EngineArgs
-ENGINE_ARGS = {
-    "block_size": 16,
-    "swap_space": 0,
-    # "enforce_eager": True,
-    "gpu_memory_utilization": 0.4,
-    "max_num_batched_tokens": 1024,
-    "max_model_len": 1024,
-    "max_num_seqs": 256,
-    "disable_log_requests": True,
-    "disable_log_stats": True,
-    "dtype": "float16"
-}
-
 from vllm.sampling_params import RequestOutputKind
 # SamplingParams
 SAMPLING_PARAMS = {
@@ -72,13 +58,7 @@ class VllmQwen2LM(Qwen2LM):
         self.fp16 = False
         self.half = lambda: None
         self.mix_ratio = mix_ratio
-        # ---------------------------------------------
-        # vllm engine 的参数配置
-        engine_args = AsyncEngineArgs(
-            model=model_dir,
-            **ENGINE_ARGS,
-        )
-        self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
+        self.llm_engine = None
 
         self.speech_token_size = 6564       # 6561 + 3
         self.llm_token_size = 151936        # llm  vocab_size