|
|
@@ -47,11 +47,18 @@ class CosyVoiceModel:
|
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
|
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
|
|
self.llm.to(self.device).eval()
|
|
|
+ self.llm.half()
|
|
|
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
|
|
self.flow.to(self.device).eval()
|
|
|
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
|
|
self.hift.to(self.device).eval()
|
|
|
|
|
|
+ def load_script(self, llm_text_encoder_model, llm_llm_model):
|
|
|
+ llm_text_encoder = torch.jit.load(llm_text_encoder_model)
|
|
|
+ self.llm.text_encoder = llm_text_encoder
|
|
|
+ llm_llm = torch.jit.load(llm_llm_model)
|
|
|
+ self.llm.llm = llm_llm
|
|
|
+
|
|
|
def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
|
|
|
with self.llm_context:
|
|
|
for i in self.llm.inference(text=text.to(self.device),
|
|
|
@@ -60,7 +67,7 @@ class CosyVoiceModel:
|
|
|
prompt_text_len=prompt_text_len.to(self.device),
|
|
|
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
|
|
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
|
|
|
- embedding=llm_embedding.to(self.device),
|
|
|
+ embedding=llm_embedding.to(self.device).half(),
|
|
|
sampling=25,
|
|
|
max_token_text_ratio=30,
|
|
|
min_token_text_ratio=3):
|