浏览代码

remove flow_cache

lyuxiang.lx 11 月之前
父节点
当前提交
68100c267a

+ 1 - 1
README.md

@@ -126,7 +126,7 @@ import torchaudio
 
 **CosyVoice2 Usage**
 ```python
-cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
+cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
 
 # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
 # zero_shot usage

+ 3 - 4
cosyvoice/bin/export_jit.py

@@ -61,8 +61,7 @@ def main():
         model = CosyVoice(args.model_dir)
     except Exception:
         try:
-            # NOTE set use_flow_cache=True when export jit for cache inference
-            model = CosyVoice2(args.model_dir, use_flow_cache=True)
+            model = CosyVoice2(args.model_dir)
         except Exception:
             raise TypeError('no valid model_type!')
 
@@ -93,9 +92,9 @@ def main():
     else:
         # 3. export flow encoder
         flow_encoder = model.model.flow.encoder
-        script = get_optimized_script(flow_encoder, ['forward_chunk'])
+        script = get_optimized_script(flow_encoder)
         script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
-        script = get_optimized_script(flow_encoder.half(), ['forward_chunk'])
+        script = get_optimized_script(flow_encoder.half())
         script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
         logging.info('successfully export flow_encoder')
 

+ 49 - 126
cosyvoice/bin/export_onnx.py

@@ -62,135 +62,58 @@ def main():
         model = CosyVoice(args.model_dir)
     except Exception:
         try:
-            # NOTE set use_flow_cache=True when export jit for cache inference
-            model = CosyVoice2(args.model_dir, use_flow_cache=True)
+            model = CosyVoice2(args.model_dir)
         except Exception:
             raise TypeError('no valid model_type!')
 
-    if not isinstance(model, CosyVoice2):
-        # 1. export flow decoder estimator
-        estimator = model.model.flow.decoder.estimator
-        estimator.eval()
-
-        device = model.model.device
-        batch_size, seq_len = 2, 256
-        out_channels = model.model.flow.decoder.estimator.out_channels
-        x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
-        torch.onnx.export(
-            estimator,
-            (x, mask, mu, t, spks, cond),
-            '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
-            export_params=True,
-            opset_version=18,
-            do_constant_folding=True,
-            input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
-            output_names=['estimator_out'],
-            dynamic_axes={
-                'x': {2: 'seq_len'},
-                'mask': {2: 'seq_len'},
-                'mu': {2: 'seq_len'},
-                'cond': {2: 'seq_len'},
-                'estimator_out': {2: 'seq_len'},
-            }
-        )
-
-        # 2. test computation consistency
-        option = onnxruntime.SessionOptions()
-        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
-        option.intra_op_num_threads = 1
-        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
-        estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
-                                                      sess_options=option, providers=providers)
-
-        for _ in tqdm(range(10)):
-            x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
-            output_pytorch = estimator(x, mask, mu, t, spks, cond)
-            ort_inputs = {
-                'x': x.cpu().numpy(),
-                'mask': mask.cpu().numpy(),
-                'mu': mu.cpu().numpy(),
-                't': t.cpu().numpy(),
-                'spks': spks.cpu().numpy(),
-                'cond': cond.cpu().numpy()
-            }
-            output_onnx = estimator_onnx.run(None, ort_inputs)[0]
-            torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
-        logging.info('successfully export estimator')
-    else:
-        # 1. export flow decoder estimator
-        estimator = model.model.flow.decoder.estimator
-        estimator.forward = estimator.forward_chunk
-        estimator.eval()
-
-        device = model.model.device
-        batch_size, seq_len = 2, 256
-        out_channels = model.model.flow.decoder.estimator.out_channels
-        x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
-        cache = model.model.init_flow_cache()['decoder_cache']
-        cache.pop('offset')
-        cache = {k: v[0] for k, v in cache.items()}
-        torch.onnx.export(
-            estimator,
-            (x, mask, mu, t, spks, cond,
-             cache['down_blocks_conv_cache'],
-             cache['down_blocks_kv_cache'],
-             cache['mid_blocks_conv_cache'],
-             cache['mid_blocks_kv_cache'],
-             cache['up_blocks_conv_cache'],
-             cache['up_blocks_kv_cache'],
-             cache['final_blocks_conv_cache']),
-            '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
-            export_params=True,
-            opset_version=18,
-            do_constant_folding=True,
-            input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
-                         'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
-            output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
-                          'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
-            dynamic_axes={
-                'x': {2: 'seq_len'},
-                'mask': {2: 'seq_len'},
-                'mu': {2: 'seq_len'},
-                'cond': {2: 'seq_len'},
-                'down_blocks_kv_cache': {3: 'cache_in_len'},
-                'mid_blocks_kv_cache': {3: 'cache_in_len'},
-                'up_blocks_kv_cache': {3: 'cache_in_len'},
-                'estimator_out': {2: 'seq_len'},
-                'down_blocks_kv_cache_out': {3: 'cache_out_len'},
-                'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
-                'up_blocks_kv_cache_out': {3: 'cache_out_len'},
-            }
-        )
-
-        # 2. test computation consistency
-        option = onnxruntime.SessionOptions()
-        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
-        option.intra_op_num_threads = 1
-        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
-        estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
-                                                      sess_options=option, providers=providers)
-
-        for iter in tqdm(range(10)):
-            x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
-            cache = model.model.init_flow_cache()['decoder_cache']
-            cache.pop('offset')
-            cache = {k: v[0] for k, v in cache.items()}
-            output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
-            ort_inputs = {
-                'x': x.cpu().numpy(),
-                'mask': mask.cpu().numpy(),
-                'mu': mu.cpu().numpy(),
-                't': t.cpu().numpy(),
-                'spks': spks.cpu().numpy(),
-                'cond': cond.cpu().numpy(),
-            }
-            output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
-            if iter == 0:
-                # NOTE why can not pass first iteration check?
-                continue
-            for i, j in zip(output_pytorch, output_onnx):
-                torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
-        logging.info('successfully export estimator')
+    # 1. export flow decoder estimator
+    estimator = model.model.flow.decoder.estimator
+    estimator.eval()
+
+    device = model.model.device
+    batch_size, seq_len = 2, 256
+    out_channels = model.model.flow.decoder.estimator.out_channels
+    x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
+    torch.onnx.export(
+        estimator,
+        (x, mask, mu, t, spks, cond),
+        '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+        export_params=True,
+        opset_version=18,
+        do_constant_folding=True,
+        input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
+        output_names=['estimator_out'],
+        dynamic_axes={
+            'x': {2: 'seq_len'},
+            'mask': {2: 'seq_len'},
+            'mu': {2: 'seq_len'},
+            'cond': {2: 'seq_len'},
+            'estimator_out': {2: 'seq_len'},
+        }
+    )
+
+    # 2. test computation consistency
+    option = onnxruntime.SessionOptions()
+    option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+    option.intra_op_num_threads = 1
+    providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+    estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+                                                  sess_options=option, providers=providers)
+
+    for _ in tqdm(range(10)):
+        x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
+        output_pytorch = estimator(x, mask, mu, t, spks, cond)
+        ort_inputs = {
+            'x': x.cpu().numpy(),
+            'mask': mask.cpu().numpy(),
+            'mu': mu.cpu().numpy(),
+            't': t.cpu().numpy(),
+            'spks': spks.cpu().numpy(),
+            'cond': cond.cpu().numpy()
+        }
+        output_onnx = estimator_onnx.run(None, ort_inputs)[0]
+        torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
+    logging.info('successfully export estimator')
 
 
 if __name__ == "__main__":

+ 3 - 3
cosyvoice/cli/cosyvoice.py

@@ -140,7 +140,7 @@ class CosyVoice:
 
 class CosyVoice2(CosyVoice):
 
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False, trt_concurrent=1):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.fp16 = fp16
@@ -162,9 +162,9 @@ class CosyVoice2(CosyVoice):
         if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
             load_jit, load_trt, fp16 = False, False, False
             logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
-        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache, trt_concurrent)
+        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
         self.model.load('{}/llm.pt'.format(model_dir),
-                        '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
+                        '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
         if load_jit:
             self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))

+ 30 - 70
cosyvoice/cli/model.py

@@ -33,12 +33,14 @@ class CosyVoiceModel:
                  llm: torch.nn.Module,
                  flow: torch.nn.Module,
                  hift: torch.nn.Module,
-                 fp16: bool = False):
+                 fp16: bool = False,
+                 trt_concurrent: int = 1):
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         self.llm = llm
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
+        self.trt_concurrent = trt_concurrent
         if self.fp16 is True:
             self.llm.half()
             self.flow.half()
@@ -85,23 +87,18 @@ class CosyVoiceModel:
 
     def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
         assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
-        if not os.path.exists(flow_decoder_estimator_model):
+        if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
             convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
-        if os.path.getsize(flow_decoder_estimator_model) == 0:
-            raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
         del self.flow.decoder.estimator
         import tensorrt as trt
         with open(flow_decoder_estimator_model, 'rb') as f:
             estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
         assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
-        if isinstance(self, CosyVoice2Model):
-            self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
-        else:
-            self.flow.decoder.estimator = estimator_engine.create_execution_context()
+        self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
 
     def get_trt_kwargs(self):
         min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
-        opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200)]
+        opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
         max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
         input_names = ["x", "mask", "mu", "cond"]
         return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
@@ -249,21 +246,21 @@ class CosyVoice2Model(CosyVoiceModel):
                  flow: torch.nn.Module,
                  hift: torch.nn.Module,
                  fp16: bool = False,
-                 use_flow_cache: bool = False,
                  trt_concurrent: int = 1):
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         self.llm = llm
         self.flow = flow
+        # NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
+        self.flow.encoder.streaming = True
+        self.flow.decoder.estimator.streaming = True
         self.hift = hift
         self.fp16 = fp16
-        self.use_flow_cache = use_flow_cache
         self.trt_concurrent = trt_concurrent
         if self.fp16 is True:
             self.llm.half()
             self.flow.half()
-        # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
+        # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
-        self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
         # hift cache
         self.mel_cache_len = 8
         self.source_cache_len = int(self.mel_cache_len * 480)
@@ -278,56 +275,24 @@ class CosyVoice2Model(CosyVoiceModel):
         # dict used to store session related variable
         self.tts_speech_token_dict = {}
         self.llm_end_dict = {}
-        self.flow_cache_dict = {}
         self.hift_cache_dict = {}
         self.trt_context_dict = {}
 
-    def init_flow_cache(self):
-        encoder_cache = {'offset': 0,
-                         'pre_lookahead_layer_conv2_cache': torch.zeros(1, 512, 2).to(self.device),
-                         'encoders_kv_cache': torch.zeros(6, 1, 8, 0, 64 * 2).to(self.device),
-                         'upsample_offset': 0,
-                         'upsample_conv_cache': torch.zeros(1, 512, 4).to(self.device),
-                         'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
-        decoder_cache = {'offset': 0,
-                         'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
-                         'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
-                         'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
-                         'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
-                         'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
-                         'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
-                         'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
-        if self.fp16 is True:
-            for cache in [encoder_cache, decoder_cache]:
-                for k, v in cache.items():
-                    if isinstance(v, torch.Tensor):
-                        cache[k] = v.half()
-        cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
-        return cache
-
     def load_jit(self, flow_encoder_model):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
-    def get_trt_kwargs(self):
-        min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4), (1, 4, 2, 0, 512, 2), (12, 4, 2, 0, 512, 2), (1, 4, 2, 0, 512, 2)]
-        opt_shape = [(2, 80, 200), (2, 1, 200), (2, 80, 200), (2, 80, 200), (1, 4, 2, 100, 512, 2), (12, 4, 2, 100, 512, 2), (1, 4, 2, 100, 512, 2)]
-        max_shape = [(2, 80, 1500), (2, 1, 1500), (2, 80, 1500), (2, 80, 1500), (1, 4, 2, 200, 512, 2), (12, 4, 2, 200, 512, 2), (1, 4, 2, 200, 512, 2)]
-        input_names = ["x", "mask", "mu", "cond", 'down_blocks_kv_cache', 'mid_blocks_kv_cache', 'up_blocks_kv_cache']
-        assert self.use_flow_cache is True, "get_trt_kwargs is set for flow cache mode. If you want to use trt with use_flow_cache=False, please set higher max_shape"
-        return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
-
-    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
         with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
-            tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
-                                                                      token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
-                                                                      prompt_token=prompt_token.to(self.device),
-                                                                      prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
-                                                                      prompt_feat=prompt_feat.to(self.device),
-                                                                      prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
-                                                                      embedding=embedding.to(self.device),
-                                                                      cache=self.flow_cache_dict[uuid],
-                                                                      finalize=finalize)
+            tts_mel, _ = self.flow.inference(token=token.to(self.device),
+                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_token=prompt_token.to(self.device),
+                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_feat=prompt_feat.to(self.device),
+                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                             embedding=embedding.to(self.device),
+                                             finalize=finalize)
+        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
         # append hift cache
         if self.hift_cache_dict[uuid] is not None:
             hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
@@ -362,7 +327,6 @@ class CosyVoice2Model(CosyVoiceModel):
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
             self.hift_cache_dict[this_uuid] = None
-            self.flow_cache_dict[this_uuid] = self.init_flow_cache()
             self.trt_context_dict[this_uuid] = self.trt_context_pool.get()
         if source_speech_token.shape[1] == 0:
             p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
@@ -370,27 +334,23 @@ class CosyVoice2Model(CosyVoiceModel):
             p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
         p.start()
         if stream is True:
-            assert self.use_flow_cache is True, "set use_flow_cache=True if you want to use stream inference to avoid OOM"
-            # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
-            flow_prompt_speech_token = flow_prompt_speech_token[:, -int(self.flow_decoder_required_cache_size / self.flow.token_mel_ratio):]
-            prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size:]
+            token_offset = 0
+            prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
             while True:
                 time.sleep(0.1)
-                if len(self.tts_speech_token_dict[this_uuid]) >= self.token_hop_len + self.flow.pre_lookahead_len:
-                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
+                this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
+                if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
+                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
                     this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                      prompt_token=flow_prompt_speech_token,
                                                      prompt_feat=prompt_speech_feat,
                                                      embedding=flow_embedding,
+                                                     token_offset=token_offset,
                                                      uuid=this_uuid,
                                                      finalize=False)
-                    # NOTE in cache inference mode, we only use flow_prompt_speech_token/prompt_speech_feat in first chunk
-                    flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
-                    prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
+                    token_offset += this_token_hop_len
                     yield {'tts_speech': this_tts_speech.cpu()}
-                    with self.lock:
-                        self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][self.token_hop_len:]
-                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < self.token_hop_len + self.flow.pre_lookahead_len:
+                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
                     break
             p.join()
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
@@ -399,18 +359,19 @@ class CosyVoice2Model(CosyVoiceModel):
                                              prompt_token=flow_prompt_speech_token,
                                              prompt_feat=prompt_speech_feat,
                                              embedding=flow_embedding,
+                                             token_offset=token_offset,
                                              uuid=this_uuid,
                                              finalize=True)
             yield {'tts_speech': this_tts_speech.cpu()}
         else:
             # deal with all tokens
-            assert self.use_flow_cache is False, "set use_flow_cache=False for nonstream inference"
             p.join()
             this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
             this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                              prompt_token=flow_prompt_speech_token,
                                              prompt_feat=prompt_speech_feat,
                                              embedding=flow_embedding,
+                                             token_offset=0,
                                              uuid=this_uuid,
                                              finalize=True,
                                              speed=speed)
@@ -419,7 +380,6 @@ class CosyVoice2Model(CosyVoiceModel):
             self.tts_speech_token_dict.pop(this_uuid)
             self.llm_end_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
-            self.flow_cache_dict.pop(this_uuid)
             self.trt_context_pool.put(self.trt_context_dict[this_uuid])
             self.trt_context_dict.pop(this_uuid)
         if torch.cuda.is_available():

+ 6 - 7
cosyvoice/dataset/processor.py

@@ -159,7 +159,7 @@ def truncate(data, truncate_length=24576, mode='train'):
 
 def compute_fbank(data,
                   feat_extractor,
-                  token_mel_ratio=2,
+                  token_mel_ratio=0,
                   mode='train'):
     """ Extract fbank
 
@@ -176,12 +176,11 @@ def compute_fbank(data,
         assert 'text_token' in sample
         waveform = sample['speech']
         feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
-
-        # trim to align speech_token and speech_feat
-        token_len = min(feat.shape[0] // token_mel_ratio, sample["speech_token"].shape[0])
-        feat = feat[:token_mel_ratio * token_len]
-        sample["speech_token"] = sample["speech_token"][:token_len]
-
+        if token_mel_ratio != 0:
+            # trim to align speech_token and speech_feat
+            token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
+            feat = feat[:token_mel_ratio * token_len]
+            sample["speech_token"] = sample["speech_token"][:token_len]
         sample['speech_feat'] = feat
         yield sample
 

+ 27 - 432
cosyvoice/flow/decoder.py

@@ -11,16 +11,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Tuple, Optional, Dict, Any
+from typing import Tuple
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from einops import pack, rearrange, repeat
-from diffusers.models.attention_processor import Attention, AttnProcessor2_0, inspect, logger, deprecate
 from cosyvoice.utils.common import mask_to_bias
 from cosyvoice.utils.mask import add_optional_chunk_mask
 from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
-from matcha.models.components.transformer import BasicTransformerBlock, maybe_allow_in_graph
+from matcha.models.components.transformer import BasicTransformerBlock
 
 
 class Transpose(torch.nn.Module):
@@ -29,7 +28,7 @@ class Transpose(torch.nn.Module):
         self.dim0 = dim0
         self.dim1 = dim1
 
-    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
         x = torch.transpose(x, self.dim0, self.dim1)
         return x
 
@@ -57,15 +56,10 @@ class CausalConv1d(torch.nn.Conv1d):
         assert stride == 1
         self.causal_padding = kernel_size - 1
 
-    def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
-        if cache.size(2) == 0:
-            x = F.pad(x, (self.causal_padding, 0), value=0.0)
-        else:
-            assert cache.size(2) == self.causal_padding
-            x = torch.concat([cache, x], dim=2)
-        cache = x[:, :, -self.causal_padding:]
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = F.pad(x, (self.causal_padding, 0), value=0.0)
         x = super(CausalConv1d, self).forward(x)
-        return x, cache
+        return x
 
 
 class CausalBlock1D(Block1D):
@@ -79,11 +73,9 @@ class CausalBlock1D(Block1D):
             nn.Mish(),
         )
 
-    def forward(self, x: torch.Tensor, mask: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
-        output, cache = self.block[0](x * mask, cache)
-        for i in range(1, len(self.block)):
-            output = self.block[i](output)
-        return output * mask, cache
+    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        output = self.block(x * mask)
+        return output * mask
 
 
 class CausalResnetBlock1D(ResnetBlock1D):
@@ -92,303 +84,6 @@ class CausalResnetBlock1D(ResnetBlock1D):
         self.block1 = CausalBlock1D(dim, dim_out)
         self.block2 = CausalBlock1D(dim_out, dim_out)
 
-    def forward(self, x: torch.Tensor, mask: torch.Tensor, time_emb: torch.Tensor,
-                block1_cache: torch.Tensor = torch.zeros(0, 0, 0), block2_cache: torch.Tensor = torch.zeros(0, 0, 0)
-                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        h, block1_cache = self.block1(x, mask, block1_cache)
-        h += self.mlp(time_emb).unsqueeze(-1)
-        h, block2_cache = self.block2(h, mask, block2_cache)
-        output = h + self.res_conv(x * mask)
-        return output, block1_cache, block2_cache
-
-
-class CausalAttnProcessor2_0(AttnProcessor2_0):
-    r"""
-    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
-    """
-
-    def __init__(self):
-        super(CausalAttnProcessor2_0, self).__init__()
-
-    def __call__(
-        self,
-        attn: Attention,
-        hidden_states: torch.FloatTensor,
-        encoder_hidden_states: Optional[torch.FloatTensor] = None,
-        attention_mask: Optional[torch.FloatTensor] = None,
-        temb: Optional[torch.FloatTensor] = None,
-        cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-        *args,
-        **kwargs,
-    ) -> Tuple[torch.FloatTensor, torch.Tensor]:
-        if len(args) > 0 or kwargs.get("scale", None) is not None:
-            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. \
-                `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
-            deprecate("scale", "1.0.0", deprecation_message)
-
-        residual = hidden_states
-        if attn.spatial_norm is not None:
-            hidden_states = attn.spatial_norm(hidden_states, temb)
-
-        input_ndim = hidden_states.ndim
-
-        if input_ndim == 4:
-            batch_size, channel, height, width = hidden_states.shape
-            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
-        batch_size, sequence_length, _ = (
-            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-        )
-
-        if attention_mask is not None:
-            # NOTE do not use attn.prepare_attention_mask as we have already provided the correct attention_mask
-            # scaled_dot_product_attention expects attention_mask shape to be
-            # (batch, heads, source_length, target_length)
-            attention_mask = attention_mask.unsqueeze(dim=1).repeat(1, attn.heads, 1, 1)
-
-        if attn.group_norm is not None:
-            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
-        query = attn.to_q(hidden_states)
-
-        if encoder_hidden_states is None:
-            encoder_hidden_states = hidden_states
-        elif attn.norm_cross:
-            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
-        key_cache = attn.to_k(encoder_hidden_states)
-        value_cache = attn.to_v(encoder_hidden_states)
-        # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
-        if cache.size(0) != 0:
-            key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
-            value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
-        else:
-            key, value = key_cache, value_cache
-        cache = torch.stack([key_cache, value_cache], dim=3)
-
-        inner_dim = key.shape[-1]
-        head_dim = inner_dim // attn.heads
-
-        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-        # the output of sdp = (batch, num_heads, seq_len, head_dim)
-        # TODO: add support for attn.scale when we move to Torch 2.1
-        hidden_states = F.scaled_dot_product_attention(
-            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
-        )
-
-        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
-        hidden_states = hidden_states.to(query.dtype)
-
-        # linear proj
-        hidden_states = attn.to_out[0](hidden_states)
-        # dropout
-        hidden_states = attn.to_out[1](hidden_states)
-
-        if input_ndim == 4:
-            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
-        if attn.residual_connection:
-            hidden_states = hidden_states + residual
-
-        hidden_states = hidden_states / attn.rescale_output_factor
-
-        return hidden_states, cache
-
-
-@maybe_allow_in_graph
-class CausalAttention(Attention):
-    def __init__(
-        self,
-        query_dim: int,
-        cross_attention_dim: Optional[int] = None,
-        heads: int = 8,
-        dim_head: int = 64,
-        dropout: float = 0.0,
-        bias: bool = False,
-        upcast_attention: bool = False,
-        upcast_softmax: bool = False,
-        cross_attention_norm: Optional[str] = None,
-        cross_attention_norm_num_groups: int = 32,
-        qk_norm: Optional[str] = None,
-        added_kv_proj_dim: Optional[int] = None,
-        norm_num_groups: Optional[int] = None,
-        spatial_norm_dim: Optional[int] = None,
-        out_bias: bool = True,
-        scale_qk: bool = True,
-        only_cross_attention: bool = False,
-        eps: float = 1e-5,
-        rescale_output_factor: float = 1.0,
-        residual_connection: bool = False,
-        _from_deprecated_attn_block: bool = False,
-        processor: Optional["AttnProcessor2_0"] = None,
-        out_dim: int = None,
-    ):
-        super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax,
-                                              cross_attention_norm, cross_attention_norm_num_groups, qk_norm, added_kv_proj_dim, norm_num_groups,
-                                              spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection,
-                                              _from_deprecated_attn_block, processor, out_dim)
-        processor = CausalAttnProcessor2_0()
-        self.set_processor(processor)
-
-    def forward(
-        self,
-        hidden_states: torch.FloatTensor,
-        encoder_hidden_states: Optional[torch.FloatTensor] = None,
-        attention_mask: Optional[torch.FloatTensor] = None,
-        cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-        **cross_attention_kwargs,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        r"""
-        The forward method of the `Attention` class.
-
-        Args:
-            hidden_states (`torch.Tensor`):
-                The hidden states of the query.
-            encoder_hidden_states (`torch.Tensor`, *optional*):
-                The hidden states of the encoder.
-            attention_mask (`torch.Tensor`, *optional*):
-                The attention mask to use. If `None`, no mask is applied.
-            **cross_attention_kwargs:
-                Additional keyword arguments to pass along to the cross attention.
-
-        Returns:
-            `torch.Tensor`: The output of the attention layer.
-        """
-        # The `Attention` class can call different attention processors / attention functions
-        # here we simply pass along all tensors to the selected processor class
-        # For standard processors that are defined here, `**cross_attention_kwargs` is empty
-
-        attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
-        unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
-        if len(unused_kwargs) > 0:
-            logger.warning(
-                f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
-            )
-        cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
-
-        return self.processor(
-            self,
-            hidden_states,
-            encoder_hidden_states=encoder_hidden_states,
-            attention_mask=attention_mask,
-            cache=cache,
-            **cross_attention_kwargs,
-        )
-
-
-@maybe_allow_in_graph
-class CausalBasicTransformerBlock(BasicTransformerBlock):
-    def __init__(
-        self,
-        dim: int,
-        num_attention_heads: int,
-        attention_head_dim: int,
-        dropout=0.0,
-        cross_attention_dim: Optional[int] = None,
-        activation_fn: str = "geglu",
-        num_embeds_ada_norm: Optional[int] = None,
-        attention_bias: bool = False,
-        only_cross_attention: bool = False,
-        double_self_attention: bool = False,
-        upcast_attention: bool = False,
-        norm_elementwise_affine: bool = True,
-        norm_type: str = "layer_norm",
-        final_dropout: bool = False,
-    ):
-        super(CausalBasicTransformerBlock, self).__init__(dim, num_attention_heads, attention_head_dim, dropout,
-                                                          cross_attention_dim, activation_fn, num_embeds_ada_norm,
-                                                          attention_bias, only_cross_attention, double_self_attention,
-                                                          upcast_attention, norm_elementwise_affine, norm_type, final_dropout)
-        self.attn1 = CausalAttention(
-            query_dim=dim,
-            heads=num_attention_heads,
-            dim_head=attention_head_dim,
-            dropout=dropout,
-            bias=attention_bias,
-            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
-            upcast_attention=upcast_attention,
-        )
-
-    def forward(
-        self,
-        hidden_states: torch.FloatTensor,
-        attention_mask: Optional[torch.FloatTensor] = None,
-        encoder_hidden_states: Optional[torch.FloatTensor] = None,
-        encoder_attention_mask: Optional[torch.FloatTensor] = None,
-        timestep: Optional[torch.LongTensor] = None,
-        cross_attention_kwargs: Dict[str, Any] = None,
-        class_labels: Optional[torch.LongTensor] = None,
-        cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        # Notice that normalization is always applied before the real computation in the following blocks.
-        # 1. Self-Attention
-        if self.use_ada_layer_norm:
-            norm_hidden_states = self.norm1(hidden_states, timestep)
-        elif self.use_ada_layer_norm_zero:
-            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
-                hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
-            )
-        else:
-            norm_hidden_states = self.norm1(hidden_states)
-
-        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
-
-        attn_output, cache = self.attn1(
-            norm_hidden_states,
-            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
-            attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
-            cache=cache,
-            **cross_attention_kwargs,
-        )
-        if self.use_ada_layer_norm_zero:
-            attn_output = gate_msa.unsqueeze(1) * attn_output
-        hidden_states = attn_output + hidden_states
-
-        # 2. Cross-Attention
-        if self.attn2 is not None:
-            norm_hidden_states = (
-                self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
-            )
-
-            attn_output = self.attn2(
-                norm_hidden_states,
-                encoder_hidden_states=encoder_hidden_states,
-                attention_mask=encoder_attention_mask,
-                **cross_attention_kwargs,
-            )
-            hidden_states = attn_output + hidden_states
-
-        # 3. Feed-forward
-        norm_hidden_states = self.norm3(hidden_states)
-
-        if self.use_ada_layer_norm_zero:
-            norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
-
-        if self._chunk_size is not None:
-            # "feed_forward_chunk_size" can be used to save memory
-            if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
-                raise ValueError(f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: \
-                                 {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.")
-
-            num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
-            ff_output = torch.cat(
-                [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
-                dim=self._chunk_dim,
-            )
-        else:
-            ff_output = self.ff(norm_hidden_states)
-
-        if self.use_ada_layer_norm_zero:
-            ff_output = gate_mlp.unsqueeze(1) * ff_output
-
-        hidden_states = ff_output + hidden_states
-
-        return hidden_states, cache
-
 
 class ConditionalDecoder(nn.Module):
     def __init__(
@@ -640,7 +335,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
             resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
             transformer_blocks = nn.ModuleList(
                 [
-                    CausalBasicTransformerBlock(
+                    BasicTransformerBlock(
                         dim=output_channel,
                         num_attention_heads=num_heads,
                         attention_head_dim=attention_head_dim,
@@ -662,7 +357,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
 
             transformer_blocks = nn.ModuleList(
                 [
-                    CausalBasicTransformerBlock(
+                    BasicTransformerBlock(
                         dim=output_channel,
                         num_attention_heads=num_heads,
                         attention_head_dim=attention_head_dim,
@@ -687,7 +382,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
             )
             transformer_blocks = nn.ModuleList(
                 [
-                    CausalBasicTransformerBlock(
+                    BasicTransformerBlock(
                         dim=output_channel,
                         num_attention_heads=num_heads,
                         attention_head_dim=attention_head_dim,
@@ -724,6 +419,9 @@ class CausalConditionalDecoder(ConditionalDecoder):
         Returns:
             _type_: _description_
         """
+        if hasattr(self, 'streaming'):
+            assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
+            streaming = self.streaming
 
         t = self.time_embeddings(t).to(t.dtype)
         t = self.time_mlp(t)
@@ -740,36 +438,36 @@ class CausalConditionalDecoder(ConditionalDecoder):
         masks = [mask]
         for resnet, transformer_blocks, downsample in self.down_blocks:
             mask_down = masks[-1]
-            x, _, _ = resnet(x, mask_down, t)
+            x = resnet(x, mask_down, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
             if streaming is True:
-                attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+                attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
             else:
                 attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
-                x, _ = transformer_block(
+                x = transformer_block(
                     hidden_states=x,
                     attention_mask=attn_mask,
                     timestep=t,
                 )
             x = rearrange(x, "b t c -> b c t").contiguous()
             hiddens.append(x)  # Save hidden states for skip connections
-            x, _ = downsample(x * mask_down)
+            x = downsample(x * mask_down)
             masks.append(mask_down[:, :, ::2])
         masks = masks[:-1]
         mask_mid = masks[-1]
 
         for resnet, transformer_blocks in self.mid_blocks:
-            x, _, _ = resnet(x, mask_mid, t)
+            x = resnet(x, mask_mid, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
             if streaming is True:
-                attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+                attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
             else:
                 attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
-                x, _ = transformer_block(
+                x = transformer_block(
                     hidden_states=x,
                     attention_mask=attn_mask,
                     timestep=t,
@@ -780,124 +478,21 @@ class CausalConditionalDecoder(ConditionalDecoder):
             mask_up = masks.pop()
             skip = hiddens.pop()
             x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
-            x, _, _ = resnet(x, mask_up, t)
+            x = resnet(x, mask_up, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
             if streaming is True:
-                attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+                attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
             else:
                 attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
-                x, _ = transformer_block(
+                x = transformer_block(
                     hidden_states=x,
                     attention_mask=attn_mask,
                     timestep=t,
                 )
             x = rearrange(x, "b t c -> b c t").contiguous()
-            x, _ = upsample(x * mask_up)
-        x, _ = self.final_block(x, mask_up)
+            x = upsample(x * mask_up)
+        x = self.final_block(x, mask_up)
         output = self.final_proj(x * mask_up)
         return output * mask
-
-    @torch.inference_mode()
-    def forward_chunk(self, x, mask, mu, t, spks=None, cond=None,
-                      down_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-                      down_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
-                      mid_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-                      mid_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
-                      up_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
-                      up_blocks_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0, 0),
-                      final_blocks_conv_cache: torch.Tensor = torch.zeros(0, 0, 0)
-                      ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Forward pass of the UNet1DConditional model.
-
-        Args:
-            x (torch.Tensor): shape (batch_size, in_channels, time)
-            mask (_type_): shape (batch_size, 1, time)
-            t (_type_): shape (batch_size)
-            spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
-            cond (_type_, optional): placeholder for future use. Defaults to None.
-
-        Raises:
-            ValueError: _description_
-            ValueError: _description_
-
-        Returns:
-            _type_: _description_
-        """
-
-        t = self.time_embeddings(t).to(t.dtype)
-        t = self.time_mlp(t)
-
-        x = pack([x, mu], "b * t")[0]
-
-        if spks is not None:
-            spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
-            x = pack([x, spks], "b * t")[0]
-        if cond is not None:
-            x = pack([x, cond], "b * t")[0]
-
-        hiddens = []
-        masks = [mask]
-
-        down_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
-        mid_blocks_kv_cache_new = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x.device)
-        up_blocks_kv_cache_new = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x.device)
-        for index, (resnet, transformer_blocks, downsample) in enumerate(self.down_blocks):
-            mask_down = masks[-1]
-            x, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576] = \
-                resnet(x, mask_down, t, down_blocks_conv_cache[index][:, :320], down_blocks_conv_cache[index][:, 320: 576])
-            x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + down_blocks_kv_cache.size(3), device=x.device).bool()
-            attn_mask = mask_to_bias(attn_mask, x.dtype)
-            for i, transformer_block in enumerate(transformer_blocks):
-                x, down_blocks_kv_cache_new[index, i] = transformer_block(
-                    hidden_states=x,
-                    attention_mask=attn_mask,
-                    timestep=t,
-                    cache=down_blocks_kv_cache[index, i],
-                )
-            x = rearrange(x, "b t c -> b c t").contiguous()
-            hiddens.append(x)  # Save hidden states for skip connections
-            x, down_blocks_conv_cache[index][:, 576:] = downsample(x * mask_down, down_blocks_conv_cache[index][:, 576:])
-            masks.append(mask_down[:, :, ::2])
-        masks = masks[:-1]
-        mask_mid = masks[-1]
-
-        for index, (resnet, transformer_blocks) in enumerate(self.mid_blocks):
-            x, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:] = \
-                resnet(x, mask_mid, t, mid_blocks_conv_cache[index][:, :256], mid_blocks_conv_cache[index][:, 256:])
-            x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + mid_blocks_kv_cache.size(3), device=x.device).bool()
-            attn_mask = mask_to_bias(attn_mask, x.dtype)
-            for i, transformer_block in enumerate(transformer_blocks):
-                x, mid_blocks_kv_cache_new[index, i] = transformer_block(
-                    hidden_states=x,
-                    attention_mask=attn_mask,
-                    timestep=t,
-                    cache=mid_blocks_kv_cache[index, i]
-                )
-            x = rearrange(x, "b t c -> b c t").contiguous()
-
-        for index, (resnet, transformer_blocks, upsample) in enumerate(self.up_blocks):
-            mask_up = masks.pop()
-            skip = hiddens.pop()
-            x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
-            x, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768] = \
-                resnet(x, mask_up, t, up_blocks_conv_cache[index][:, :512], up_blocks_conv_cache[index][:, 512: 768])
-            x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.ones(x.size(0), x.size(1), x.size(1) + up_blocks_kv_cache.size(3), device=x.device).bool()
-            attn_mask = mask_to_bias(attn_mask, x.dtype)
-            for i, transformer_block in enumerate(transformer_blocks):
-                x, up_blocks_kv_cache_new[index, i] = transformer_block(
-                    hidden_states=x,
-                    attention_mask=attn_mask,
-                    timestep=t,
-                    cache=up_blocks_kv_cache[index, i]
-                )
-            x = rearrange(x, "b t c -> b c t").contiguous()
-            x, up_blocks_conv_cache[index][:, 768:] = upsample(x * mask_up, up_blocks_conv_cache[index][:, 768:])
-        x, final_blocks_conv_cache = self.final_block(x, mask_up, final_blocks_conv_cache)
-        output = self.final_proj(x * mask_up)
-        return output * mask, down_blocks_conv_cache, down_blocks_kv_cache_new, mid_blocks_conv_cache, mid_blocks_kv_cache_new, \
-            up_blocks_conv_cache, up_blocks_kv_cache_new, final_blocks_conv_cache

+ 4 - 12
cosyvoice/flow/flow.py

@@ -241,7 +241,6 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
                   prompt_feat,
                   prompt_feat_len,
                   embedding,
-                  cache,
                   finalize):
         assert token.shape[0] == 1
         # xvec projection
@@ -255,16 +254,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
 
         # text encode
         if finalize is True:
-            h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, **cache['encoder_cache'])
+            h, h_lengths = self.encoder(token, token_len)
         else:
             token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
-            h, h_lengths, encoder_cache = self.encoder.forward_chunk(token, token_len, context=context, **cache['encoder_cache'])
-        cache['encoder_cache']['offset'] = encoder_cache[0]
-        cache['encoder_cache']['pre_lookahead_layer_conv2_cache'] = encoder_cache[1]
-        cache['encoder_cache']['encoders_kv_cache'] = encoder_cache[2]
-        cache['encoder_cache']['upsample_offset'] = encoder_cache[3]
-        cache['encoder_cache']['upsample_conv_cache'] = encoder_cache[4]
-        cache['encoder_cache']['upsample_kv_cache'] = encoder_cache[5]
+            h, h_lengths = self.encoder(token, token_len, context=context)
         mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
         h = self.encoder_proj(h)
 
@@ -274,14 +267,13 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         conds = conds.transpose(1, 2)
 
         mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
-        feat, cache['decoder_cache'] = self.decoder(
+        feat, _ = self.decoder(
             mu=h.transpose(1, 2).contiguous(),
             mask=mask.unsqueeze(1),
             spks=embedding,
             cond=conds,
             n_timesteps=10,
-            cache=cache['decoder_cache']
         )
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2
-        return feat.float(), cache
+        return feat.float(), None

+ 23 - 145
cosyvoice/flow/flow_matching.py

@@ -126,21 +126,26 @@ class ConditionalCFM(BASECFM):
         if isinstance(self.estimator, torch.nn.Module):
             return self.estimator(x, mask, mu, t, spks, cond)
         else:
-            with self.lock:
-                self.estimator.set_input_shape('x', (2, 80, x.size(2)))
-                self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
-                self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
-                self.estimator.set_input_shape('t', (2,))
-                self.estimator.set_input_shape('spks', (2, 80))
-                self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
-                # run trt engine
-                assert self.estimator.execute_v2([x.contiguous().data_ptr(),
-                                                  mask.contiguous().data_ptr(),
-                                                  mu.contiguous().data_ptr(),
-                                                  t.contiguous().data_ptr(),
-                                                  spks.contiguous().data_ptr(),
-                                                  cond.contiguous().data_ptr(),
-                                                  x.data_ptr()]) is True
+            estimator, trt_engine = self.estimator.acquire_estimator()
+            estimator.set_input_shape('x', (2, 80, x.size(2)))
+            estimator.set_input_shape('mask', (2, 1, x.size(2)))
+            estimator.set_input_shape('mu', (2, 80, x.size(2)))
+            estimator.set_input_shape('t', (2,))
+            estimator.set_input_shape('spks', (2, 80))
+            estimator.set_input_shape('cond', (2, 80, x.size(2)))
+            data_ptrs = [x.contiguous().data_ptr(),
+                         mask.contiguous().data_ptr(),
+                         mu.contiguous().data_ptr(),
+                         t.contiguous().data_ptr(),
+                         spks.contiguous().data_ptr(),
+                         cond.contiguous().data_ptr(),
+                         x.data_ptr()]
+            for i, j in enumerate(data_ptrs):
+                estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
+            # run trt engine
+            assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
+            torch.cuda.current_stream().synchronize()
+            self.estimator.release_estimator(estimator)
             return x
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
@@ -191,7 +196,7 @@ class CausalConditionalCFM(ConditionalCFM):
         self.rand_noise = torch.randn([1, 80, 50 * 300])
 
     @torch.inference_mode()
-    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, cache={}):
+    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
         """Forward diffusion
 
         Args:
@@ -210,136 +215,9 @@ class CausalConditionalCFM(ConditionalCFM):
                 shape: (batch_size, n_feats, mel_timesteps)
         """
 
-        offset = cache.pop('offset')
-        z = self.rand_noise[:, :, :mu.size(2) + offset].to(mu.device).to(mu.dtype) * temperature
-        z = z[:, :, offset:]
-        offset += mu.size(2)
+        z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
         # fix prompt and overlap part mu and z
         t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         if self.t_scheduler == 'cosine':
             t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
-        mel, cache = self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, cache=cache)
-        cache['offset'] = offset
-        return mel, cache
-
-    def solve_euler(self, x, t_span, mu, mask, spks, cond, cache):
-        """
-        Fixed euler solver for ODEs.
-        Args:
-            x (torch.Tensor): random noise
-            t_span (torch.Tensor): n_timesteps interpolated
-                shape: (n_timesteps + 1,)
-            mu (torch.Tensor): output of encoder
-                shape: (batch_size, n_feats, mel_timesteps)
-            mask (torch.Tensor): output_mask
-                shape: (batch_size, 1, mel_timesteps)
-            spks (torch.Tensor, optional): speaker ids. Defaults to None.
-                shape: (batch_size, spk_emb_dim)
-            cond: Not used but kept for future purposes
-        """
-        t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
-        t = t.unsqueeze(dim=0)
-
-        # I am storing this because I can later plot it by putting a debugger here and saving it to a file
-        # Or in future might add like a return_all_steps flag
-        sol = []
-
-        # Do not use concat, it may cause memory format changed and trt infer with wrong results!
-        x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-        mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
-        mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-        t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
-        spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
-        cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-        flow_cache_size = cache['down_blocks_kv_cache'].shape[4]
-        for step in range(1, len(t_span)):
-            # Classifier-Free Guidance inference introduced in VoiceBox
-            x_in[:] = x
-            mask_in[:] = mask
-            mu_in[0] = mu
-            t_in[:] = t.unsqueeze(0)
-            spks_in[0] = spks
-            cond_in[0] = cond
-            cache_step = {k: v[step - 1] for k, v in cache.items()}
-            dphi_dt, cache_step = self.forward_estimator(
-                x_in, mask_in,
-                mu_in, t_in,
-                spks_in,
-                cond_in,
-                cache_step
-            )
-            # NOTE if smaller than flow_cache_size, means last chunk, no need to cache
-            if flow_cache_size != 0 and x_in.shape[2] >= flow_cache_size:
-                cache['down_blocks_conv_cache'][step - 1] = cache_step[0]
-                cache['down_blocks_kv_cache'][step - 1] = cache_step[1][:, :, :, -flow_cache_size:]
-                cache['mid_blocks_conv_cache'][step - 1] = cache_step[2]
-                cache['mid_blocks_kv_cache'][step - 1] = cache_step[3][:, :, :, -flow_cache_size:]
-                cache['up_blocks_conv_cache'][step - 1] = cache_step[4]
-                cache['up_blocks_kv_cache'][step - 1] = cache_step[5][:, :, :, -flow_cache_size:]
-                cache['final_blocks_conv_cache'][step - 1] = cache_step[6]
-            dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
-            dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
-            x = x + dt * dphi_dt
-            t = t + dt
-            sol.append(x)
-            if step < len(t_span) - 1:
-                dt = t_span[step + 1] - t
-        return sol[-1].float(), cache
-
-    def forward_estimator(self, x, mask, mu, t, spks, cond, cache):
-        if isinstance(self.estimator, torch.nn.Module):
-            x, cache1, cache2, cache3, cache4, cache5, cache6, cache7 = self.estimator.forward_chunk(x, mask, mu, t, spks, cond, **cache)
-            cache = (cache1, cache2, cache3, cache4, cache5, cache6, cache7)
-        else:
-            estimator, trt_engine = self.estimator.acquire_estimator()
-            estimator.set_input_shape('x', (2, 80, x.size(2)))
-            estimator.set_input_shape('mask', (2, 1, x.size(2)))
-            estimator.set_input_shape('mu', (2, 80, x.size(2)))
-            estimator.set_input_shape('t', (2,))
-            estimator.set_input_shape('spks', (2, 80))
-            estimator.set_input_shape('cond', (2, 80, x.size(2)))
-            estimator.set_input_shape('down_blocks_conv_cache', cache['down_blocks_conv_cache'].shape)
-            estimator.set_input_shape('down_blocks_kv_cache', cache['down_blocks_kv_cache'].shape)
-            estimator.set_input_shape('mid_blocks_conv_cache', cache['mid_blocks_conv_cache'].shape)
-            estimator.set_input_shape('mid_blocks_kv_cache', cache['mid_blocks_kv_cache'].shape)
-            estimator.set_input_shape('up_blocks_conv_cache', cache['up_blocks_conv_cache'].shape)
-            estimator.set_input_shape('up_blocks_kv_cache', cache['up_blocks_kv_cache'].shape)
-            estimator.set_input_shape('final_blocks_conv_cache', cache['final_blocks_conv_cache'].shape)
-            down_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
-            mid_blocks_kv_cache_out = torch.zeros(12, 4, 2, x.size(2), 512, 2).to(x)
-            up_blocks_kv_cache_out = torch.zeros(1, 4, 2, x.size(2), 512, 2).to(x)
-            data_ptrs = [x.contiguous().data_ptr(),
-                         mask.contiguous().data_ptr(),
-                         mu.contiguous().data_ptr(),
-                         t.contiguous().data_ptr(),
-                         spks.contiguous().data_ptr(),
-                         cond.contiguous().data_ptr(),
-                         cache['down_blocks_conv_cache'].contiguous().data_ptr(),
-                         cache['down_blocks_kv_cache'].contiguous().data_ptr(),
-                         cache['mid_blocks_conv_cache'].contiguous().data_ptr(),
-                         cache['mid_blocks_kv_cache'].contiguous().data_ptr(),
-                         cache['up_blocks_conv_cache'].contiguous().data_ptr(),
-                         cache['up_blocks_kv_cache'].contiguous().data_ptr(),
-                         cache['final_blocks_conv_cache'].contiguous().data_ptr(),
-                         x.data_ptr(),
-                         cache['down_blocks_conv_cache'].data_ptr(),
-                         down_blocks_kv_cache_out.data_ptr(),
-                         cache['mid_blocks_conv_cache'].data_ptr(),
-                         mid_blocks_kv_cache_out.data_ptr(),
-                         cache['up_blocks_conv_cache'].data_ptr(),
-                         up_blocks_kv_cache_out.data_ptr(),
-                         cache['final_blocks_conv_cache'].data_ptr()]
-            for i, j in enumerate(data_ptrs):
-                estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
-            # run trt engine
-            assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
-            torch.cuda.current_stream().synchronize()
-            self.estimator.release_estimator(estimator)
-            cache = (cache['down_blocks_conv_cache'],
-                     down_blocks_kv_cache_out,
-                     cache['mid_blocks_conv_cache'],
-                     mid_blocks_kv_cache_out,
-                     cache['up_blocks_conv_cache'],
-                     up_blocks_kv_cache_out,
-                     cache['final_blocks_conv_cache'])
-        return x, cache
+        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None

+ 169 - 1
cosyvoice/hifigan/generator.py

@@ -223,6 +223,172 @@ class SourceModuleHnNSF(torch.nn.Module):
         return sine_merge, noise, uv
 
 
+class SineGen2(torch.nn.Module):
+    """ Definition of sine generator
+    SineGen(samp_rate, harmonic_num = 0,
+            sine_amp = 0.1, noise_std = 0.003,
+            voiced_threshold = 0,
+            flag_for_pulse=False)
+    samp_rate: sampling rate in Hz
+    harmonic_num: number of harmonic overtones (default 0)
+    sine_amp: amplitude of sine-wavefrom (default 0.1)
+    noise_std: std of Gaussian noise (default 0.003)
+    voiced_thoreshold: F0 threshold for U/V classification (default 0)
+    flag_for_pulse: this SinGen is used inside PulseGen (default False)
+    Note: when flag_for_pulse is True, the first time step of a voiced
+        segment is always sin(np.pi) or cos(0)
+    """
+
+    def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
+                 sine_amp=0.1, noise_std=0.003,
+                 voiced_threshold=0,
+                 flag_for_pulse=False):
+        super(SineGen2, self).__init__()
+        self.sine_amp = sine_amp
+        self.noise_std = noise_std
+        self.harmonic_num = harmonic_num
+        self.dim = self.harmonic_num + 1
+        self.sampling_rate = samp_rate
+        self.voiced_threshold = voiced_threshold
+        self.flag_for_pulse = flag_for_pulse
+        self.upsample_scale = upsample_scale
+
+    def _f02uv(self, f0):
+        # generate uv signal
+        uv = (f0 > self.voiced_threshold).type(torch.float32)
+        return uv
+
+    def _f02sine(self, f0_values):
+        """ f0_values: (batchsize, length, dim)
+            where dim indicates fundamental tone and overtones
+        """
+        # convert to F0 in rad. The interger part n can be ignored
+        # because 2 * np.pi * n doesn't affect phase
+        rad_values = (f0_values / self.sampling_rate) % 1
+
+        # initial phase noise (no noise for fundamental component)
+        rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
+        rand_ini[:, 0] = 0
+        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+        # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+        if not self.flag_for_pulse:
+            rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
+                                                         scale_factor=1 / self.upsample_scale,
+                                                         mode="linear").transpose(1, 2)
+
+            phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
+            phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
+                                                    scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
+            sines = torch.sin(phase)
+        else:
+            # If necessary, make sure that the first time step of every
+            # voiced segments is sin(pi) or cos(0)
+            # This is used for pulse-train generation
+
+            # identify the last time step in unvoiced segments
+            uv = self._f02uv(f0_values)
+            uv_1 = torch.roll(uv, shifts=-1, dims=1)
+            uv_1[:, -1, :] = 1
+            u_loc = (uv < 1) * (uv_1 > 0)
+
+            # get the instantanouse phase
+            tmp_cumsum = torch.cumsum(rad_values, dim=1)
+            # different batch needs to be processed differently
+            for idx in range(f0_values.shape[0]):
+                temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+                temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+                # stores the accumulation of i.phase within
+                # each voiced segments
+                tmp_cumsum[idx, :, :] = 0
+                tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+            # rad_values - tmp_cumsum: remove the accumulation of i.phase
+            # within the previous voiced segment.
+            i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+            # get the sines
+            sines = torch.cos(i_phase * 2 * np.pi)
+        return sines
+
+    def forward(self, f0):
+        """ sine_tensor, uv = forward(f0)
+        input F0: tensor(batchsize=1, length, dim=1)
+                  f0 for unvoiced steps should be 0
+        output sine_tensor: tensor(batchsize=1, length, dim)
+        output uv: tensor(batchsize=1, length, 1)
+        """
+        # fundamental component
+        fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
+
+        # generate sine waveforms
+        sine_waves = self._f02sine(fn) * self.sine_amp
+
+        # generate uv signal
+        uv = self._f02uv(f0)
+
+        # noise: for unvoiced should be similar to sine_amp
+        #        std = self.sine_amp/3 -> max value ~ self.sine_amp
+        # .       for voiced regions is self.noise_std
+        noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+        noise = noise_amp * torch.randn_like(sine_waves)
+
+        # first: set the unvoiced part to 0 by uv
+        # then: additive noise
+        sine_waves = sine_waves * uv + noise
+        return sine_waves, uv, noise
+
+
+class SourceModuleHnNSF2(torch.nn.Module):
+    """ SourceModule for hn-nsf
+    SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+                 add_noise_std=0.003, voiced_threshod=0)
+    sampling_rate: sampling_rate in Hz
+    harmonic_num: number of harmonic above F0 (default: 0)
+    sine_amp: amplitude of sine source signal (default: 0.1)
+    add_noise_std: std of additive Gaussian noise (default: 0.003)
+        note that amplitude of noise in unvoiced is decided
+        by sine_amp
+    voiced_threshold: threhold to set U/V given F0 (default: 0)
+    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+    F0_sampled (batchsize, length, 1)
+    Sine_source (batchsize, length, 1)
+    noise_source (batchsize, length 1)
+    uv (batchsize, length, 1)
+    """
+
+    def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
+                 add_noise_std=0.003, voiced_threshod=0):
+        super(SourceModuleHnNSF2, self).__init__()
+
+        self.sine_amp = sine_amp
+        self.noise_std = add_noise_std
+
+        # to produce sine waveforms
+        self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
+                                  sine_amp, add_noise_std, voiced_threshod)
+
+        # to merge source harmonics into a single excitation
+        self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+        self.l_tanh = torch.nn.Tanh()
+
+    def forward(self, x):
+        """
+        Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+        F0_sampled (batchsize, length, 1)
+        Sine_source (batchsize, length, 1)
+        noise_source (batchsize, length 1)
+        """
+        # source for harmonic branch
+        with torch.no_grad():
+            sine_wavs, uv, _ = self.l_sin_gen(x)
+        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+        # source for noise branch, in the same shape as uv
+        noise = torch.randn_like(uv) * self.sine_amp / 3
+        return sine_merge, noise, uv
+
+
 class HiFTGenerator(nn.Module):
     """
     HiFTNet Generator: Neural Source Filter + ISTFTNet
@@ -259,7 +425,9 @@ class HiFTGenerator(nn.Module):
 
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_upsamples = len(upsample_rates)
-        self.m_source = SourceModuleHnNSF(
+        # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
+        this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
+        self.m_source = this_SourceModuleHnNSF(
             sampling_rate=sampling_rate,
             upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
             harmonic_num=nb_harmonics,

+ 17 - 115
cosyvoice/transformer/upsample_encoder.py

@@ -56,16 +56,11 @@ class Upsample1D(nn.Module):
         # In this mode, first repeat interpolate, than conv with stride=1
         self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
 
-    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor, conv_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
-        if conv_cache.size(2) == 0:
-            outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
-        else:
-            assert conv_cache.size(2) == self.stride * 2
-            outputs = torch.concat([conv_cache, outputs], dim=2)
-        conv_cache_new = outputs[:, :, -self.stride * 2:]
+        outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
         outputs = self.conv(outputs)
-        return outputs, input_lengths * self.stride, conv_cache_new
+        return outputs, input_lengths * self.stride
 
 
 class PreLookaheadLayer(nn.Module):
@@ -83,7 +78,7 @@ class PreLookaheadLayer(nn.Module):
             kernel_size=3, stride=1, padding=0,
         )
 
-    def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0), conv2_cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+    def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
         """
         inputs: (batch_size, seq_len, channels)
         """
@@ -93,22 +88,18 @@ class PreLookaheadLayer(nn.Module):
         if context.size(2) == 0:
             outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
         else:
+            assert self.training is False, 'you have passed context, make sure that you are running inference mode'
             assert context.size(2) == self.pre_lookahead_len
             outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
         outputs = F.leaky_relu(self.conv1(outputs))
         # outputs
-        if conv2_cache.size(2) == 0:
-            outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
-        else:
-            assert conv2_cache.size(2) == self.conv2.kernel_size[0] - 1
-            outputs = torch.concat([conv2_cache, outputs], dim=2)
-        conv2_cache_new = outputs[:, :, -(self.conv2.kernel_size[0] - 1):]
+        outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
         outputs = self.conv2(outputs)
         outputs = outputs.transpose(1, 2).contiguous()
 
         # residual connection
         outputs = outputs + inputs
-        return outputs, conv2_cache_new
+        return outputs
 
 
 class UpsampleConformerEncoder(torch.nn.Module):
@@ -253,6 +244,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
         self,
         xs: torch.Tensor,
         xs_lens: torch.Tensor,
+        context: torch.Tensor = torch.zeros(0, 0, 0),
         decoding_chunk_size: int = 0,
         num_decoding_left_chunks: int = -1,
         streaming: bool = False,
@@ -280,20 +272,27 @@ class UpsampleConformerEncoder(torch.nn.Module):
             checkpointing API because `__call__` attaches all the hooks of the module.
             https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
         """
+        if hasattr(self, 'streaming'):
+            assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
+            streaming = self.streaming
         T = xs.size(1)
         masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
         if self.global_cmvn is not None:
             xs = self.global_cmvn(xs)
         xs, pos_emb, masks = self.embed(xs, masks)
+        if context.size(1) != 0:
+            assert self.training is False, 'you have passed context, make sure that you are running inference mode'
+            context_masks = torch.ones(1, 1, context.size(1)).to(masks)
+            context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
         mask_pad = masks  # (B, 1, T/subsample_rate)
         chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
         # lookahead + conformer encoder
-        xs, _ = self.pre_lookahead_layer(xs)
+        xs = self.pre_lookahead_layer(xs, context=context)
         xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
 
         # upsample + conformer encoder
         xs = xs.transpose(1, 2).contiguous()
-        xs, xs_lens, _ = self.up_layer(xs, xs_lens)
+        xs, xs_lens = self.up_layer(xs, xs_lens)
         xs = xs.transpose(1, 2).contiguous()
         T = xs.size(1)
         masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
@@ -322,100 +321,3 @@ class UpsampleConformerEncoder(torch.nn.Module):
         for layer in self.up_encoders:
             xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
         return xs
-
-    @torch.jit.export
-    def forward_chunk(
-        self,
-        xs: torch.Tensor,
-        xs_lens: torch.Tensor,
-        offset: int = 0,
-        context: torch.Tensor = torch.zeros(0, 0, 0),
-        pre_lookahead_layer_conv2_cache: torch.Tensor = torch.zeros(0, 0, 0),
-        encoders_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0),
-        upsample_offset: int = 0,
-        upsample_conv_cache: torch.Tensor = torch.zeros(0, 0, 0),
-        upsample_kv_cache: torch.Tensor = torch.zeros(0, 0, 0, 0, 0)
-    ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs: padded input tensor (B, T, D)
-            xs_lens: input length (B)
-            decoding_chunk_size: decoding chunk size for dynamic chunk
-                0: default for training, use random dynamic chunk.
-                <0: for decoding, use full chunk.
-                >0: for decoding, use fixed chunk size as set.
-            num_decoding_left_chunks: number of left chunks, this is for decoding,
-            the chunk size is decoding_chunk_size.
-                >=0: use num_decoding_left_chunks
-                <0: use all left chunks
-        Returns:
-            encoder output tensor xs, and subsampled masks
-            xs: padded output tensor (B, T' ~= T/subsample_rate, D)
-            masks: torch.Tensor batch padding mask after subsample
-                (B, 1, T' ~= T/subsample_rate)
-        NOTE(xcsong):
-            We pass the `__call__` method of the modules instead of `forward` to the
-            checkpointing API because `__call__` attaches all the hooks of the module.
-            https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
-        """
-        assert xs.size(0) == 1
-        # tmp_masks is just for interface compatibility
-        tmp_masks = torch.ones(1,
-                               xs.size(1),
-                               device=xs.device,
-                               dtype=torch.bool)
-        tmp_masks = tmp_masks.unsqueeze(1)
-        if self.global_cmvn is not None:
-            xs = self.global_cmvn(xs)
-        # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
-        xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
-        offset += xs.size(1)
-        tmp_masks = torch.ones(1,
-                               context.size(1),
-                               device=context.device,
-                               dtype=torch.bool)
-        tmp_masks = tmp_masks.unsqueeze(1)
-        if context.size(1) != 0:
-            context, _, _ = self.embed(context, tmp_masks, offset)
-
-        # lookahead + conformer encoder
-        xs, pre_lookahead_layer_conv2_cache = self.pre_lookahead_layer(xs, context, pre_lookahead_layer_conv2_cache)
-        # NOTE in cache mode we do not need to call add_optional_chunk_mask
-        chunk_masks = torch.ones((1, xs.size(1), offset), dtype=torch.bool, device=xs.device)
-        mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
-        encoders_kv_cache_list = []
-        for index, layer in enumerate(self.encoders):
-            xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
-            encoders_kv_cache_list.append(encoders_kv_cache_new)
-        encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
-
-        # upsample
-        xs = xs.transpose(1, 2).contiguous()
-        xs, xs_lens, upsample_conv_cache = self.up_layer(xs, xs_lens, upsample_conv_cache)
-        xs = xs.transpose(1, 2).contiguous()
-
-        # tmp_masks is just for interface compatibility
-        tmp_masks = torch.ones(1,
-                               xs.size(1),
-                               device=xs.device,
-                               dtype=torch.bool)
-        tmp_masks = tmp_masks.unsqueeze(1)
-        xs, pos_emb, masks = self.up_embed(xs, tmp_masks, upsample_offset)
-        upsample_offset += xs.size(1)
-
-        # conformer encoder
-        chunk_masks = torch.ones((1, xs.size(1), upsample_offset), dtype=torch.bool, device=xs.device)
-        mask_pad = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
-        upsample_kv_cache_list = []
-        for index, layer in enumerate(self.up_encoders):
-            xs, chunk_masks, upsample_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, upsample_kv_cache[index])
-            upsample_kv_cache_list.append(upsample_kv_cache_new)
-        upsample_kv_cache = torch.stack(upsample_kv_cache_list, dim=0)
-
-        if self.normalize_before:
-            xs = self.after_norm(xs)
-        # Here we assume the mask is not changed in encoder layers, so just
-        # return the masks before encoder layers, and the masks will be used
-        # for cross attention with decoder later
-        return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)

+ 1 - 1
cosyvoice/utils/file_utils.py

@@ -56,7 +56,7 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
     network = builder.create_network(network_flags)
     parser = trt.OnnxParser(network, logger)
     config = builder.create_builder_config()
-    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB
+    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31)  # 1GB
     if fp16:
         config.set_flag(trt.BuilderFlag.FP16)
     profile = builder.create_optimization_profile()

+ 35 - 4
cosyvoice/utils/mask.py

@@ -86,7 +86,7 @@ def subsequent_mask(
     return mask
 
 
-def subsequent_chunk_mask(
+def subsequent_chunk_mask_deprecated(
         size: int,
         chunk_size: int,
         num_left_chunks: int = -1,
@@ -124,6 +124,40 @@ def subsequent_chunk_mask(
     return ret
 
 
+def subsequent_chunk_mask(
+        size: int,
+        chunk_size: int,
+        num_left_chunks: int = -1,
+        device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+    """Create mask for subsequent steps (size, size) with chunk size,
+       this is for streaming encoder
+
+    Args:
+        size (int): size of mask
+        chunk_size (int): size of chunk
+        num_left_chunks (int): number of left chunks
+            <0: use full chunk
+            >=0: use num_left_chunks
+        device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+    Returns:
+        torch.Tensor: mask
+
+    Examples:
+        >>> subsequent_chunk_mask(4, 2)
+        [[1, 1, 0, 0],
+         [1, 1, 0, 0],
+         [1, 1, 1, 1],
+         [1, 1, 1, 1]]
+    """
+    # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
+    pos_idx = torch.arange(size, device=device)
+    block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
+    ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
+    return ret
+
+
 def add_optional_chunk_mask(xs: torch.Tensor,
                             masks: torch.Tensor,
                             use_dynamic_chunk: bool,
@@ -196,9 +230,6 @@ def add_optional_chunk_mask(xs: torch.Tensor,
     else:
         chunk_masks = masks
     assert chunk_masks.dtype == torch.bool
-    if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
-        print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
-        chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
     return chunk_masks
 
 

+ 0 - 37
test1.py

@@ -1,37 +0,0 @@
-import sys
-sys.path.append('third_party/Matcha-TTS')
-from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
-from cosyvoice.utils.file_utils import load_wav
-import torchaudio # type: ignore
-
-cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
-
-# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
-# zero_shot usage
-prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
-for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
-    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-
-# save zero_shot spk for future usage
-assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
-for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
-    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-cosyvoice.save_spkinfo()
-
-# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
-for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
-    torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-
-# instruct usage
-for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
-    torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-
-# bistream usage, you can use generator as input, this is useful when using text llm model as input
-# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
-def text_generator():
-    yield '收到好友从远方寄来的生日礼物,'
-    yield '那份意外的惊喜与深深的祝福'
-    yield '让我心中充满了甜蜜的快乐,'
-    yield '笑容如花儿般绽放。'
-for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
-    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)