Przeglądaj źródła

refactor(llm): 重构 VLLM 推理方式

- 新增基于队列和线程的异步推理机制
- 优化同步推理接口,使用新机制实现
qihua 1 rok temu
rodzic
commit
d4d187bd8c
1 zmienionych plików z 55 dodań i 54 usunięć
  1. 55 54
      cosyvoice/llm/llm_vllm.py

+ 55 - 54
cosyvoice/llm/llm_vllm.py

@@ -11,9 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import asyncio
-import contextlib
 import time
+import queue
+import asyncio
+import threading
 from typing import List, Generator, AsyncGenerator
 import torch
 from cosyvoice.utils.file_utils import logging
@@ -41,6 +42,7 @@ ENGINE_ARGS = {
     "max_num_seqs": 256,
     "disable_log_requests": True,
     "disable_log_stats": True,
+    "dtype": "float16"
 }
 
 from vllm.sampling_params import RequestOutputKind
@@ -84,13 +86,42 @@ class VllmQwen2LM(Qwen2LM):
         self.task_token_id = self.sos_eos_token_id + 1
         self.zero_token_id = self.task_token_id + 1
 
+        # 不能直接在同步函数正确的使用 异步的生成器函数,即使使用协程也会对vllm造成崩溃
+        # 使用 queue 的方式,后台线程运行推理任务
+        self.task_queue = queue.Queue()
+        self.loop = asyncio.new_event_loop()
+        self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
+        self.loop_thread.start()
+        # 运行后台协程,用于处理任务队列中的任务
+        # TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改
+        asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop)
+
+    def _run_event_loop(self):
+        asyncio.set_event_loop(self.loop)
+        self.loop.run_forever()
+
+    async def inference_processor(self, task_queue):
+        while True:
+            try:
+                print(f"inference_processor")
+                out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get()
+                sampling_params = SamplingParams(**SAMPLING_PARAMS)
+                sampling_params.stop_token_ids = stop_token_ids or [6561]
+                if max_tokens:
+                    sampling_params.max_tokens = max_tokens
+                async for output in self.llm_engine.generate(
+                        {
+                            "prompt_token_ids": prompt_token_ids,
+                        },
+                        sampling_params=sampling_params,
+                        request_id=request_id or f"{time.time()}",
+                ):
+                    out_queue.put((output.outputs[0], output.finished))
+            except Exception as e:
+                logging.error(f"Error in inference_processor: {e}")
+
     async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
             -> AsyncGenerator[CompletionOutput, None]:
-        assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
-        invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
-        assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
-        # logging.debug('prompt_token_ids:', prompt_token_ids)
-        # TODO: 增加上下文控制,取消请求时
         sampling_params = SamplingParams(**SAMPLING_PARAMS)
         sampling_params.stop_token_ids = stop_token_ids or [6561]
         if max_tokens:
@@ -104,49 +135,16 @@ class VllmQwen2LM(Qwen2LM):
         ):
             yield output.outputs[0]
 
-
-    def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
-            -> Generator[CompletionOutput, None, None]:
-        assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
-        invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
-        assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
-        # logging.debug('prompt_token_ids:', prompt_token_ids)
-        # TODO: 增加上下文控制,取消请求时
-        sampling_params = SamplingParams(**SAMPLING_PARAMS)
-        sampling_params.stop_token_ids = stop_token_ids or [6561]
-        if max_tokens:
-            sampling_params.max_tokens = max_tokens
-
-        # 创建独立事件循环
-        loop = asyncio.new_event_loop()
-        try:
-            asyncio.set_event_loop(loop)
-            # 初始化异步生成器
-            async_gen = self.llm_engine.generate(
-                    {
-                        "prompt_token_ids": prompt_token_ids,
-                    },
-                    sampling_params=sampling_params,
-                    request_id=request_id or f"{time.time()}",
-            )
-            while True:
-                try:
-                    # 同步获取异步结果
-                    output = loop.run_until_complete(async_gen.__anext__())
-                    yield output.outputs[0]
-                except StopAsyncIteration:
-                    break
-        except GeneratorExit:
-            if async_gen is not None:
-                loop.run_until_complete(async_gen.aclose())
-            raise
-        finally:
-            # 资源清理
-            print("资源清理...")
-            if async_gen is not None:
-                loop.run_until_complete(async_gen.aclose())
-                loop.close()
-            print("资源清理成功")
+    def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
+        # 使用 同步转异步 会导致vllm崩溃,目前选择 queue 的方式,后台线程运行推理任务
+        # 提交推理任务到队列中
+        out_queue = queue.Queue()
+        self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens))
+        # 将 out_queue 的结果返回
+        finished = False
+        while not finished:
+            (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
+            yield output
 
     def inference(
             self,
@@ -194,6 +192,9 @@ class VllmQwen2LM(Qwen2LM):
             max_token_text_ratio: float = 20,
             min_token_text_ratio: float = 2,
     ) -> Generator[torch.Tensor, None, None]:
+        prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
+        prompt_speech_token = tensor_to_list(prompt_speech_token)
+
         last_tokens = []
         prompt_token_ids = [self.sos_eos_token_id]
         text_tokens_cache = prompt_text
@@ -202,18 +203,18 @@ class VllmQwen2LM(Qwen2LM):
             # text need tokens
             assert isinstance(this_text, list), "text need token ids List[int]."
             text_tokens_cache += this_text
-            while len(llm_prompt_speech_token) != 0:
+            while len(prompt_speech_token) != 0:
                 if len(text_tokens_cache) >= self.mix_ratio[0]:
                     text_input_token = text_tokens_cache[:self.mix_ratio[0]]
-                    speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]]
+                    speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
                     prompt_token_ids += text_input_token + speech_input_token
                     # reset the last cache
                     text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
-                    llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:]
+                    prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
                 else:
                     logging.info('not enough text token to decode, wait for more')
                     break
-            if len(llm_prompt_speech_token) == 0:
+            if len(prompt_speech_token) == 0:
                 if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
                     logging.info('get fill token, need to append more text token')
                     if len(text_tokens_cache) >= self.mix_ratio[0]: