lyuxiang.lx 1 年間 前
コミット
7e6d60c24c
2 ファイル変更4 行追加4 行削除
  1. 1 1
      cosyvoice/bin/inference.py
  2. 3 3
      cosyvoice/cli/model.py

+ 1 - 1
cosyvoice/bin/inference.py

@@ -99,7 +99,7 @@ def main():
                                'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
                                'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
             tts_speeches = []
-            for model_output in model.inference(**model_input):
+            for model_output in model.tts(**model_input):
                 tts_speeches.append(model_output['tts_speech'])
             tts_speeches = torch.concat(tts_speeches, dim=1)
             tts_key = '{}_{}'.format(utts[0], tts_index[0])

+ 3 - 3
cosyvoice/cli/model.py

@@ -56,14 +56,14 @@ class CosyVoiceModel:
         self.hift_cache_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
-        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
+        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
         self.llm.to(self.device).eval()
         if self.fp16 is True:
             self.llm.half()
-        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
+        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
         self.flow.to(self.device).eval()
         # in case hift_model is a hifigan model
-        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device)}
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
         self.hift.load_state_dict(hift_state_dict, strict=False)
         self.hift.to(self.device).eval()