llm_vllm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import contextlib
  16. import time
  17. from typing import List, Generator, AsyncGenerator
  18. import torch
  19. from cosyvoice.utils.file_utils import logging
  20. from cosyvoice.llm.llm import Qwen2LM
  21. # 启用vllm V1版本
  22. import os
  23. os.environ["VLLM_USE_V1"] = '1'
  24. from vllm import ModelRegistry
  25. from vllm import LLMEngine, AsyncLLMEngine, CompletionOutput
  26. from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs
  27. from vllm.sampling_params import SamplingParams
  28. from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
  29. ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
  30. # EngineArgs
  31. ENGINE_ARGS = {
  32. "block_size": 16,
  33. "swap_space": 0,
  34. # "enforce_eager": True,
  35. "gpu_memory_utilization": 0.4,
  36. "max_num_batched_tokens": 1024,
  37. "max_model_len": 1024,
  38. "max_num_seqs": 256,
  39. "disable_log_requests": True,
  40. "disable_log_stats": True,
  41. }
  42. from vllm.sampling_params import RequestOutputKind
  43. # SamplingParams
  44. SAMPLING_PARAMS = {
  45. "temperature": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
  46. "top_p": 1, # 不能低于0.8, 否则会生成非常多的空音频,或者无法正常生成语音Token
  47. "top_k": 25,
  48. # "min_tokens": 80, # 不支持设置最小的tokens数量设置,开启后vllm直接崩溃,无法启动
  49. # "presence_penalty": 1.0, # 不支持设置
  50. # "frequency_penalty": 0.0, # 不支持设置
  51. "max_tokens": 1024,
  52. "detokenize": False, # 目前 vllm 0.7.3 v1版本中设置无效,待后续版本更新后减少计算
  53. "ignore_eos": False,
  54. "output_kind": RequestOutputKind.DELTA # 设置为DELTA,如调整该参数,请同时调整llm_inference的处理代码
  55. }
  56. def tensor_to_list(tensor: torch.tensor):
  57. return tensor.view(-1).cpu().numpy().tolist()
  58. class VllmQwen2LM(Qwen2LM):
  59. def __init__(
  60. self,
  61. model_dir,
  62. mix_ratio: List[int] = [5, 15],
  63. ):
  64. self.fp16 = False
  65. self.half = lambda: None
  66. self.mix_ratio = mix_ratio
  67. # ---------------------------------------------
  68. # vllm engine 的参数配置
  69. engine_args = AsyncEngineArgs(
  70. model=model_dir,
  71. **ENGINE_ARGS,
  72. )
  73. self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
  74. self.speech_token_size = 6564 # 6561 + 3
  75. self.llm_token_size = 151936 # llm vocab_size
  76. self.sos_eos_token_id = self.speech_token_size + self.llm_token_size + 1
  77. self.task_token_id = self.sos_eos_token_id + 1
  78. self.zero_token_id = self.task_token_id + 1
  79. async def async_llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
  80. -> AsyncGenerator[CompletionOutput, None]:
  81. assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
  82. invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
  83. assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
  84. # logging.debug('prompt_token_ids:', prompt_token_ids)
  85. # TODO: 增加上下文控制,取消请求时
  86. sampling_params = SamplingParams(**SAMPLING_PARAMS)
  87. sampling_params.stop_token_ids = stop_token_ids or [6561]
  88. if max_tokens:
  89. sampling_params.max_tokens = max_tokens
  90. async for output in self.llm_engine.generate(
  91. {
  92. "prompt_token_ids": prompt_token_ids,
  93. },
  94. sampling_params=sampling_params,
  95. request_id=request_id or f"{time.time()}",
  96. ):
  97. yield output.outputs[0]
  98. def llm_inference(self, prompt_token_ids: List[int], request_id: str=None, stop_token_ids=None, max_tokens=None)\
  99. -> Generator[CompletionOutput, None, None]:
  100. assert isinstance(prompt_token_ids, list) , "prompt_token_ids should be List[int]"
  101. invalid = next((i for i, x in enumerate(prompt_token_ids) if not isinstance(x, int)), None)
  102. assert invalid is None, f"Error in prompt_token_ids, Non-int element at index {invalid}: {prompt_token_ids[invalid]}"
  103. # logging.debug('prompt_token_ids:', prompt_token_ids)
  104. # TODO: 增加上下文控制,取消请求时
  105. sampling_params = SamplingParams(**SAMPLING_PARAMS)
  106. sampling_params.stop_token_ids = stop_token_ids or [6561]
  107. if max_tokens:
  108. sampling_params.max_tokens = max_tokens
  109. # 创建独立事件循环
  110. loop = asyncio.new_event_loop()
  111. try:
  112. asyncio.set_event_loop(loop)
  113. # 初始化异步生成器
  114. async_gen = self.llm_engine.generate(
  115. {
  116. "prompt_token_ids": prompt_token_ids,
  117. },
  118. sampling_params=sampling_params,
  119. request_id=request_id or f"{time.time()}",
  120. )
  121. while True:
  122. try:
  123. # 同步获取异步结果
  124. output = loop.run_until_complete(async_gen.__anext__())
  125. yield output.outputs[0]
  126. except StopAsyncIteration:
  127. break
  128. except GeneratorExit:
  129. if async_gen is not None:
  130. loop.run_until_complete(async_gen.aclose())
  131. raise
  132. finally:
  133. # 资源清理
  134. print("资源清理...")
  135. if async_gen is not None:
  136. loop.run_until_complete(async_gen.aclose())
  137. loop.close()
  138. print("资源清理成功")
  139. def inference(
  140. self,
  141. text: torch.Tensor,
  142. text_len: torch.Tensor,
  143. prompt_text: torch.Tensor,
  144. prompt_text_len: torch.Tensor,
  145. prompt_speech_token: torch.Tensor,
  146. prompt_speech_token_len: torch.Tensor,
  147. embedding: torch.Tensor,
  148. sampling: int = 25,
  149. max_token_text_ratio: float = 20,
  150. min_token_text_ratio: float = 2,
  151. ) -> Generator[torch.Tensor|int, None, None]:
  152. prompt_text = tensor_to_list(prompt_text + torch.tensor(6564))
  153. prompt_speech_token = tensor_to_list(prompt_speech_token)
  154. text = tensor_to_list(text + torch.tensor(6564))
  155. prompt_token_ids = [self.sos_eos_token_id] + prompt_text + text + \
  156. [self.task_token_id] + prompt_speech_token
  157. max_tokens = len(text) * 20
  158. for output in self.llm_inference(
  159. prompt_token_ids,
  160. stop_token_ids=[6561],
  161. max_tokens=max_tokens,
  162. ):
  163. if output.token_ids[-1] == 6561:
  164. need_add_tokens = output.token_ids[:-1]
  165. else:
  166. need_add_tokens = output.token_ids
  167. # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
  168. # yield need_add_tokens
  169. for token in need_add_tokens:
  170. yield token
  171. def inference_bistream(
  172. self,
  173. text: Generator,
  174. prompt_text: torch.Tensor,
  175. prompt_text_len: torch.Tensor,
  176. prompt_speech_token: torch.Tensor,
  177. prompt_speech_token_len: torch.Tensor,
  178. embedding: torch.Tensor,
  179. sampling: int = 25,
  180. max_token_text_ratio: float = 20,
  181. min_token_text_ratio: float = 2,
  182. ) -> Generator[torch.Tensor, None, None]:
  183. last_tokens = []
  184. prompt_token_ids = [self.sos_eos_token_id]
  185. text_tokens_cache = prompt_text
  186. for this_text in text:
  187. this_text = tensor_to_list(this_text + torch.tensor(6564))
  188. # text need tokens
  189. assert isinstance(this_text, list), "text need token ids List[int]."
  190. text_tokens_cache += this_text
  191. while len(llm_prompt_speech_token) != 0:
  192. if len(text_tokens_cache) >= self.mix_ratio[0]:
  193. text_input_token = text_tokens_cache[:self.mix_ratio[0]]
  194. speech_input_token = llm_prompt_speech_token[:self.mix_ratio[1]]
  195. prompt_token_ids += text_input_token + speech_input_token
  196. # reset the last cache
  197. text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
  198. llm_prompt_speech_token = llm_prompt_speech_token[self.mix_ratio[1]:]
  199. else:
  200. logging.info('not enough text token to decode, wait for more')
  201. break
  202. if len(llm_prompt_speech_token) == 0:
  203. if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
  204. logging.info('get fill token, need to append more text token')
  205. if len(text_tokens_cache) >= self.mix_ratio[0]:
  206. text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
  207. prompt_token_ids += text_tokens_temp
  208. logging.info('append {} text token'.format(len(text_tokens_temp)))
  209. text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
  210. else:
  211. logging.info('not enough text token to decode, wait for more')
  212. continue
  213. for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
  214. last_tokens = output.token_ids
  215. if last_tokens[-1] == 6563:
  216. need_add_tokens = last_tokens[:-1]
  217. else:
  218. need_add_tokens = last_tokens
  219. # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
  220. # yield need_add_tokens
  221. for token in need_add_tokens:
  222. yield token
  223. prompt_token_ids.extend(need_add_tokens)
  224. prompt_token_ids += text_tokens_cache + [self.task_token_id]
  225. logging.info('no more text token, decode until met eos')
  226. for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
  227. if output.token_ids[-1] == 6561:
  228. need_add_tokens = output.token_ids[:-1]
  229. else:
  230. need_add_tokens = output.token_ids
  231. # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
  232. # yield need_add_tokens
  233. for token in need_add_tokens:
  234. yield token