|
@@ -13,16 +13,84 @@
|
|
|
# limitations under the License.
|
|
# limitations under the License.
|
|
|
import torch
|
|
import torch
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
|
|
+import torch.nn.functional as F
|
|
|
from einops import pack, rearrange, repeat
|
|
from einops import pack, rearrange, repeat
|
|
|
|
|
+from cosyvoice.utils.common import mask_to_bias
|
|
|
|
|
+from cosyvoice.utils.mask import add_optional_chunk_mask
|
|
|
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
|
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
|
|
from matcha.models.components.transformer import BasicTransformerBlock
|
|
from matcha.models.components.transformer import BasicTransformerBlock
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class Transpose(torch.nn.Module):
|
|
|
|
|
+ def __init__(self, dim0: int, dim1: int):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.dim0 = dim0
|
|
|
|
|
+ self.dim1 = dim1
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor):
|
|
|
|
|
+ x = torch.transpose(x, self.dim0, self.dim1)
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class CausalBlock1D(Block1D):
|
|
|
|
|
+ def __init__(self, dim: int, dim_out: int):
|
|
|
|
|
+ super(CausalBlock1D, self).__init__(dim, dim_out)
|
|
|
|
|
+ self.block = torch.nn.Sequential(
|
|
|
|
|
+ CausalConv1d(dim, dim_out, 3),
|
|
|
|
|
+ Transpose(1, 2),
|
|
|
|
|
+ nn.LayerNorm(dim_out),
|
|
|
|
|
+ Transpose(1, 2),
|
|
|
|
|
+ nn.Mish(),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
|
|
|
|
+ output = self.block(x * mask)
|
|
|
|
|
+ return output * mask
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class CausalResnetBlock1D(ResnetBlock1D):
|
|
|
|
|
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
|
|
|
|
|
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
|
|
|
|
+ self.block1 = CausalBlock1D(dim, dim_out)
|
|
|
|
|
+ self.block2 = CausalBlock1D(dim_out, dim_out)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class CausalConv1d(torch.nn.Conv1d):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ in_channels: int,
|
|
|
|
|
+ out_channels: int,
|
|
|
|
|
+ kernel_size: int,
|
|
|
|
|
+ stride: int = 1,
|
|
|
|
|
+ dilation: int = 1,
|
|
|
|
|
+ groups: int = 1,
|
|
|
|
|
+ bias: bool = True,
|
|
|
|
|
+ padding_mode: str = 'zeros',
|
|
|
|
|
+ device=None,
|
|
|
|
|
+ dtype=None
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
|
|
|
|
|
+ kernel_size, stride,
|
|
|
|
|
+ padding=0, dilation=dilation,
|
|
|
|
|
+ groups=groups, bias=bias,
|
|
|
|
|
+ padding_mode=padding_mode,
|
|
|
|
|
+ device=device, dtype=dtype
|
|
|
|
|
+ )
|
|
|
|
|
+ assert stride == 1
|
|
|
|
|
+ self.causal_padding = (kernel_size - 1, 0)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor):
|
|
|
|
|
+ x = F.pad(x, self.causal_padding)
|
|
|
|
|
+ x = super(CausalConv1d, self).forward(x)
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class ConditionalDecoder(nn.Module):
|
|
class ConditionalDecoder(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
in_channels,
|
|
in_channels,
|
|
|
out_channels,
|
|
out_channels,
|
|
|
|
|
+ causal=False,
|
|
|
channels=(256, 256),
|
|
channels=(256, 256),
|
|
|
dropout=0.05,
|
|
dropout=0.05,
|
|
|
attention_head_dim=64,
|
|
attention_head_dim=64,
|
|
@@ -39,7 +107,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
channels = tuple(channels)
|
|
channels = tuple(channels)
|
|
|
self.in_channels = in_channels
|
|
self.in_channels = in_channels
|
|
|
self.out_channels = out_channels
|
|
self.out_channels = out_channels
|
|
|
-
|
|
|
|
|
|
|
+ self.causal = causal
|
|
|
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
|
|
time_embed_dim = channels[0] * 4
|
|
time_embed_dim = channels[0] * 4
|
|
|
self.time_mlp = TimestepEmbedding(
|
|
self.time_mlp = TimestepEmbedding(
|
|
@@ -56,7 +124,7 @@ class ConditionalDecoder(nn.Module):
|
|
|
input_channel = output_channel
|
|
input_channel = output_channel
|
|
|
output_channel = channels[i]
|
|
output_channel = channels[i]
|
|
|
is_last = i == len(channels) - 1
|
|
is_last = i == len(channels) - 1
|
|
|
- resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
|
|
|
|
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
transformer_blocks = nn.ModuleList(
|
|
transformer_blocks = nn.ModuleList(
|
|
|
[
|
|
[
|
|
|
BasicTransformerBlock(
|
|
BasicTransformerBlock(
|
|
@@ -70,14 +138,14 @@ class ConditionalDecoder(nn.Module):
|
|
|
]
|
|
]
|
|
|
)
|
|
)
|
|
|
downsample = (
|
|
downsample = (
|
|
|
- Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
|
|
|
|
|
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
|
)
|
|
)
|
|
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
|
|
|
|
|
|
|
for _ in range(num_mid_blocks):
|
|
for _ in range(num_mid_blocks):
|
|
|
input_channel = channels[-1]
|
|
input_channel = channels[-1]
|
|
|
out_channels = channels[-1]
|
|
out_channels = channels[-1]
|
|
|
- resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
|
|
|
|
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
|
|
|
|
|
transformer_blocks = nn.ModuleList(
|
|
transformer_blocks = nn.ModuleList(
|
|
|
[
|
|
[
|
|
@@ -99,7 +167,11 @@ class ConditionalDecoder(nn.Module):
|
|
|
input_channel = channels[i] * 2
|
|
input_channel = channels[i] * 2
|
|
|
output_channel = channels[i + 1]
|
|
output_channel = channels[i + 1]
|
|
|
is_last = i == len(channels) - 2
|
|
is_last = i == len(channels) - 2
|
|
|
- resnet = ResnetBlock1D(
|
|
|
|
|
|
|
+ resnet = CausalResnetBlock1D(
|
|
|
|
|
+ dim=input_channel,
|
|
|
|
|
+ dim_out=output_channel,
|
|
|
|
|
+ time_emb_dim=time_embed_dim,
|
|
|
|
|
+ ) if self.causal else ResnetBlock1D(
|
|
|
dim=input_channel,
|
|
dim=input_channel,
|
|
|
dim_out=output_channel,
|
|
dim_out=output_channel,
|
|
|
time_emb_dim=time_embed_dim,
|
|
time_emb_dim=time_embed_dim,
|
|
@@ -119,10 +191,10 @@ class ConditionalDecoder(nn.Module):
|
|
|
upsample = (
|
|
upsample = (
|
|
|
Upsample1D(output_channel, use_conv_transpose=True)
|
|
Upsample1D(output_channel, use_conv_transpose=True)
|
|
|
if not is_last
|
|
if not is_last
|
|
|
- else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
|
|
|
|
|
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
|
|
)
|
|
)
|
|
|
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
|
|
- self.final_block = Block1D(channels[-1], channels[-1])
|
|
|
|
|
|
|
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
|
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
|
|
self.initialize_weights()
|
|
self.initialize_weights()
|
|
|
|
|
|
|
@@ -175,7 +247,9 @@ class ConditionalDecoder(nn.Module):
|
|
|
mask_down = masks[-1]
|
|
mask_down = masks[-1]
|
|
|
x = resnet(x, mask_down, t)
|
|
x = resnet(x, mask_down, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
|
|
|
|
|
|
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
|
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
|
|
|
+ attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
x = transformer_block(
|
|
|
hidden_states=x,
|
|
hidden_states=x,
|
|
@@ -192,7 +266,9 @@ class ConditionalDecoder(nn.Module):
|
|
|
for resnet, transformer_blocks in self.mid_blocks:
|
|
for resnet, transformer_blocks in self.mid_blocks:
|
|
|
x = resnet(x, mask_mid, t)
|
|
x = resnet(x, mask_mid, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
|
|
|
|
|
|
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
|
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
|
|
|
+ attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
x = transformer_block(
|
|
|
hidden_states=x,
|
|
hidden_states=x,
|
|
@@ -207,7 +283,9 @@ class ConditionalDecoder(nn.Module):
|
|
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
|
|
x = resnet(x, mask_up, t)
|
|
x = resnet(x, mask_up, t)
|
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
x = rearrange(x, "b c t -> b t c").contiguous()
|
|
|
- attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
|
|
|
|
|
|
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
|
|
|
|
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
|
|
|
|
+ attn_mask = mask_to_bias(attn_mask==1, x.dtype)
|
|
|
for transformer_block in transformer_blocks:
|
|
for transformer_block in transformer_blocks:
|
|
|
x = transformer_block(
|
|
x = transformer_block(
|
|
|
hidden_states=x,
|
|
hidden_states=x,
|
|
@@ -218,4 +296,4 @@ class ConditionalDecoder(nn.Module):
|
|
|
x = upsample(x * mask_up)
|
|
x = upsample(x * mask_up)
|
|
|
x = self.final_block(x, mask_up)
|
|
x = self.final_block(x, mask_up)
|
|
|
output = self.final_proj(x * mask_up)
|
|
output = self.final_proj(x * mask_up)
|
|
|
- return output * mask
|
|
|
|
|
|
|
+ return output * mask
|