Explorar el Código

add flow unified training

lyuxiang.lx hace 10 meses
padre
commit
fd1a951a6c

+ 19 - 9
cosyvoice/flow/decoder.py

@@ -210,6 +210,7 @@ class CausalAttention(Attention):
         upcast_softmax: bool = False,
         cross_attention_norm: Optional[str] = None,
         cross_attention_norm_num_groups: int = 32,
+        qk_norm: Optional[str] = None,
         added_kv_proj_dim: Optional[int] = None,
         norm_num_groups: Optional[int] = None,
         spatial_norm_dim: Optional[int] = None,
@@ -223,7 +224,7 @@ class CausalAttention(Attention):
         processor: Optional["AttnProcessor2_0"] = None,
         out_dim: int = None,
     ):
-        super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups,
+        super(CausalAttention, self).__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, cross_attention_norm, cross_attention_norm_num_groups, qk_norm,
                                               added_kv_proj_dim, norm_num_groups, spatial_norm_dim, out_bias, scale_qk, only_cross_attention, eps, rescale_output_factor, residual_connection, _from_deprecated_attn_block, processor, out_dim)
         processor = CausalAttnProcessor2_0()
         self.set_processor(processor)
@@ -505,7 +506,7 @@ class ConditionalDecoder(nn.Module):
                 if m.bias is not None:
                     nn.init.constant_(m.bias, 0)
 
-    def forward(self, x, mask, mu, t, spks=None, cond=None):
+    def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
         """Forward pass of the UNet1DConditional model.
 
         Args:
@@ -540,7 +541,7 @@ class ConditionalDecoder(nn.Module):
             mask_down = masks[-1]
             x = resnet(x, mask_down, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = (torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) == 1)
+            attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
@@ -558,7 +559,7 @@ class ConditionalDecoder(nn.Module):
         for resnet, transformer_blocks in self.mid_blocks:
             x = resnet(x, mask_mid, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = (torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) == 1)
+            attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
@@ -574,7 +575,7 @@ class ConditionalDecoder(nn.Module):
             x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
             x = resnet(x, mask_up, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = (torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) == 1)
+            attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
@@ -700,7 +701,7 @@ class CausalConditionalDecoder(ConditionalDecoder):
         self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
         self.initialize_weights()
 
-    def forward(self, x, mask, mu, t, spks=None, cond=None):
+    def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
         """Forward pass of the UNet1DConditional model.
 
         Args:
@@ -735,7 +736,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
             mask_down = masks[-1]
             x, _, _ = resnet(x, mask_down, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+            if streaming is True:
+                attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+            else:
+                attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
                 x, _ = transformer_block(
@@ -753,7 +757,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
         for resnet, transformer_blocks in self.mid_blocks:
             x, _, _ = resnet(x, mask_mid, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+            if streaming is True:
+                attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+            else:
+                attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
                 x, _ = transformer_block(
@@ -769,7 +776,10 @@ class CausalConditionalDecoder(ConditionalDecoder):
             x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
             x, _, _ = resnet(x, mask_up, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+            if streaming is True:
+                attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, self.num_decoding_left_chunks)
+            else:
+                attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
             attn_mask = mask_to_bias(attn_mask, x.dtype)
             for transformer_block in transformer_blocks:
                 x, _ = transformer_block(

+ 6 - 2
cosyvoice/flow/flow.py

@@ -202,6 +202,9 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         feat_len = batch['speech_feat_len'].to(device)
         embedding = batch['embedding'].to(device)
 
+        # NOTE unified training, static_chunk_size > 0 or = 0
+        streaming = True if random.random() < 0.5 else False
+
         # xvec projection
         embedding = F.normalize(embedding, dim=1)
         embedding = self.spk_embed_affine_layer(embedding)
@@ -211,7 +214,7 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         token = self.input_embedding(torch.clamp(token, min=0)) * mask
 
         # text encode
-        h, h_lengths = self.encoder(token, token_len)
+        h, h_lengths = self.encoder(token, token_len, streaming=streaming)
         h = self.encoder_proj(h)
 
         # get conditions
@@ -230,7 +233,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
             mask.unsqueeze(1),
             h.transpose(1, 2).contiguous(),
             embedding,
-            cond=conds
+            cond=conds,
+            streaming=streaming,
         )
         return {'loss': loss}
 

+ 2 - 5
cosyvoice/flow/flow_matching.py

@@ -142,7 +142,7 @@ class ConditionalCFM(BASECFM):
                                            x.data_ptr()])
             return x
 
-    def compute_loss(self, x1, mask, mu, spks=None, cond=None):
+    def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
         """Computes diffusion loss
 
         Args:
@@ -179,11 +179,8 @@ class ConditionalCFM(BASECFM):
             spks = spks * cfg_mask.view(-1, 1)
             cond = cond * cfg_mask.view(-1, 1, 1)
 
-        pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
+        pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
         loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
-        if loss.isnan():
-            print(123)
-            pred_new = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
         return loss, y
 
 

+ 11 - 10
cosyvoice/transformer/upsample_encoder.py

@@ -255,6 +255,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
         xs_lens: torch.Tensor,
         decoding_chunk_size: int = 0,
         num_decoding_left_chunks: int = -1,
+        streaming: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Embed positions in tensor.
 
@@ -286,11 +287,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
         xs, pos_emb, masks = self.embed(xs, masks)
         mask_pad = masks  # (B, 1, T/subsample_rate)
         chunk_masks = add_optional_chunk_mask(xs, masks,
-                                              self.use_dynamic_chunk,
-                                              self.use_dynamic_left_chunk,
-                                              decoding_chunk_size,
-                                              self.static_chunk_size,
-                                              num_decoding_left_chunks)
+                                            self.use_dynamic_chunk if streaming is True else False,
+                                            self.use_dynamic_left_chunk if streaming is True else False,
+                                            decoding_chunk_size if streaming is True else 0,
+                                            self.static_chunk_size if streaming is True else 0,
+                                            num_decoding_left_chunks if streaming is True else -1)
         # lookahead + conformer encoder
         xs, _ = self.pre_lookahead_layer(xs)
         xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
@@ -304,11 +305,11 @@ class UpsampleConformerEncoder(torch.nn.Module):
         xs, pos_emb, masks = self.up_embed(xs, masks)
         mask_pad = masks  # (B, 1, T/subsample_rate)
         chunk_masks = add_optional_chunk_mask(xs, masks,
-                                              self.use_dynamic_chunk,
-                                              self.use_dynamic_left_chunk,
-                                              decoding_chunk_size,
-                                              self.static_chunk_size * self.up_layer.stride,
-                                              num_decoding_left_chunks)
+                                            self.use_dynamic_chunk if streaming is True else False,
+                                            self.use_dynamic_left_chunk if streaming is True else False,
+                                            decoding_chunk_size if streaming is True else 0,
+                                            self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
+                                            num_decoding_left_chunks if streaming is True else -1)
         xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
 
         if self.normalize_before: