lyuxiang.lx il y a 10 mois
Parent
commit
2a3e033ee1

+ 1 - 1
cosyvoice/bin/average_model.py

@@ -75,7 +75,7 @@ def main():
         print('Processing {}'.format(path))
         states = torch.load(path, map_location=torch.device('cpu'))
         for k in states.keys():
-            if k not in avg.keys():
+            if k not in avg.keys() and k not in ['step', 'epoch']:
                 avg[k] = states[k].clone()
             else:
                 avg[k] += states[k]

+ 1 - 0
cosyvoice/bin/export_jit.py

@@ -98,5 +98,6 @@ def main():
         script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
         logging.info('successfully export flow_encoder')
 
+
 if __name__ == '__main__':
     main()

+ 19 - 17
cosyvoice/bin/export_onnx.py

@@ -99,7 +99,7 @@ def main():
         option.intra_op_num_threads = 1
         providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
         estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
-                                                    sess_options=option, providers=providers)
+                                                      sess_options=option, providers=providers)
 
         for _ in tqdm(range(10)):
             x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
@@ -131,31 +131,33 @@ def main():
         torch.onnx.export(
             estimator,
             (x, mask, mu, t, spks, cond,
-            cache['down_blocks_conv_cache'],
-            cache['down_blocks_kv_cache'],
-            cache['mid_blocks_conv_cache'],
-            cache['mid_blocks_kv_cache'],
-            cache['up_blocks_conv_cache'],
-            cache['up_blocks_kv_cache'],
-            cache['final_blocks_conv_cache']),
+             cache['down_blocks_conv_cache'],
+             cache['down_blocks_kv_cache'],
+             cache['mid_blocks_conv_cache'],
+             cache['mid_blocks_kv_cache'],
+             cache['up_blocks_conv_cache'],
+             cache['up_blocks_kv_cache'],
+             cache['final_blocks_conv_cache']),
             '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
             export_params=True,
             opset_version=18,
             do_constant_folding=True,
-            input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache', 'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
-            output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out', 'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
+            input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
+                         'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
+            output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
+                          'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
             dynamic_axes={
                 'x': {2: 'seq_len'},
                 'mask': {2: 'seq_len'},
                 'mu': {2: 'seq_len'},
                 'cond': {2: 'seq_len'},
-                'down_blocks_kv_cache': {3: 'seq_len'},
-                'mid_blocks_kv_cache': {3: 'seq_len'},
-                'up_blocks_kv_cache': {3: 'seq_len'},
+                'down_blocks_kv_cache': {3: 'cache_in_len'},
+                'mid_blocks_kv_cache': {3: 'cache_in_len'},
+                'up_blocks_kv_cache': {3: 'cache_in_len'},
                 'estimator_out': {2: 'seq_len'},
-                'down_blocks_kv_cache_out': {3: 'seq_len'},
-                'mid_blocks_kv_cache_out': {3: 'seq_len'},
-                'up_blocks_kv_cache_out': {3: 'seq_len'},
+                'down_blocks_kv_cache_out': {3: 'cache_out_len'},
+                'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
+                'up_blocks_kv_cache_out': {3: 'cache_out_len'},
             }
         )
 
@@ -165,7 +167,7 @@ def main():
         option.intra_op_num_threads = 1
         providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
         estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
-                                                    sess_options=option, providers=providers)
+                                                      sess_options=option, providers=providers)
 
         for _ in tqdm(range(10)):
             x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)

+ 6 - 6
cosyvoice/bin/export_trt.sh

@@ -7,19 +7,19 @@ MODEL_DIR=<YOUR_MODEL_DIR>
 export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
 
 # cosyvoice export
-$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
-$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
+$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x200,mask:2x1x200,mu:2x80x200,cond:2x80x200 --maxShapes=x:2x80x3000,mask:2x1x3000,mu:2x80x3000,cond:2x80x3000 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
+$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x200,mask:2x1x200,mu:2x80x200,cond:2x80x200 --maxShapes=x:2x80x3000,mask:2x1x3000,mu:2x80x3000,cond:2x80x3000 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw
 
 # cosyvoice2 export with cache
 $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan \
     --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4,down_blocks_kv_cache:1x4x2x0x512x2,mid_blocks_kv_cache:12x4x2x0x512x2,up_blocks_kv_cache:1x4x2x0x512x2 \
-    --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193,down_blocks_kv_cache:1x4x2x193x512x2,mid_blocks_kv_cache:12x4x2x193x512x2,up_blocks_kv_cache:1x4x2x193x512x2 \
-    --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800,down_blocks_kv_cache:1x4x2x200x512x2,mid_blocks_kv_cache:12x4x2x200x512x2,up_blocks_kv_cache:1x4x2x200x512x2 \
+    --optShapes=x:2x80x200,mask:2x1x200,mu:2x80x200,cond:2x80x200,down_blocks_kv_cache:1x4x2x100x512x2,mid_blocks_kv_cache:12x4x2x100x512x2,up_blocks_kv_cache:1x4x2x100x512x2 \
+    --maxShapes=x:2x80x1500,mask:2x1x1500,mu:2x80x1500,cond:2x80x1500,down_blocks_kv_cache:1x4x2x200x512x2,mid_blocks_kv_cache:12x4x2x200x512x2,up_blocks_kv_cache:1x4x2x200x512x2 \
     --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw \
     --outputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw
 $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 \
     --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4,down_blocks_kv_cache:1x4x2x0x512x2,mid_blocks_kv_cache:12x4x2x0x512x2,up_blocks_kv_cache:1x4x2x0x512x2 \
-    --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193,down_blocks_kv_cache:1x4x2x193x512x2,mid_blocks_kv_cache:12x4x2x193x512x2,up_blocks_kv_cache:1x4x2x193x512x2 \
-    --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800,down_blocks_kv_cache:1x4x2x200x512x2,mid_blocks_kv_cache:12x4x2x200x512x2,up_blocks_kv_cache:1x4x2x200x512x2 \
+    --optShapes=x:2x80x200,mask:2x1x200,mu:2x80x200,cond:2x80x200,down_blocks_kv_cache:1x4x2x100x512x2,mid_blocks_kv_cache:12x4x2x100x512x2,up_blocks_kv_cache:1x4x2x100x512x2 \
+    --maxShapes=x:2x80x1500,mask:2x1x1500,mu:2x80x1500,cond:2x80x1500,down_blocks_kv_cache:1x4x2x200x512x2,mid_blocks_kv_cache:12x4x2x200x512x2,up_blocks_kv_cache:1x4x2x200x512x2 \
     --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw \
     --outputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw

+ 2 - 1
cosyvoice/bin/inference.py

@@ -78,6 +78,7 @@ def main():
                            tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
     test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
 
+    sample_rate = configs['sample_rate']
     del configs
     os.makedirs(args.result_dir, exist_ok=True)
     fn = os.path.join(args.result_dir, 'wav.scp')
@@ -113,7 +114,7 @@ def main():
             tts_speeches = torch.concat(tts_speeches, dim=1)
             tts_key = '{}_{}'.format(utts[0], tts_index[0])
             tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
-            torchaudio.save(tts_fn, tts_speeches, sample_rate=configs['sample_rate'], backend='soundfile')
+            torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
             f.write('{} {}\n'.format(tts_key, tts_fn))
             f.flush()
     f.close()

+ 43 - 28
cosyvoice/cli/model.py

@@ -36,8 +36,6 @@ class CosyVoiceModel:
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
-        self.llm.fp16 = fp16
-        self.flow.fp16 = fp16
         if self.fp16 is True:
             self.llm.half()
             self.flow.half()
@@ -85,19 +83,25 @@ class CosyVoiceModel:
     def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
         assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
         if not os.path.exists(flow_decoder_estimator_model):
-            convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
+            convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
         if os.path.getsize(flow_decoder_estimator_model) == 0:
             raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
         del self.flow.decoder.estimator
         import tensorrt as trt
         with open(flow_decoder_estimator_model, 'rb') as f:
             self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
-        if self.flow.decoder.estimator_engine is None:
-            raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
+        assert self.flow.decoder.estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
         self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
 
