Przeglądaj źródła

Merge pull request #821 from FunAudioLLM/dev/lyuxiang.lx

update
Xiang Lyu 11 miesięcy temu
rodzic
commit
9e156428e2

+ 2 - 2
README.md

@@ -132,7 +132,7 @@ import torchaudio
 
 **CosyVoice2 Usage**
 ```python
-cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False)
+cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
 
 # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
 # zero_shot usage
@@ -151,7 +151,7 @@ for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来
 
 **CosyVoice Usage**
 ```python
-cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
+cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
 # sft usage
 print(cosyvoice.list_available_spks())
 # change stream=True for chunk stream inference

+ 35 - 18
cosyvoice/bin/export_jit.py

@@ -23,7 +23,7 @@ import torch
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
 
 
 def get_args():
@@ -37,6 +37,16 @@ def get_args():
     return args
 
 
+def get_optimized_script(model, preserved_attrs=[]):
+    script = torch.jit.script(model)
+    if preserved_attrs != []:
+        script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
+    else:
+        script = torch.jit.freeze(script)
+    script = torch.jit.optimize_for_inference(script)
+    return script
+
+
 def main():
     args = get_args()
     logging.basicConfig(level=logging.DEBUG,
@@ -46,28 +56,35 @@ def main():
     torch._C._jit_set_profiling_mode(False)
     torch._C._jit_set_profiling_executor(False)
 
-    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
+    try:
+        model = CosyVoice(args.model_dir)
+    except Exception:
+        try:
+            model = CosyVoice2(args.model_dir)
+        except Exception:
+            raise TypeError('no valid model_type!')
 
-    # 1. export llm text_encoder
-    llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
-    script = torch.jit.script(llm_text_encoder)
-    script = torch.jit.freeze(script)
-    script = torch.jit.optimize_for_inference(script)
-    script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
+    if not isinstance(model, CosyVoice2):
+        # 1. export llm text_encoder
+        llm_text_encoder = model.model.llm.text_encoder
+        script = get_optimized_script(llm_text_encoder)
+        script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
+        script = get_optimized_script(llm_text_encoder.half())
+        script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
 
-    # 2. export llm llm
-    llm_llm = cosyvoice.model.llm.llm.half()
-    script = torch.jit.script(llm_llm)
-    script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
-    script = torch.jit.optimize_for_inference(script)
-    script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
+        # 2. export llm llm
+        llm_llm = model.model.llm.llm
+        script = get_optimized_script(llm_llm, ['forward_chunk'])
+        script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
+        script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
+        script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
 
     # 3. export flow encoder
-    flow_encoder = cosyvoice.model.flow.encoder
-    script = torch.jit.script(flow_encoder)
-    script = torch.jit.freeze(script)
-    script = torch.jit.optimize_for_inference(script)
+    flow_encoder = model.model.flow.encoder
+    script = get_optimized_script(flow_encoder)
     script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
+    script = get_optimized_script(flow_encoder.half())
+    script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
 
 
 if __name__ == '__main__':

+ 18 - 14
cosyvoice/bin/export_onnx.py

@@ -27,7 +27,7 @@ from tqdm import tqdm
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
 
 
 def get_dummy_input(batch_size, seq_len, out_channels, device):
@@ -56,14 +56,20 @@ def main():
     logging.basicConfig(level=logging.DEBUG,
                         format='%(asctime)s %(levelname)s %(message)s')
 
-    cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
+    try:
+        model = CosyVoice(args.model_dir)
+    except Exception:
+        try:
+            model = CosyVoice2(args.model_dir)
+        except Exception:
+            raise TypeError('no valid model_type!')
 
     # 1. export flow decoder estimator
-    estimator = cosyvoice.model.flow.decoder.estimator
+    estimator = model.model.flow.decoder.estimator
 
-    device = cosyvoice.model.device
-    batch_size, seq_len = 1, 256
-    out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
+    device = model.model.device
+    batch_size, seq_len = 2, 256
+    out_channels = model.model.flow.decoder.estimator.out_channels
     x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
     torch.onnx.export(
         estimator,
@@ -75,13 +81,11 @@ def main():
         input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
         output_names=['estimator_out'],
         dynamic_axes={
-            'x': {0: 'batch_size', 2: 'seq_len'},
-            'mask': {0: 'batch_size', 2: 'seq_len'},
-            'mu': {0: 'batch_size', 2: 'seq_len'},
-            'cond': {0: 'batch_size', 2: 'seq_len'},
-            't': {0: 'batch_size'},
-            'spks': {0: 'batch_size'},
-            'estimator_out': {0: 'batch_size', 2: 'seq_len'},
+            'x': {2: 'seq_len'},
+            'mask': {2: 'seq_len'},
+            'mu': {2: 'seq_len'},
+            'cond': {2: 'seq_len'},
+            'estimator_out': {2: 'seq_len'},
         }
     )
 
@@ -94,7 +98,7 @@ def main():
                                                   sess_options=option, providers=providers)
 
     for _ in tqdm(range(10)):
