|
|
@@ -14,7 +14,6 @@
|
|
|
# limitations under the License.
|
|
|
import os
|
|
|
from typing import Generator
|
|
|
-import queue
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import threading
|
|
|
@@ -23,7 +22,7 @@ from torch.nn import functional as F
|
|
|
from contextlib import nullcontext
|
|
|
import uuid
|
|
|
from cosyvoice.utils.common import fade_in_out
|
|
|
-from cosyvoice.utils.file_utils import convert_onnx_to_trt
|
|
|
+from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
|
|
from cosyvoice.utils.common import TrtContextWrapper
|
|
|
|
|
|
|
|
|
@@ -33,14 +32,12 @@ class CosyVoiceModel:
|
|
|
llm: torch.nn.Module,
|
|
|
flow: torch.nn.Module,
|
|
|
hift: torch.nn.Module,
|
|
|
- fp16: bool = False,
|
|
|
- trt_concurrent: int = 1):
|
|
|
+ fp16: bool = False):
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
self.llm = llm
|
|
|
self.flow = flow
|
|
|
self.hift = hift
|
|
|
self.fp16 = fp16
|
|
|
- self.trt_concurrent = trt_concurrent
|
|
|
if self.fp16 is True:
|
|
|
self.llm.half()
|
|
|
self.flow.half()
|
|
|
@@ -59,9 +56,6 @@ class CosyVoiceModel:
|
|
|
self.stream_scale_factor = 1
|
|
|
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
|
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
|
|
- self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
|
|
- for _ in range(trt_concurrent):
|
|
|
- self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
|
|
self.lock = threading.Lock()
|
|
|
# dict used to store session related variable
|
|
|
self.tts_speech_token_dict = {}
|
|
|
@@ -69,7 +63,6 @@ class CosyVoiceModel:
|
|
|
self.mel_overlap_dict = {}
|
|
|
self.flow_cache_dict = {}
|
|
|
self.hift_cache_dict = {}
|
|
|
- self.trt_context_dict = {}
|
|
|
|
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
|
|
@@ -89,7 +82,7 @@ class CosyVoiceModel:
|
|
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
|
self.flow.encoder = flow_encoder
|
|
|
|
|
|
- def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
|
|
|
+ def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
|
|
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
|
|
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
|
|
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
|
|
@@ -98,7 +91,7 @@ class CosyVoiceModel:
|
|
|
with open(flow_decoder_estimator_model, 'rb') as f:
|
|
|
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
|
|
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
|
|
- self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
|
|
|
+ self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
|
|
|
|
|
def get_trt_kwargs(self):
|
|
|
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
|
|
@@ -108,7 +101,7 @@ class CosyVoiceModel:
|
|
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
|
|
|
|
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
|
|
- with self.llm_context, torch.cuda.amp.autocast(self.fp16):
|
|
|
+ with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
|
|
if isinstance(text, Generator):
|
|
|
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
|
|
|
for i in self.llm.inference_bistream(text=text,
|
|
|
@@ -125,7 +118,8 @@ class CosyVoiceModel:
|
|
|
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
|
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
- embedding=llm_embedding.to(self.device)):
|
|
|
+ embedding=llm_embedding.to(self.device),
|
|
|
+ uuid=uuid):
|
|
|
self.tts_speech_token_dict[uuid].append(i)
|
|
|
self.llm_end_dict[uuid] = True
|
|
|
|
|
|
@@ -180,13 +174,11 @@ class CosyVoiceModel:
|
|
|
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
|
|
# this_uuid is used to track variables related to this inference thread
|
|
|
this_uuid = str(uuid.uuid1())
|
|
|
- this_trt_context = self.trt_context_pool.get()
|
|
|
with self.lock:
|
|
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
|
|
self.hift_cache_dict[this_uuid] = None
|
|
|
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
|
|
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
|
|
- self.trt_context_dict[this_uuid] = this_trt_context
|
|
|
if source_speech_token.shape[1] == 0:
|
|
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
|
else:
|
|
|
@@ -240,8 +232,6 @@ class CosyVoiceModel:
|
|
|
self.mel_overlap_dict.pop(this_uuid)
|
|
|
self.hift_cache_dict.pop(this_uuid)
|
|
|
self.flow_cache_dict.pop(this_uuid)
|
|
|
- self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
|
|
- self.trt_context_dict.pop(this_uuid)
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.current_stream().synchronize()
|
|
|
@@ -253,14 +243,12 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
llm: torch.nn.Module,
|
|
|
flow: torch.nn.Module,
|
|
|
hift: torch.nn.Module,
|
|
|
- fp16: bool = False,
|
|
|
- trt_concurrent: int = 1):
|
|
|
+ fp16: bool = False):
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
self.llm = llm
|
|
|
self.flow = flow
|
|
|
self.hift = hift
|
|
|
self.fp16 = fp16
|
|
|
- self.trt_concurrent = trt_concurrent
|
|
|
if self.fp16 is True:
|
|
|
self.llm.half()
|
|
|
self.flow.half()
|
|
|
@@ -273,22 +261,28 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
self.speech_window = np.hamming(2 * self.source_cache_len)
|
|
|
# rtf and decoding related
|
|
|
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
|
|
- self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
|
|
- for _ in range(trt_concurrent):
|
|
|
- self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
|
|
|
self.lock = threading.Lock()
|
|
|
# dict used to store session related variable
|
|
|
self.tts_speech_token_dict = {}
|
|
|
self.llm_end_dict = {}
|
|
|
self.hift_cache_dict = {}
|
|
|
- self.trt_context_dict = {}
|
|
|
|
|
|
def load_jit(self, flow_encoder_model):
|
|
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
|
self.flow.encoder = flow_encoder
|
|
|
|
|
|
+ def load_vllm(self, model_dir):
|
|
|
+ export_cosyvoice2_vllm(self.llm, model_dir, self.device)
|
|
|
+ from vllm import EngineArgs, LLMEngine
|
|
|
+ engine_args = EngineArgs(model=model_dir,
|
|
|
+ skip_tokenizer_init=True,
|
|
|
+ enable_prompt_embeds=True,
|
|
|
+ gpu_memory_utilization=0.2)
|
|
|
+ self.llm.vllm = LLMEngine.from_engine_args(engine_args)
|
|
|
+ del self.llm.llm.model.model.layers
|
|
|
+
|
|
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
|
|
- with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
|
|
|
+ with torch.cuda.amp.autocast(self.fp16):
|
|
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
|
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
prompt_token=prompt_token.to(self.device),
|
|
|
@@ -330,11 +324,9 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
|
|
# this_uuid is used to track variables related to this inference thread
|
|
|
this_uuid = str(uuid.uuid1())
|
|
|
- this_trt_context = self.trt_context_pool.get()
|
|
|
with self.lock:
|
|
|
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
|
|
self.hift_cache_dict[this_uuid] = None
|
|
|
- self.trt_context_dict[this_uuid] = this_trt_context
|
|
|
if source_speech_token.shape[1] == 0:
|
|
|
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
|
|
else:
|
|
|
@@ -388,8 +380,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
self.tts_speech_token_dict.pop(this_uuid)
|
|
|
self.llm_end_dict.pop(this_uuid)
|
|
|
self.hift_cache_dict.pop(this_uuid)
|
|
|
- self.trt_context_pool.put(self.trt_context_dict[this_uuid])
|
|
|
- self.trt_context_dict.pop(this_uuid)
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.current_stream().synchronize()
|