+    def get_trt_kwargs(self):
+        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
+        opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
+        max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
+        input_names = ["x", "mask", "mu", "cond"]
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
-        with self.llm_context:
+        with self.llm_context, torch.cuda.amp.autocast(self.fp16):
             if isinstance(text, Generator):
                 assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
                 for i in self.llm.inference_bistream(text=text,
@@ -119,14 +123,15 @@ class CosyVoiceModel:
         self.llm_end_dict[uuid] = True
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
-        tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
-                                                                  token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
-                                                                  prompt_token=prompt_token.to(self.device),
-                                                                  prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
-                                                                  prompt_feat=prompt_feat.to(self.device),
-                                                                  prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
-                                                                  embedding=embedding.to(self.device),
-                                                                  flow_cache=self.flow_cache_dict[uuid])
+        with torch.cuda.amp.autocast(self.fp16):
+            tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
+                                                                      token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                                                      prompt_token=prompt_token.to(self.device),
+                                                                      prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                                      prompt_feat=prompt_feat.to(self.device),
+                                                                      prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                                                      embedding=embedding.to(self.device),
+                                                                      flow_cache=self.flow_cache_dict[uuid])
 
         # mel overlap fade in out
         if self.mel_overlap_dict[uuid].shape[2] != 0:
@@ -289,21 +294,18 @@ class CosyVoice2Model(CosyVoiceModel):
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
-        self.llm.fp16 = fp16
-        self.flow.fp16 = fp16
         if self.fp16 is True:
             self.llm.half()
             self.flow.half()
-        self.token_hop_len = 2 * self.flow.input_frame_rate
+        self.token_hop_len = self.flow.encoder.static_chunk_size
         # flow decoder required_cache_size
-        self.flow_decoder_required_cache_size = self.flow.decoder.estimator.num_decoding_left_chunks * self.flow.input_frame_rate * self.flow.token_mel_ratio
+        self.flow_decoder_required_cache_size = self.flow.decoder.estimator.num_decoding_left_chunks * self.flow.decoder.estimator.static_chunk_size
         # hift cache
         self.mel_cache_len = 8
         self.source_cache_len = int(self.mel_cache_len * 480)
         # speech fade in out
         self.speech_window = np.hamming(2 * self.source_cache_len)
         # rtf and decoding related
-        self.stream_scale_factor = 1
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.lock = threading.Lock()
         # dict used to store session related variable
@@ -327,6 +329,11 @@ class CosyVoice2Model(CosyVoiceModel):
                          'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
                          'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
                          'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
+        if self.fp16 is True:
+            for cache in [encoder_cache, decoder_cache]:
+                for k, v in cache.items():
+                    if isinstance(v, torch.Tensor):
+                        cache[k] = v.half()
         cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
         return cache
 
@@ -341,16 +348,24 @@ class CosyVoice2Model(CosyVoiceModel):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
+    def get_trt_kwargs(self):
+        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
+        opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
+        max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
+        input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
+        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
+
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
-        tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
-                                                                  token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
-                                                                  prompt_token=prompt_token.to(self.device),
-                                                                  prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
-                                                                  prompt_feat=prompt_feat.to(self.device),
-                                                                  prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
-                                                                  embedding=embedding.to(self.device),
-                                                                  cache=self.flow_cache_dict[uuid],
-                                                                  finalize=finalize)
+        with torch.cuda.amp.autocast(self.fp16):
+            tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
+                                                                      token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                                                      prompt_token=prompt_token.to(self.device),
+                                                                      prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                                      prompt_feat=prompt_feat.to(self.device),
+                                                                      prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                                                      embedding=embedding.to(self.device),
+                                                                      cache=self.flow_cache_dict[uuid],
+                                                                      finalize=finalize)
         self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid])
         # append hift cache
         if self.hift_cache_dict[uuid] is not None:

+ 2 - 2
cosyvoice/dataset/processor.py

@@ -196,8 +196,8 @@ def compute_f0(data, sample_rate, hop_size, mode='train'):
         assert 'text_token' in sample
         waveform = sample['speech']
         _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
