浏览代码

Merge pull request #670 from FunAudioLLM/dev/lyuxiang.lx

fix bug
Xiang Lyu 1 年之前
父节点
当前提交
07352a50b3
共有 2 个文件被更改,包括 7 次插入0 次删除
  1. 5 0
      cosyvoice/cli/cosyvoice.py
  2. 2 0
      cosyvoice/utils/common.py

+ 5 - 0
cosyvoice/cli/cosyvoice.py

@@ -16,6 +16,7 @@ import time
 from tqdm import tqdm
 from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
+import torch
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 from cosyvoice.cli.model import CosyVoiceModel
 from cosyvoice.utils.file_utils import logging
@@ -37,6 +38,10 @@ class CosyVoice:
                                           '{}/spk2info.pt'.format(model_dir),
                                           instruct,
                                           configs['allowed_special'])
+        if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
+            load_jit = False
+            fp16 = False
+            logging.warning('cpu do not support fp16 and jit, force set to False')
         self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),

+ 2 - 0
cosyvoice/utils/common.py

@@ -141,6 +141,8 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
     device = fade_in_mel.device
     fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
     mel_overlap_len = int(window.shape[0] / 2)
+    if fade_in_mel.device == torch.device('cpu'):
+        fade_in_mel = fade_in_mel.clone()
     fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
         fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
     return fade_in_mel.to(device)