|
|
@@ -33,6 +33,8 @@ class CosyVoiceModel:
|
|
|
self.flow = flow
|
|
|
self.hift = hift
|
|
|
self.fp16 = fp16
|
|
|
+ self.llm.fp16 = fp16
|
|
|
+ self.flow.fp16 = fp16
|
|
|
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
|
|
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
|
|
self.token_overlap_len = 20
|
|
|
@@ -61,17 +63,17 @@ class CosyVoiceModel:
|
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
|
|
self.llm.to(self.device).eval()
|
|
|
- if self.fp16 is True:
|
|
|
- self.llm.half()
|
|
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
|
|
self.flow.to(self.device).eval()
|
|
|
# in case hift_model is a hifigan model
|
|
|
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
|
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
|
|
self.hift.to(self.device).eval()
|
|
|
+ if self.fp16 is True:
|
|
|
+ self.llm.half()
|
|
|
+ self.flow.half()
|
|
|
|
|
|
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
|
|
- assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
|
|
|
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
|
|
self.llm.text_encoder = llm_text_encoder
|
|
|
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
|
|
@@ -79,18 +81,16 @@ class CosyVoiceModel:
|
|
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
|
self.flow.encoder = flow_encoder
|
|
|
|
|
|
- def load_onnx(self, flow_decoder_estimator_model):
|
|
|
- import onnxruntime
|
|
|
- option = onnxruntime.SessionOptions()
|
|
|
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
- option.intra_op_num_threads = 1
|
|
|
- providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
|
+ def load_trt(self, flow_decoder_estimator_model):
|
|
|
del self.flow.decoder.estimator
|
|
|
- self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
|
|
+ import tensorrt as trt
|
|
|
+ with open(flow_decoder_estimator_model, 'rb') as f:
|
|
|
+ self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
|
|
+ if self.flow.decoder.estimator_engine is None:
|
|
|
+ raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
|
|
+ self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
|
|
|
|
|
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
|
|
- if self.fp16 is True:
|
|
|
- llm_embedding = llm_embedding.half()
|
|
|
with self.llm_context:
|
|
|
for i in self.llm.inference(text=text.to(self.device),
|
|
|
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
@@ -259,16 +259,20 @@ class CosyVoiceModel:
|
|
|
self.hift_cache_dict.pop(this_uuid)
|
|
|
|
|
|
|
|
|
-class CosyVoice2Model:
|
|
|
+class CosyVoice2Model(CosyVoiceModel):
|
|
|
|
|
|
def __init__(self,
|
|
|
llm: torch.nn.Module,
|
|
|
flow: torch.nn.Module,
|
|
|
- hift: torch.nn.Module):
|
|
|
+ hift: torch.nn.Module,
|
|
|
+ fp16: bool):
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
self.llm = llm
|
|
|
self.flow = flow
|
|
|
self.hift = hift
|
|
|
+ self.fp16 = fp16
|
|
|
+ self.llm.fp16 = fp16
|
|
|
+ self.flow.fp16 = fp16
|
|
|
self.token_hop_len = 2 * self.flow.input_frame_rate
|
|
|
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
|
|
|
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
|
|
|
@@ -287,52 +291,10 @@ class CosyVoice2Model:
|
|
|
self.llm_end_dict = {}
|
|
|
self.hift_cache_dict = {}
|
|
|
|
|
|
- def load(self, llm_model, flow_model, hift_model):
|
|
|
- self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
|
|
- self.llm.to(self.device).eval()
|
|
|
- self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
|
|
- self.flow.to(self.device).eval()
|
|
|
- self.flow.decoder.fp16 = False
|
|
|
- # in case hift_model is a hifigan model
|
|
|
- hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
|
|
- self.hift.load_state_dict(hift_state_dict, strict=True)
|
|
|
- self.hift.to(self.device).eval()
|
|
|
-
|
|
|
def load_jit(self, flow_encoder_model):
|
|
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
|
self.flow.encoder = flow_encoder
|
|
|
|
|
|
- def load_onnx(self, flow_decoder_estimator_model):
|
|
|
- import onnxruntime
|
|
|
- option = onnxruntime.SessionOptions()
|
|
|
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
- option.intra_op_num_threads = 1
|
|
|
- providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
|
- del self.flow.decoder.estimator
|
|
|
- self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
|
|
-
|
|
|
- def load_trt(self, flow_decoder_estimator_model):
|
|
|
- del self.flow.decoder.estimator
|
|
|
- import tensorrt as trt
|
|
|
- with open(flow_decoder_estimator_model, 'rb') as f:
|
|
|
- self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
|
|
- if self.flow.decoder.estimator_engine is None:
|
|
|
- raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
|
|
|
- self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
|
|
|
- self.flow.decoder.fp16 = True
|
|
|
-
|
|
|
- def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
|
|
- with self.llm_context:
|
|
|
- for i in self.llm.inference(text=text.to(self.device),
|
|
|
- text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
- prompt_text=prompt_text.to(self.device),
|
|
|
- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
|
- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
- embedding=llm_embedding.to(self.device)):
|
|
|
- self.tts_speech_token_dict[uuid].append(i)
|
|
|
- self.llm_end_dict[uuid] = True
|
|
|
-
|
|
|
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
|
|
|
tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
|
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|