-        if sum(_f0 != 0) < 5: # this happens when the algorithm fails
-            _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
+        if sum(_f0 != 0) < 5:  # this happens when the algorithm fails
+            _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)  # if harvest fails, try dio
         f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
         f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
         sample['pitch_feat'] = f0

+ 33 - 23
cosyvoice/flow/decoder.py

@@ -57,7 +57,7 @@ class CausalConv1d(torch.nn.Conv1d):
         assert stride == 1
         self.causal_padding = kernel_size - 1
 
-    def forward(self, x: torch.Tensor, cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+    def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
         if cache.size(2) == 0:
             x = F.pad(x, (self.causal_padding, 0), value=0.0)
         else:
@@ -79,7 +79,7 @@ class CausalBlock1D(Block1D):
             nn.Mish(),
         )
 
-    def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+    def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
         output, cache = self.block[0](x * mask, cache)
         for i in range(1, len(self.block)):
             output = self.block[i](output)
@@ -92,7 +92,9 @@ class CausalResnetBlock1D(ResnetBlock1D):
         self.block1 = CausalBlock1D(dim, dim_out)
         self.block2 = CausalBlock1D(dim_out, dim_out)
 
-    def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor, block1_cache: torch.Tensor=torch.zeros(0, 0, 0), block2_cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
+                block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
+                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         h, block1_cache = self.block1(x, mask, block1_cache)
         h += self.mlp(time_emb).unsqueeze(-1)
         h, block2_cache = self.block2(h, mask, block2_cache)
@@ -120,7 +122,8 @@ class CausalAttnProcessor2_0(AttnProcessor2_0):
         **kwargs,
     ) -> Tuple[torch.FloatTensor, torch.Tensor]:
         if len(args) > 0 or kwargs.get("scale", None) is not None:
-            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
+                `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
             deprecate("scale", "1.0.0", deprecation_message)
 
         residual = hidden_states
@@ -224,8 +227,10 @@ class CausalAttention(Attention):
         processor: Optional["AttnProcessor2_0"] = None,
         out_dim: int = None,
     ):
-        super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, qk_norm,
-                                              added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim)
+        super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
+                                              cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
+                                              spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
+                                              _from_deprecated_attn_block, processor, out_dim)
         processor = CausalAttnProcessor2_0()
         self.set_processor(processor)
 
@@ -294,8 +299,10 @@ class CausalBasicTransformerBlock(BasicTransformerBlock):
         norm_type: str = "layer_norm",
         final_dropout: bool = False,
     ):
-        super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, activation_fn, num_embeds_ada_norm,
-                                                          attention_bias, only_cross_attention, double_self_attention, upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
+        super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
+                                                          cross_attention_dim, activation_fn, num_embeds_ada_norm,
+                                                          attention_bias, only_cross_attention, double_self_attention,
+                                                          upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
         self.attn1 = CausalAttention(
             query_dim=dim,
             heads=num_attention_heads,
@@ -364,9 +371,8 @@ class CausalBasicTransformerBlock(BasicTransformerBlock):
         if self._chunk_size is not None:
             # "feed_forward_chunk_size" can be used to save memory
             if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
-                raise ValueError(
-                    f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
-                )
+                raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
+                                 {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
 
             num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
             ff_output = torch.cat(
@@ -794,14 +800,14 @@ class CausalConditionalDecoder(ConditionalDecoder):
         return output * mask
 
     def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
-        down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-        down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
-        mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-        mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
-        up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-        up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
-        final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+                      down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+                      down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
+                      mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+                      mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
+                      up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
+                      up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
+                      final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
+                      ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         """Forward pass of the UNet1DConditional model.
 
         Args:
@@ -838,7 +844,8 @@ class CausalConditionalDecoder(ConditionalDecoder):
         up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
         for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
             mask_down = masks[-1]
-            x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
+            x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
+                resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
             x = rearrange(x, "b c t -> b t c").contiguous()
             attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
             attn_mask = mask_to_bias(attn_mask, x.dtype)
@@ -857,7 +864,8 @@ class CausalConditionalDecoder(ConditionalDecoder):
         mask_mid = masks[-1]
 
         for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
-            x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
+            x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
+                resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
             x = rearrange(x, "b c t -> b t c").contiguous()
             attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
             attn_mask = mask_to_bias(attn_mask, x.dtype)
@@ -874,7 +882,8 @@ class CausalConditionalDecoder(ConditionalDecoder):
             mask_up = masks.pop()
             skip = hiddens.pop()
             x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
-            x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
+            x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
+                resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
             x = rearrange(x, "b c t -> b t c").contiguous()
             attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
             attn_mask = mask_to_bias(attn_mask, x.dtype)
@@ -889,4 +898,5 @@ class CausalConditionalDecoder(ConditionalDecoder):
             x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
         x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
         output = self.final_proj(x * mask_up)
-        return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache
+        return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
+            up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache

+ 1 - 9
cosyvoice/flow/flow.py

@@ -112,10 +112,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
                   prompt_feat_len,
                   embedding,
                   flow_cache):
-        if self.fp16 is True:
-            prompt_feat = prompt_feat.half()
-            embedding = embedding.half()
-
         assert token.shape[0] == 1
         # xvec projection
         embedding = F.normalize(embedding, dim=1)
@@ -146,7 +142,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
             cond=conds,
             n_timesteps=10,
             prompt_len=mel_len1,
-            flow_cache=flow_cache
+            cache=flow_cache
         )
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2
@@ -249,10 +245,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
                   embedding,
                   cache,
                   finalize):
-        if self.fp16 is True:
-            prompt_feat = prompt_feat.half()
-            embedding = embedding.half()
-
         assert token.shape[0] == 1
         # xvec projection
         embedding = F.normalize(embedding, dim=1)

+ 48 - 17
cosyvoice/flow/flow_matching.py

@@ -133,13 +133,13 @@ class ConditionalCFM(BASECFM):
                 self.estimator.set_input_shape('spks', (2, 80))
                 self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
                 # run trt engine
-                self.estimator.execute_v2([x.contiguous().data_ptr(),
-                                           mask.contiguous().data_ptr(),
-                                           mu.contiguous().data_ptr(),
-                                           t.contiguous().data_ptr(),
-                                           spks.contiguous().data_ptr(),
-                                           cond.contiguous().data_ptr(),
-                                           x.data_ptr()])
+                assert self.estimator.execute_v2([x.contiguous().data_ptr(),
+                                                  mask.contiguous().data_ptr(),
+                                                  mu.contiguous().data_ptr(),
+                                                  t.contiguous().data_ptr(),
+                                                  spks.contiguous().data_ptr(),
+                                                  cond.contiguous().data_ptr(),
+                                                  x.data_ptr()]) is True
             return x
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -244,9 +244,9 @@ class CausalConditionalCFM(ConditionalCFM):
         sol = []
 
         # estimator cache for each step
-        down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
-        mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x.device)
-        up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x.device)
+        down_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
+        mid_blocks_kv_cache_new = torch.zeros(10, 12, 4, 2, x.size(2), 512, 2).to(x)
+        up_blocks_kv_cache_new = torch.zeros(10, 1, 4, 2, x.size(2), 512, 2).to(x)
 
         # Do not use concat, it may cause memory format changed and trt infer with wrong results!
         x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
@@ -302,12 +302,43 @@ class CausalConditionalCFM(ConditionalCFM):
                 self.estimator.set_input_shape('t', (2,))
                 self.estimator.set_input_shape('spks', (2, 80))
                 self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
+                self.estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
+                self.estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
+                self.estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
+                self.estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
+                self.estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
+                self.estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
+                self.estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
                 # run trt engine
-                self.estimator.execute_v2([x.contiguous().data_ptr(),
-                                           mask.contiguous().data_ptr(),
-                                           mu.contiguous().data_ptr(),
-                                           t.contiguous().data_ptr(),
-                                           spks.contiguous().data_ptr(),
-                                           cond.contiguous().data_ptr(),
-                                           x.data_ptr()])
+                down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
+                mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
+                up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
+                assert self.estimator.execute_v2([x.contiguous().data_ptr(),
+                                                  mask.contiguous().data_ptr(),
+                                                  mu.contiguous().data_ptr(),
+                                                  t.contiguous().data_ptr(),
+                                                  spks.contiguous().data_ptr(),
+                                                  cond.contiguous().data_ptr(),
+                                                  cache['down_blocks_conv_cache'].contiguous().data_ptr(),
+                                                  cache['down_blocks_kv_cache'].contiguous().data_ptr(),
+                                                  cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
+                                                  cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
+                                                  cache['up_blocks_conv_cache'].contiguous().data_ptr(),
+                                                  cache['up_blocks_kv_cache'].contiguous().data_ptr(),
+                                                  cache['final_blocks_conv_cache'].contiguous().data_ptr(),
+                                                  x.data_ptr(),
+                                                  cache['down_blocks_conv_cache'].data_ptr(),
+                                                  down_blocks_kv_cache_out.data_ptr(),
+                                                  cache['mid_blocks_conv_cache'].data_ptr(),
+                                                  mid_blocks_kv_cache_out.data_ptr(),
+                                                  cache['up_blocks_conv_cache'].data_ptr(),
+                                                  up_blocks_kv_cache_out.data_ptr(),
+                                                  cache['final_blocks_conv_cache'].data_ptr()]) is True
+                cache = (cache['down_blocks_conv_cache'],
+                         down_blocks_kv_cache_out,
+                         cache['mid_blocks_conv_cache'],
+                         mid_blocks_kv_cache_out,
+                         cache['up_blocks_conv_cache'],
+                         up_blocks_kv_cache_out,
+                         cache['final_blocks_conv_cache'])
         return x, cache

+ 2 - 5
cosyvoice/llm/llm.py

@@ -169,9 +169,6 @@ class TransformerLM(torch.nn.Module):
             max_token_text_ratio: float = 20,
             min_token_text_ratio: float = 2,
     ) -> Generator[torch.Tensor, None, None]:
-        if self.fp16 is True:
-            embedding = embedding.half()
-
         device = text.device
         text = torch.concat([prompt_text, text], dim=1)
         text_len += prompt_text_len
@@ -393,8 +390,8 @@ class Qwen2LM(TransformerLM):
                 while True:
                     seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
                     y_pred, cache = self.llm.forward_one_step(lm_input,
-                                                masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
-                                                cache=cache)
+                                                              masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
+                                                              cache=cache)
                     logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
                     if next_fill_index != -1 and len(out_tokens) == next_fill_index:
                         top_ids = self.speech_token_size + 2

+ 11 - 11
cosyvoice/transformer/upsample_encoder.py

@@ -56,7 +56,7 @@ class Upsample1D(nn.Module):
         # In this mode, first repeat interpolate, than conv with stride=1
         self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
 
-    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor=torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
         if conv_cache.size(2) == 0:
             outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
@@ -287,11 +287,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
         xs, pos_emb, masks = self.embed(xs, masks)
         mask_pad = masks  # (B, 1, T/subsample_rate)
         chunk_masks = add_optional_chunk_mask(xs, masks,
-                                            self.use_dynamic_chunk if streaming is True else False,
-                                            self.use_dynamic_left_chunk if streaming is True else False,
-                                            decoding_chunk_size if streaming is True else 0,
-                                            self.static_chunk_size if streaming is True else 0,
-                                            num_decoding_left_chunks if streaming is True else -1)
+                                              self.use_dynamic_chunk if streaming is True else False,
+                                              self.use_dynamic_left_chunk if streaming is True else False,
+                                              decoding_chunk_size if streaming is True else 0,
+                                              self.static_chunk_size if streaming is True else 0,
+                                              num_decoding_left_chunks if streaming is True else -1)
         # lookahead + conformer encoder
         xs, _ = self.pre_lookahead_layer(xs)
         xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
@@ -305,11 +305,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
         xs, pos_emb, masks = self.up_embed(xs, masks)
         mask_pad = masks  # (B, 1, T/subsample_rate)
         chunk_masks = add_optional_chunk_mask(xs, masks,
-                                            self.use_dynamic_chunk if streaming is True else False,
-                                            self.use_dynamic_left_chunk if streaming is True else False,
-                                            decoding_chunk_size if streaming is True else 0,
-                                            self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
-                                            num_decoding_left_chunks if streaming is True else -1)
+                                              self.use_dynamic_chunk if streaming is True else False,
+                                              self.use_dynamic_left_chunk if streaming is True else False,
+                                              decoding_chunk_size if streaming is True else 0,
+                                              self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
+                                              num_decoding_left_chunks if streaming is True else -1)
         xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
 
         if self.normalize_before:

+ 4 - 8
cosyvoice/utils/file_utils.py

@@ -47,13 +47,8 @@ def load_wav(wav, target_sr):
     return speech
 
 
-def convert_onnx_to_trt(trt_model, onnx_model, fp16):
+def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
     import tensorrt as trt
-    _min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
-    _opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
-    _max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
-    input_names = ["x", "mask", "mu", "t", "spks", "cond"]
-
     logging.info("Converting onnx to trt...")
     network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
     logger = trt.Logger(trt.Logger.INFO)
@@ -72,8 +67,8 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
                 print(parser.get_error(error))
             raise ValueError('failed to parse {}'.format(onnx_model))
     # set input shapes
-    for i in range(len(input_names)):
-        profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
+    for i in range(len(trt_kwargs['input_names'])):
+        profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
     tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
     # set input and output data type
     for i in range(network.num_inputs):
@@ -87,3 +82,4 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
     # save trt engine
     with open(trt_model, "wb") as f:
         f.write(engine_bytes)
+    logging.info("Succesfully convert onnx to trt...")

+ 2 - 3
cosyvoice/utils/mask.py

@@ -15,7 +15,6 @@
 # limitations under the License.
 
 import torch
-from cosyvoice.utils.file_utils import logging
 '''
 def subsequent_mask(
         size: int,
@@ -198,8 +197,8 @@ def add_optional_chunk_mask(xs: torch.Tensor,
         chunk_masks = masks
     assert chunk_masks.dtype == torch.bool
     if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
-        logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
-        chunk_masks[chunk_masks.sum(dim=-1)==0] = True
+        print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
+        chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
     return chunk_masks
 
 

+ 4 - 0
cosyvoice/utils/train_utils.py

@@ -286,11 +286,15 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
             # optimizer.step().
             if torch.isfinite(grad_norm):
                 scaler.step(optimizer)
+            else:
+                logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
             scaler.update()
         else:
             grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
             if torch.isfinite(grad_norm):
                 optimizer.step()
+            else:
+                logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
         optimizer.zero_grad()
         scheduler.step()
     info_dict["lr"] = optimizer.param_groups[0]['lr']

+ 7 - 3
examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

@@ -13,6 +13,10 @@ qwen_pretrain_path: ''
 token_frame_rate: 25
 token_mel_ratio: 2
 
+# stream related params
+chunk_size: 1 # streaming inference chunk size, in second
+num_decoding_left_chunks: 2 # streaming inference flow decoder left chunk size, in second
+
 # model params
 # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
 # for system/third_party class/function, we do not require this.
@@ -56,7 +60,7 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
         input_size: 512
         use_cnn_module: False
         macaron_style: False
-        static_chunk_size: !ref <token_frame_rate> # 试试UpsampleConformerEncoder也是static
+        static_chunk_size: !ref <chunk_size> * <token_frame_rate>
     decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
         in_channels: 240
         n_spks: 1
@@ -79,8 +83,8 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
             num_mid_blocks: 12
             num_heads: 8
             act_fn: 'gelu'
-            static_chunk_size: !ref <token_frame_rate> * <token_mel_ratio> # here we use static_chunk_size because we want to fix kv cache size during inference
-            num_decoding_left_chunks: 2
+            static_chunk_size: !ref <chunk_size> * <token_frame_rate> * <token_mel_ratio> # here we use static_chunk_size because we want to fix kv cache size during inference
+            num_decoding_left_chunks: !ref <num_decoding_left_chunks>
 
 hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
     in_channels: 80

+ 1 - 1
requirements.txt

@@ -13,7 +13,7 @@ inflect==7.3.1
 librosa==0.10.2
 lightning==2.2.4
 matplotlib==3.7.5
-modelscope==1.15.0
+modelscope==1.20.0
 networkx==3.1
 omegaconf==2.3.0
 onnx==1.16.0