-        x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
+        x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
         output_pytorch = estimator(x, mask, mu, t, spks, cond)
         ort_inputs = {
             'x': x.cpu().numpy(),

+ 1 - 0
cosyvoice/bin/export_trt.sh

@@ -6,4 +6,5 @@ TRT_DIR=<YOUR_TRT_DIR>
 MODEL_DIR=<COSYVOICE2_MODEL_DIR>
 
 export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
+$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
 $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw

+ 19 - 23
cosyvoice/cli/cosyvoice.py

@@ -25,14 +25,15 @@ from cosyvoice.utils.class_utils import get_model_type
 
 class CosyVoice:
 
-    def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
+        self.fp16 = fp16
         if not os.path.exists(model_dir):
             model_dir = snapshot_download(model_dir)
         with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
             configs = load_hyperpyyaml(f)
-        assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
+        assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
         self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                           configs['feat_extractor'],
                                           '{}/campplus.onnx'.format(model_dir),
@@ -40,20 +41,19 @@ class CosyVoice:
                                           '{}/spk2info.pt'.format(model_dir),
                                           configs['allowed_special'])
         self.sample_rate = configs['sample_rate']
-        if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
-            load_jit = False
-            fp16 = False
-            logging.warning('cpu do not support fp16 and jit, force set to False')
+        if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
+            load_jit, load_trt, fp16 = False, False, False
+            logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
         self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
         if load_jit:
-            self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
-                                '{}/llm.llm.fp16.zip'.format(model_dir),
-                                '{}/flow.encoder.fp32.zip'.format(model_dir))
-        if load_onnx:
-            self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
+            self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+                                '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+                                '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
+        if load_trt:
+            self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
         del configs
 
     def list_available_spks(self):
@@ -123,9 +123,10 @@ class CosyVoice:
 
 class CosyVoice2(CosyVoice):
 
-    def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
+        self.fp16 = fp16
         if not os.path.exists(model_dir):
             model_dir = snapshot_download(model_dir)
         with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
@@ -138,22 +139,17 @@ class CosyVoice2(CosyVoice):
                                           '{}/spk2info.pt'.format(model_dir),
                                           configs['allowed_special'])
         self.sample_rate = configs['sample_rate']
-        if torch.cuda.is_available() is False and load_jit is True:
-            load_jit = False
-            logging.warning('cpu do not support jit, force set to False')
-        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
+        if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
+            load_jit, load_trt, fp16 = False, False, False
+            logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
+        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
         if load_jit:
-            self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
-        if load_trt is True and load_onnx is True:
-            load_onnx = False
-            logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
-        if load_onnx:
-            self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
+            self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
         if load_trt:
-            self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
+            self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
         del configs
 
     def inference_instruct(self, *args, **kwargs):

+ 18 - 56
cosyvoice/cli/model.py

@@ -33,6 +33,8 @@ class CosyVoiceModel:
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
+        self.llm.fp16 = fp16
+        self.flow.fp16 = fp16
         self.token_min_hop_len = 2 * self.flow.input_frame_rate
         self.token_max_hop_len = 4 * self.flow.input_frame_rate
         self.token_overlap_len = 20
@@ -61,17 +63,17 @@ class CosyVoiceModel:
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
         self.llm.to(self.device).eval()
-        if self.fp16 is True:
-            self.llm.half()
         self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
         self.flow.to(self.device).eval()
         # in case hift_model is a hifigan model
         hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
         self.hift.load_state_dict(hift_state_dict, strict=True)
         self.hift.to(self.device).eval()
+        if self.fp16 is True:
+            self.llm.half()
+            self.flow.half()
 
     def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
-        assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
         llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
         self.llm.text_encoder = llm_text_encoder
         llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
@@ -79,18 +81,16 @@ class CosyVoiceModel:
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
-    def load_onnx(self, flow_decoder_estimator_model):
-        import onnxruntime
-        option = onnxruntime.SessionOptions()
-        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
-        option.intra_op_num_threads = 1
-        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+    def load_trt(self, flow_decoder_estimator_model):
         del self.flow.decoder.estimator
-        self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        if self.flow.decoder.estimator_engine is None:
+            raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
+        self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
 
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
-        if self.fp16 is True:
-            llm_embedding = llm_embedding.half()
         with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),
                                         text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
@@ -259,16 +259,20 @@ class CosyVoiceModel:
             self.hift_cache_dict.pop(this_uuid)
 
 
-class CosyVoice2Model:
+class CosyVoice2Model(CosyVoiceModel):
 
     def __init__(self,
                  llm: torch.nn.Module,
                  flow: torch.nn.Module,
-                 hift: torch.nn.Module):
+                 hift: torch.nn.Module,
+                 fp16: bool):
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         self.llm = llm
         self.flow = flow
         self.hift = hift
+        self.fp16 = fp16
+        self.llm.fp16 = fp16
+        self.flow.fp16 = fp16
         self.token_hop_len = 2 * self.flow.input_frame_rate
         # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
         self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
@@ -287,52 +291,10 @@ class CosyVoice2Model:
         self.llm_end_dict = {}
         self.hift_cache_dict = {}
 
-    def load(self, llm_model, flow_model, hift_model):
-        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
-        self.llm.to(self.device).eval()
-        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
-        self.flow.to(self.device).eval()
-        self.flow.decoder.fp16 = False
-        # in case hift_model is a hifigan model
-        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
-        self.hift.load_state_dict(hift_state_dict, strict=True)
-        self.hift.to(self.device).eval()
-
     def load_jit(self, flow_encoder_model):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
-    def load_onnx(self, flow_decoder_estimator_model):
-        import onnxruntime
-        option = onnxruntime.SessionOptions()
-        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
-        option.intra_op_num_threads = 1
-        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
-        del self.flow.decoder.estimator
-        self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
-
-    def load_trt(self, flow_decoder_estimator_model):
-        del self.flow.decoder.estimator
-        import tensorrt as trt
-        with open(flow_decoder_estimator_model, 'rb') as f:
-            self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
-        if self.flow.decoder.estimator_engine is None:
-            raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
-        self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
-        self.flow.decoder.fp16 = True
-
-    def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
-        with self.llm_context:
-            for i in self.llm.inference(text=text.to(self.device),
-                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
-                                        prompt_text=prompt_text.to(self.device),
-                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
-                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
-                                        embedding=llm_embedding.to(self.device)):
-                self.tts_speech_token_dict[uuid].append(i)
-        self.llm_end_dict[uuid] = True
-
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
         tts_mel, _ = self.flow.inference(token=token.to(self.device),
                                          token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),

+ 12 - 4
cosyvoice/flow/flow.py

@@ -111,6 +111,10 @@ class MaskedDiffWithXvec(torch.nn.Module):
                   prompt_feat_len,
                   embedding,
                   flow_cache):
+        if self.fp16 is True:
+            prompt_feat = prompt_feat.half()
+            embedding = embedding.half()
+
         assert token.shape[0] == 1
         # xvec projection
         embedding = F.normalize(embedding, dim=1)
@@ -129,7 +133,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
 
         # get conditions
-        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
+        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
         conds[:, :mel_len1] = prompt_feat
         conds = conds.transpose(1, 2)
 
@@ -145,7 +149,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         )
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2
-        return feat, flow_cache
+        return feat.float(), flow_cache
 
 
 class CausalMaskedDiffWithXvec(torch.nn.Module):
@@ -196,6 +200,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
                   prompt_feat_len,
                   embedding,
                   finalize):
+        if self.fp16 is True:
+            prompt_feat = prompt_feat.half()
+            embedding = embedding.half()
+
         assert token.shape[0] == 1
         # xvec projection
         embedding = F.normalize(embedding, dim=1)
@@ -214,7 +222,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         h = self.encoder_proj(h)
 
         # get conditions
-        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
+        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
         conds[:, :mel_len1] = prompt_feat
         conds = conds.transpose(1, 2)
 
@@ -228,4 +236,4 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         )
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2
-        return feat, None
+        return feat.float(), None

+ 17 - 38
cosyvoice/flow/flow_matching.py

@@ -11,7 +11,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import onnxruntime
 import torch
 import torch.nn.functional as F
 from matcha.models.components.flow_matching import BASECFM
@@ -52,7 +51,7 @@ class ConditionalCFM(BASECFM):
                 shape: (batch_size, n_feats, mel_timesteps)
         """
 
-        z = torch.randn_like(mu) * temperature
+        z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
         cache_size = flow_cache.shape[2]
         # fix prompt and overlap part mu and z
         if cache_size != 0:
@@ -89,36 +88,29 @@ class ConditionalCFM(BASECFM):
         # Or in future might add like a return_all_steps flag
         sol = []
 
-        if self.inference_cfg_rate > 0:
-            # Do not use concat, it may cause memory format changed and trt infer with wrong results!
-            x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-            mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
-            mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-            t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
-            spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
-            cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-        else:
-            x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
+        # Do not use concat, it may cause memory format changed and trt infer with wrong results!
+        x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+        mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
+        mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+        t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
+        spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
+        cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
         for step in range(1, len(t_span)):
             # Classifier-Free Guidance inference introduced in VoiceBox
-            if self.inference_cfg_rate > 0:
-                x_in[:] = x
-                mask_in[:] = mask
-                mu_in[0] = mu
-                t_in[:] = t.unsqueeze(0)
-                spks_in[0] = spks
-                cond_in[0] = cond
-            else:
-                x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
+            x_in[:] = x
+            mask_in[:] = mask
+            mu_in[0] = mu
+            t_in[:] = t.unsqueeze(0)
+            spks_in[0] = spks
+            cond_in[0] = cond
             dphi_dt = self.forward_estimator(
                 x_in, mask_in,
                 mu_in, t_in,
                 spks_in,
                 cond_in
             )
-            if self.inference_cfg_rate > 0:
-                dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
-                dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
+            dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
+            dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
             x = x + dt * dphi_dt
             t = t + dt
             sol.append(x)
@@ -130,17 +122,6 @@ class ConditionalCFM(BASECFM):
     def forward_estimator(self, x, mask, mu, t, spks, cond):
         if isinstance(self.estimator, torch.nn.Module):
             return self.estimator.forward(x, mask, mu, t, spks, cond)
-        elif isinstance(self.estimator, onnxruntime.InferenceSession):
-            ort_inputs = {
-                'x': x.cpu().numpy(),
-                'mask': mask.cpu().numpy(),
-                'mu': mu.cpu().numpy(),
-                't': t.cpu().numpy(),
-                'spks': spks.cpu().numpy(),
-                'cond': cond.cpu().numpy()
-            }
-            output = self.estimator.run(None, ort_inputs)[0]
-            return torch.tensor(output, dtype=x.dtype, device=x.device)
         else:
             self.estimator.set_input_shape('x', (2, 80, x.size(2)))
             self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
@@ -225,9 +206,7 @@ class CausalConditionalCFM(ConditionalCFM):
                 shape: (batch_size, n_feats, mel_timesteps)
         """
 
-        z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
-        if self.fp16 is True:
-            z = z.half()
+        z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
         # fix prompt and overlap part mu and z
         t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         if self.t_scheduler == 'cosine':

+ 5 - 2
cosyvoice/llm/llm.py

@@ -164,6 +164,9 @@ class TransformerLM(torch.nn.Module):
             max_token_text_ratio: float = 20,
             min_token_text_ratio: float = 2,
     ) -> Generator[torch.Tensor, None, None]:
+        if self.fp16 is True:
+            embedding = embedding.half()
+
         device = text.device
         text = torch.concat([prompt_text, text], dim=1)
         text_len += prompt_text_len
@@ -178,7 +181,7 @@ class TransformerLM(torch.nn.Module):
             embedding = self.spk_embed_affine_layer(embedding)
             embedding = embedding.unsqueeze(dim=1)
         else:
-            embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+            embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
 
         # 3. concat llm_input
         sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
@@ -310,7 +313,7 @@ class Qwen2LM(torch.nn.Module):
         text = self.llm.model.model.embed_tokens(text)
 
         # 2. encode embedding
-        embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+        embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
 
         # 3. concat llm_input
         sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)

+ 1 - 0
cosyvoice/utils/class_utils.py

@@ -75,6 +75,7 @@ COSYVOICE_ATTENTION_CLASSES = {
 
 
 def get_model_type(configs):
+    # NOTE CosyVoice2Model inherits CosyVoiceModel
     if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
         return CosyVoiceModel
     if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):

+ 36 - 1
cosyvoice/utils/mask.py

@@ -86,7 +86,7 @@ def subsequent_mask(
     return mask
 
 
-def subsequent_chunk_mask(
+def subsequent_chunk_mask_deprecated(
         size: int,
         chunk_size: int,
         num_left_chunks: int = -1,
@@ -124,6 +124,41 @@ def subsequent_chunk_mask(
     return ret
 
 
+def subsequent_chunk_mask(
+        size: int,
+        chunk_size: int,
+        num_left_chunks: int = -1,
+        device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+    """Create mask for subsequent steps (size, size) with chunk size,
+       this is for streaming encoder
+
+    Args:
+        size (int): size of mask
+        chunk_size (int): size of chunk
+        num_left_chunks (int): number of left chunks
+            <0: use full chunk
+            >=0: use num_left_chunks
+        device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+    Returns:
+        torch.Tensor: mask
+
+    Examples:
+        >>> subsequent_chunk_mask(4, 2)
+        [[1, 1, 0, 0],
+         [1, 1, 0, 0],
+         [1, 1, 1, 1],
+         [1, 1, 1, 1]]
+    """
+    # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
+    # actually this is not needed after we have inference cache implemented, will remove it later
+    pos_idx = torch.arange(size, device=device)
+    block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
+    ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
+    return ret
+
+
 def add_optional_chunk_mask(xs: torch.Tensor,
                             masks: torch.Tensor,
                             use_dynamic_chunk: bool,