Ver Fonte

load jit to device

lyuxiang.lx há 1 ano atrás
pai
commit
ba3d9693da
1 ficheiros alterados com 3 adições e 3 exclusões
  1. 3 3
      cosyvoice/cli/model.py

+ 3 - 3
cosyvoice/cli/model.py

@@ -63,11 +63,11 @@ class CosyVoiceModel:
         self.hift.to(self.device).eval()
 
     def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
-        llm_text_encoder = torch.jit.load(llm_text_encoder_model)
+        llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
         self.llm.text_encoder = llm_text_encoder
-        llm_llm = torch.jit.load(llm_llm_model)
+        llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
         self.llm.llm = llm_llm
-        flow_encoder = torch.jit.load(flow_encoder_model)
+        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
     def load_onnx(self, flow_decoder_estimator_model):