|
@@ -63,12 +63,12 @@ class CosyVoiceModel:
|
|
|
self.silent_tokens = []
|
|
self.silent_tokens = []
|
|
|
|
|
|
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
def load(self, llm_model, flow_model, hift_model):
|
|
|
- self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
|
|
|
|
|
|
|
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
|
|
|
self.llm.to(self.device).eval()
|
|
self.llm.to(self.device).eval()
|
|
|
- self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
|
|
|
|
|
|
|
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
|
|
|
self.flow.to(self.device).eval()
|
|
self.flow.to(self.device).eval()
|
|
|
# in case hift_model is a hifigan model
|
|
# in case hift_model is a hifigan model
|
|
|
- hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
|
|
|
|
|
|
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
|
|
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
|
self.hift.load_state_dict(hift_state_dict, strict=True)
|
|
|
self.hift.to(self.device).eval()
|
|
self.hift.to(self.device).eval()
|
|
|
|
|
|