lyuxiang.lx 8 months ago
parent
commit
dd2d926147
2 changed files with 23 additions and 56 deletions
  1. 22 56
      cosyvoice/flow/DiT/dit.py
  2. 1 0
      cosyvoice/flow/DiT/modules.py

+ 22 - 56
cosyvoice/flow/DiT/dit_model.py → cosyvoice/flow/DiT/dit.py

@@ -1,3 +1,4 @@
+
 """
 """
 ein notation:
 ein notation:
 b - batch
 b - batch
@@ -14,9 +15,8 @@ from torch import nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 from einops import repeat
 from einops import repeat
 from x_transformers.x_transformers import RotaryEmbedding
 from x_transformers.x_transformers import RotaryEmbedding
-from funasr.models.transformer.utils.mask import causal_block_mask
-
-from cosyvoice.flow.DiT.dit_modules import (
+from cosyvoice.utils.mask import add_optional_chunk_mask
+from cosyvoice.flow.DiT.modules import (
     TimestepEmbedding,
     TimestepEmbedding,
     ConvNeXtV2Block,
     ConvNeXtV2Block,
     CausalConvPositionEmbedding,
     CausalConvPositionEmbedding,
@@ -115,7 +115,8 @@ class DiT(nn.Module):
         mu_dim=None,
         mu_dim=None,
         long_skip_connection=False,
         long_skip_connection=False,
         spk_dim=None,
         spk_dim=None,
-        **kwargs
+        static_chunk_size=50,
+        num_decoding_left_chunks=2
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -136,50 +137,20 @@ class DiT(nn.Module):
 
 
         self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
         self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
         self.proj_out = nn.Linear(dim, mel_dim)
         self.proj_out = nn.Linear(dim, mel_dim)
-        self.causal_mask_type = kwargs.get("causal_mask_type", None)
-
-    def build_mix_causal_mask(self, attn_mask, rand=None, ratio=None):
-        b, _, _, t = attn_mask.shape
-        if rand is None:
-            rand = torch.rand((b, 1, 1, 1), device=attn_mask.device, dtype=torch.float32)
-        mixed_mask = attn_mask.clone()
-        for item in self.causal_mask_type:
-            prob_min, prob_max = item["prob_min"], item["prob_max"]
-            _ratio = 1
-            if "ratio" in item:
-                _ratio = item["ratio"]
-            if ratio is not None:
-                _ratio = ratio
-            block_size = item["block_size"] * _ratio
-            if block_size <= 0:
-                causal_mask = attn_mask
-            else:
-                causal_mask = causal_block_mask(
-                    t, block_size, attn_mask.device, torch.float32
-                ).unsqueeze(0).unsqueeze(1)  # 1,1,T,T
-            flag = (prob_min <= rand) & (rand < prob_max)
-            mixed_mask = mixed_mask * (~flag) + (causal_mask * attn_mask) * flag
-
-        return mixed_mask
-
-    def forward(
-        self,
-        x: float["b n d"],  # nosied input audio
-        cond: float["b n d"],  # masked cond audio
-        mu: int["b nt d"],  # mu
-        spks: float["b 1 d"],  # spk xvec
-        time: float["b"] | float[""],  # time step
-        return_hidden: bool = False,
-        mask: bool["b 1 n"] | None = None,
-        mask_rand: float["b 1 1"] = None,  # for mask flag type
-        **kwargs,
-    ):
+        self.static_chunk_size = static_chunk_size
+        self.num_decoding_left_chunks = num_decoding_left_chunks
+
+    def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
+        x = x.transpose(1, 2)
+        mu = mu.transpose(1, 2)
+        cond = cond.transpose(1, 2)
+        spks = spks.unsqueeze(dim=1)
         batch, seq_len = x.shape[0], x.shape[1]
         batch, seq_len = x.shape[0], x.shape[1]
-        if time.ndim == 0:
-            time = time.repeat(batch)
+        if t.ndim == 0:
+            t = t.repeat(batch)
 
 
         # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
         # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
-        t = self.time_embed(time)
+        t = self.time_embed(t)
         x = self.input_embed(x, cond, mu, spks.squeeze(1))
         x = self.input_embed(x, cond, mu, spks.squeeze(1))
 
 
         rope = self.rotary_embed.forward_from_seq_len(seq_len)
         rope = self.rotary_embed.forward_from_seq_len(seq_len)
@@ -187,22 +158,17 @@ class DiT(nn.Module):
         if self.long_skip_connection is not None:
         if self.long_skip_connection is not None:
             residual = x
             residual = x
 
 
-        mask = mask.unsqueeze(1)  # B,1,1,T
-        if self.causal_mask_type is not None:
-            mask = self.build_mix_causal_mask(mask, rand=mask_rand.unsqueeze(-1))
+        if streaming is True:
+            attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
+        else:
+            attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
 
 
         for block in self.transformer_blocks:
         for block in self.transformer_blocks:
-            # mask-out padded values for amp training
-            x = x * mask[:, 0, -1, :].unsqueeze(-1)
-            x = block(x, t, mask=mask.bool(), rope=rope)
+            x = block(x, t, mask=attn_mask.bool(), rope=rope)
 
 
         if self.long_skip_connection is not None:
         if self.long_skip_connection is not None:
             x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
             x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
 
 
         x = self.norm_out(x, t)
         x = self.norm_out(x, t)
-        output = self.proj_out(x)
-
-        if return_hidden:
-            return output, None
-
+        output = self.proj_out(x).transpose(1, 2)
         return output
         return output

+ 1 - 0
cosyvoice/flow/DiT/dit_modules.py → cosyvoice/flow/DiT/modules.py

@@ -1,3 +1,4 @@
+
 """
 """
 ein notation:
 ein notation:
 b - batch
 b - batch