|
|
@@ -86,46 +86,35 @@ class VllmQwen2LM(Qwen2LM):
|
|
|
self.task_token_id = self.sos_eos_token_id + 1
|
|
|
self.zero_token_id = self.task_token_id + 1
|
|
|
|
|
|
- # 不能直接在同步函数正确的使用 异步的生成器函数,即使使用协程也会对vllm造成崩溃
|
|
|
- # 使用 queue 的方式,后台线程运行推理任务
|
|
|
- self.task_queue = queue.Queue()
|
|
|
+ # vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务
|
|
|
self.loop = asyncio.new_event_loop()
|
|
|
self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
|
|
|
self.loop_thread.start()
|
|
|
- # 运行后台协程,用于处理任务队列中的任务
|
|
|
- # TODO: 目前只能单任务运行,多任务运行需要对 inference_processor 进行修改
|
|
|
- asyncio.run_coroutine_threadsafe(self.inference_processor(self.task_queue), self.loop)
|
|
|
|
|
|
def _run_event_loop(self):
|
|
|
asyncio.set_event_loop(self.loop)
|
|
|
self.loop.run_forever()
|
|
|
|
|
|
- async def inference_processor(self, task_queue):
|
|
|
- while True:
|
|
|
- try:
|
|
|
- logging.debug(f"inference_processor")
|
|
|
- out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens = task_queue.get()
|
|
|
- sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
|
|
- sampling_params.stop_token_ids = stop_token_ids or [6561]
|
|
|
- if max_tokens:
|
|
|
- sampling_params.max_tokens = max_tokens
|
|
|
- async for output in self.llm_engine.generate(
|
|
|
- {
|
|
|
- "prompt_token_ids": prompt_token_ids,
|
|
|
- },
|
|
|
- sampling_params=sampling_params,
|
|
|
- request_id=request_id or f"{time.time()}",
|
|
|
- ):
|
|
|
- out_queue.put((output.outputs[0], output.finished))
|
|
|
- except Exception as e:
|
|
|
- logging.error(f"Error in inference_processor: {e}")
|
|
|
+ async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens):
|
|
|
+ sampling_params = SamplingParams(**SAMPLING_PARAMS)
|
|
|
+ sampling_params.stop_token_ids = stop_token_ids or [6561]
|
|
|
+ if max_tokens:
|
|
|
+ sampling_params.max_tokens = max_tokens
|
|
|
+ async for output in self.llm_engine.generate(
|
|
|
+ {
|
|
|
+ "prompt_token_ids": prompt_token_ids,
|
|
|
+ },
|
|
|
+ sampling_params=sampling_params,
|
|
|
+ request_id=request_id or f"{time.time()}",
|
|
|
+ ):
|
|
|
+ out_queue.put((output.outputs[0], output.finished))
|
|
|
|
|
|
def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
|
|
|
- # 使用 同步转异步 会导致vllm崩溃,目前选择 queue 的方式,后台线程运行推理任务
|
|
|
- # 提交推理任务到队列中
|
|
|
out_queue = queue.Queue()
|
|
|
- self.task_queue.put((out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens))
|
|
|
- # 将 out_queue 的结果返回
|
|
|
+ asyncio.run_coroutine_threadsafe(
|
|
|
+ self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop
|
|
|
+ )
|
|
|
+ # 接收 out_queue 返回的结果
|
|
|
finished = False
|
|
|
while not finished:
|
|
|
(output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
|