llm_vllm.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import time
  15. import queue
  16. import asyncio
  17. import threading
  18. from typing import List, Generator, AsyncGenerator
  19. import torch
  20. from cosyvoice.utils.file_utils import logging
  21. from cosyvoice.llm.llm import Qwen2LM
  22. # 启用vllm V1版本
  23. import os
  24. os.environ["VLLM_USE_V1"] = '1'
  25. from vllm import ModelRegistry
  26. from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
  27. from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
  28. from vllm.sampling_params import SamplingParams
  29. from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
  30. ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
  31. # EngineArgs
  32. ENGINE_ARGS = {
  33. "block_size": 16,
  34. "swap_space": 0,
  35. # "enforce_eager": True,
  36. "gpu_memory_utilization": 0.4,
  37. "max_num_batched_tokens": 1024,
  38. "max_model_len": 1024,
  39. "max_num_seqs": 256,
  40. "disable_log_requests": True,
  41. "disable_log_stats": True,
  42. "dtype": "float16"
  43. }
  44. from vllm.sampling_params import RequestOutputKind
  45. # SamplingParams
  46. SAMPLING_PARAMS = {
  47. "temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
  48. "top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
  49. "top_k": 25,
  50. # "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动
  51. # "presence_penalty": 1.0, # 不支持设置
  52. # "frequency_penalty": 0.0, # 不支持设置
  53. "max_tokens": 1024,
  54. "detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算
  55. "ignore_eos": False,
  56. "output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码
  57. }
  58. def tensor_to_list(tensor: torch.tensor):
  59. return tensor.view(-1).cpu().numpy().tolist()
  60. class VllmQwen2LM(Qwen2LM):
  61. def __init__(
  62. self,
  63. model_dir,
  64. mix_ratio: List[int] = [5, 15],
  65. ):
  66. self.fp16 = False
  67. self.half = lambda: None
  68. self.mix_ratio = mix_ratio
  69. # ---------------------------------------------
  70. # vllm engine 的参数配置
  71. engine_args = AsyncEngineArgs(
  72. model=model_dir,
  73. **ENGINE_ARGS,
  74. )
  75. self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
  76. self.speech_token_size = 6564 # 6561 + 3
  77. self.llm_token_size = 151936 # llm vocab_size
  78. self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
  79. self.task_token_id = self.sos_eos_token_id + 1
  80. self.zero_token_id = self.task_token_id + 1
  81. # vllm 的推理任务需要在一个固定的事件循环中,因此启动一个后台线程运行转用于推理任务
  82. self.loop = asyncio.new_event_loop()
  83. self.loop_thread = threading.Thread(target=self._run_event_loop, daemon=True)
  84. self.loop_thread.start()
  85. def _run_event_loop(self):
  86. asyncio.set_event_loop(self.loop)
  87. self.loop.run_forever()
  88. async def async_llm_inference(self, out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens):
  89. sampling_params = SamplingParams(**SAMPLING_PARAMS)
  90. sampling_params.stop_token_ids = stop_token_ids or [6561]
  91. if max_tokens:
  92. sampling_params.max_tokens = max_tokens
  93. async for output in self.llm_engine.generate(
  94. {
  95. "prompt_token_ids": prompt_token_ids,
  96. },
  97. sampling_params=sampling_params,
  98. request_id=request_id or f"{time.time()}",
  99. ):
  100. out_queue.put((output.outputs[0], output.finished))
  101. def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None):
  102. out_queue = queue.Queue()
  103. asyncio.run_coroutine_threadsafe(
  104. self.async_llm_inference(out_queue, prompt_token_ids, request_id, stop_token_ids, max_tokens), self.loop
  105. )
  106. # 接收 out_queue 返回的结果
  107. finished = False
  108. while not finished:
  109. (output, finished) = out_queue.get_nowait() if not out_queue.empty() else out_queue.get()
  110. yield output
  111. def inference(
  112. self,
  113. text: torch.Tensor,
  114. text_len: torch.Tensor,
  115. prompt_text: torch.Tensor,
  116. prompt_text_len: torch.Tensor,
  117. prompt_speech_token: torch.Tensor,
  118. prompt_speech_token_len: torch.Tensor,
  119. embedding: torch.Tensor,
  120. sampling: int = 25,
  121. max_token_text_ratio: float = 20,
  122. min_token_text_ratio: float = 2,
  123. ) -> Generator[torch.Tensor|int, None, None]:
  124. prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
  125. prompt_speech_token = tensor_to_list(prompt_speech_token)
  126. text = tensor_to_list(text + torch.tensor(6564))
  127. prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
  128. [self.task_token_id] + prompt_speech_token
  129. max_tokens = len(text) * 20
  130. for output in self.llm_inference(
  131. prompt_token_ids,
  132. stop_token_ids=[6561],
  133. max_tokens=max_tokens,
  134. ):
  135. if output.token_ids[-1] == 6561:
  136. need_add_tokens = output.token_ids[:-1]
  137. else:
  138. need_add_tokens = output.token_ids
  139. for token in need_add_tokens:
  140. yield token
  141. def inference_bistream(
  142. self,
  143. text: Generator,
  144. prompt_text: torch.Tensor,
  145. prompt_text_len: torch.Tensor,
  146. prompt_speech_token: torch.Tensor,
  147. prompt_speech_token_len: torch.Tensor,
  148. embedding: torch.Tensor,
  149. sampling: int = 25,
  150. max_token_text_ratio: float = 20,
  151. min_token_text_ratio: float = 2,
  152. ) -> Generator[torch.Tensor, None, None]:
  153. prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
  154. prompt_speech_token = tensor_to_list(prompt_speech_token)
  155. last_tokens = []
  156. prompt_token_ids = [self.sos_eos_token_id]
  157. text_tokens_cache = prompt_text
  158. for this_text in text:
  159. this_text = tensor_to_list(this_text + torch.tensor(6564))
  160. # text need tokens
  161. assert isinstance(this_text, list), "text need token ids List[int]."
  162. text_tokens_cache += this_text
  163. while len(prompt_speech_token) != 0:
  164. if len(text_tokens_cache) >= self.mix_ratio[0]:
  165. text_input_token = text_tokens_cache[:self.mix_ratio[0]]
  166. speech_input_token = prompt_speech_token[:self.mix_ratio[1]]
  167. prompt_token_ids += text_input_token + speech_input_token
  168. # reset the last cache
  169. text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
  170. prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
  171. else:
  172. break
  173. if len(prompt_speech_token) == 0:
  174. if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
  175. if len(text_tokens_cache) >= self.mix_ratio[0]:
  176. text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
  177. prompt_token_ids += text_tokens_temp
  178. text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
  179. else:
  180. continue
  181. for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
  182. last_tokens = output.token_ids
  183. if last_tokens[-1] == 6563:
  184. need_add_tokens = last_tokens[:-1]
  185. else:
  186. need_add_tokens = last_tokens
  187. for token in need_add_tokens:
  188. yield token
  189. prompt_token_ids.extend(need_add_tokens)
  190. prompt_token_ids += text_tokens_cache + [self.task_token_id]
  191. for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
  192. if output.token_ids[-1] == 6561:
  193. need_add_tokens = output.token_ids[:-1]
  194. else:
  195. need_add_tokens = output.token_ids
  196. for token in need_add_tokens:
  197. yield token