|
|
@@ -56,14 +56,14 @@ class CosyVoiceModel:
|
|
|
self.hift_cache_dict = {}
|
|
|
|
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
|
- self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
|
|
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
|
|
|
self.llm.to(self.device).eval()
|
|
|
if self.fp16 is True:
|
|
|
self.llm.half()
|
|
|
- self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
|
|
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
|
|
|
self.flow.to(self.device).eval()
|
|
|
# in case hift_model is a hifigan model
|
|
|
- hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device)}
|
|
|
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
|
|
self.hift.load_state_dict(hift_state_dict, strict=False)
|
|
|
self.hift.to(self.device).eval()
|
|
|
|