Browse Source

fix cache bug

lyuxiang.lx 1 year ago
parent
commit
aea75207dd

+ 2 - 1
cosyvoice/transformer/upsample_encoder.py

@@ -396,6 +396,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
         encoders_kv_cache_list = []
         encoders_kv_cache_list = []
         for index, layer in enumerate(self.encoders):
         for index, layer in enumerate(self.encoders):
             xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
             xs, chunk_masks, encoders_kv_cache_new, _ = layer(xs, chunk_masks, pos_emb, mask_pad, encoders_kv_cache[index])
+            encoders_kv_cache_list.append(encoders_kv_cache_new)
         encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
         encoders_kv_cache = torch.stack(encoders_kv_cache_list, dim=0)
 
 
         # upsample
         # upsample
@@ -426,4 +427,4 @@ class UpsampleConformerEncoder(torch.nn.Module):
         # Here we assume the mask is not changed in encoder layers, so just
         # Here we assume the mask is not changed in encoder layers, so just
         # return the masks before encoder layers, and the masks will be used
         # return the masks before encoder layers, and the masks will be used
         # for cross attention with decoder later
         # for cross attention with decoder later
-        return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache_new, upsample_offset, upsample_conv_cache, upsample_kv_cache_new)
+        return xs, masks, (offset, pre_lookahead_layer_conv2_cache, encoders_kv_cache, upsample_offset, upsample_conv_cache, upsample_kv_cache)

+ 18 - 21
examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

@@ -56,7 +56,7 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
         input_size: 512
         input_size: 512
         use_cnn_module: False
         use_cnn_module: False
         macaron_style: False
         macaron_style: False
-        use_dynamic_chunk: True
+        static_chunk_size: !ref <token_frame_rate> # 试试UpsampleConformerEncoder也是static
     decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
     decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
         in_channels: 240
         in_channels: 240
         n_spks: 1
         n_spks: 1
@@ -154,12 +154,9 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
     center: False
     center: False
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
 compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
     feat_extractor: !ref <feat_extractor>
     feat_extractor: !ref <feat_extractor>
-# pitch_extractor: !name:torchaudio.functional.compute_kaldi_pitch # TODO need to replace it
-#     sample_rate: !ref <sample_rate>
-#     frame_length: 46.4 # match feat_extractor win_size/sampling_rate
-#     frame_shift: 11.6 # match feat_extractor hop_size/sampling_rate
-# compute_f0: !name:cosyvoice.dataset.processor.compute_f0
-#     pitch_extractor: !ref <pitch_extractor>
+compute_f0: !name:cosyvoice.dataset.processor.compute_f0
+    sample_rate: !ref <sample_rate>
+    hop_size: 480
 parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
 parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
     normalize: True
     normalize: True
 shuffle: !name:cosyvoice.dataset.processor.shuffle
 shuffle: !name:cosyvoice.dataset.processor.shuffle
@@ -186,20 +183,20 @@ data_pipeline: [
     !ref <batch>,
     !ref <batch>,
     !ref <padding>,
     !ref <padding>,
 ]
 ]
-# data_pipeline_gan: [
-#     !ref <parquet_opener>,
-#     !ref <tokenize>,
-#     !ref <filter>,
-#     !ref <resample>,
-#     !ref <truncate>,
-#     !ref <compute_fbank>,
-#     !ref <compute_f0>,
-#     !ref <parse_embedding>,
-#     !ref <shuffle>,
-#     !ref <sort>,
-#     !ref <batch>,
-#     !ref <padding>,
-# ]
+data_pipeline_gan: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <truncate>,
+    !ref <compute_fbank>,
+    !ref <compute_f0>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
 
 
 # llm flow train conf
 # llm flow train conf
 train_conf:
 train_conf: