Selaa lähdekoodia

Revert "mv AsyncLLMEngine init to CosyVoice2"

This reverts commit 9b3f35149620681af225c3a61e614f307ac5aacd.
雾聪 1 vuosi sitten
vanhempi
commit
96950745a6
2 muutettua tiedostoa jossa 21 lisäystä ja 23 poistoa
  1. 0 22
      cosyvoice/cli/cosyvoice.py
  2. 21 1
      cosyvoice/llm/llm_vllm.py

+ 0 - 22
cosyvoice/cli/cosyvoice.py

@@ -166,29 +166,7 @@ class CosyVoice2(CosyVoice):
             logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
         if use_vllm:
             try:
-                os.environ["VLLM_USE_V1"] = '1'
-                from vllm import AsyncLLMEngine
-                from vllm.engine.arg_utils import AsyncEngineArgs                
-                # EngineArgs
-                ENGINE_ARGS = {
-                    "block_size": 16,
-                    "swap_space": 0,
-                    # "enforce_eager": True,
-                    "gpu_memory_utilization": 0.4,
-                    "max_num_batched_tokens": 1024,
-                    "max_model_len": 1024,
-                    "max_num_seqs": 256,
-                    "disable_log_requests": True,
-                    "disable_log_stats": True,
-                    "dtype": "bfloat16"
-                }
                 self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
-                engine_args = AsyncEngineArgs(
-                    model=model_dir,
-                    **ENGINE_ARGS,
-                )
-                self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
-                self.model.llm_engine = self.llm_engine
             except Exception as e:
                 logging.warning(f'use vllm inference failed. \n{e}')
                 raise e

+ 21 - 1
cosyvoice/llm/llm_vllm.py

@@ -31,6 +31,20 @@ from vllm.sampling_params import SamplingParams
 from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
 ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
 
+# EngineArgs
+ENGINE_ARGS = {
+    "block_size": 16,
+    "swap_space": 0,
+    # "enforce_eager": True,
+    "gpu_memory_utilization": 0.4,
+    "max_num_batched_tokens": 1024,
+    "max_model_len": 1024,
+    "max_num_seqs": 256,
+    "disable_log_requests": True,
+    "disable_log_stats": True,
+    "dtype": "float16"
+}
+
 from vllm.sampling_params import RequestOutputKind
 # SamplingParams
 SAMPLING_PARAMS = {
@@ -58,7 +72,13 @@ class VllmQwen2LM(Qwen2LM):
         self.fp16 = False
         self.half = lambda: None
         self.mix_ratio = mix_ratio
-        self.llm_engine = None
+        # ---------------------------------------------
+        # vllm engine 的参数配置
+        engine_args = AsyncEngineArgs(
+            model=model_dir,
+            **ENGINE_ARGS,
+        )
+        self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
 
         self.speech_token_size = 6564       # 6561 + 3
         self.llm_token_size = 151936        # llm  vocab_size