1
0
Selaa lähdekoodia

add cosyvoice2

lyuxiang.lx 9 kuukautta sitten
vanhempi
commit
3e381002d7

+ 2 - 1
.gitignore

@@ -48,4 +48,5 @@ compile_commands.json
 *.pt
 pretrained_models/*
 *_pb2_grpc.py
-*_pb2.py
+*_pb2.py
+*.tar

+ 33 - 1
cosyvoice/cli/cosyvoice.py

@@ -18,7 +18,7 @@ from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
 import torch
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
-from cosyvoice.cli.model import CosyVoiceModel
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
 from cosyvoice.utils.file_utils import logging
 
 
@@ -118,3 +118,35 @@ class CosyVoice:
             logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
             yield model_output
             start_time = time.time()
+
+class CosyVoice2(CosyVoice):
+
+    def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
+        instruct = True if '-Instruct' in model_dir else False
+        self.model_dir = model_dir
+        if not os.path.exists(model_dir):
+            model_dir = snapshot_download(model_dir)
+        with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
+            configs = load_hyperpyyaml(f)
+        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+                                          configs['feat_extractor'],
+                                          '{}/campplus.onnx'.format(model_dir),
+                                          '{}/speech_tokenizer_v2.onnx'.format(model_dir),
+                                          '{}/spk2info.pt'.format(model_dir),
+                                          instruct,
+                                          configs['allowed_special'])
+        if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
+            load_jit = False
+            fp16 = False
+            logging.warning('cpu do not support fp16 and jit, force set to False')
+        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
+        self.model.load('{}/llm.pt'.format(model_dir),
+                        '{}/flow.pt'.format(model_dir),
+                        '{}/hift.pt'.format(model_dir))
+        if load_jit:
+            self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
+                                '{}/llm.llm.fp16.zip'.format(model_dir),
+                                '{}/flow.encoder.fp32.zip'.format(model_dir))
+        if load_onnx:
+            self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
+        del configs

+ 175 - 3
cosyvoice/cli/model.py

@@ -57,15 +57,15 @@ class CosyVoiceModel:
         self.hift_cache_dict = {}
 
     def load(self, llm_model, flow_model, hift_model):
-        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
+        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
         self.llm.to(self.device).eval()
         if self.fp16 is True:
             self.llm.half()
-        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
+        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
         self.flow.to(self.device).eval()
         # in case hift_model is a hifigan model
         hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
-        self.hift.load_state_dict(hift_state_dict, strict=False)
+        self.hift.load_state_dict(hift_state_dict, strict=True)
         self.hift.to(self.device).eval()
 
     def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
@@ -254,3 +254,175 @@ class CosyVoiceModel:
             self.llm_end_dict.pop(this_uuid)
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
+
+
+class CosyVoice2Model:
+
+    def __init__(self,
+                 llm: torch.nn.Module,
+                 flow: torch.nn.Module,
+                 hift: torch.nn.Module,
+                 fp16: bool):
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.llm = llm
+        self.flow = flow
+        self.hift = hift
+        self.fp16 = fp16
+        self.token_min_hop_len = 1 * self.flow.input_frame_rate
+        self.token_max_hop_len = 2 * self.flow.input_frame_rate
+        self.token_right_context = self.flow.encoder.pre_lookahead_layer.pre_lookahead_len
+        # hift cache
+        self.mel_cache_len = 8
+        self.source_cache_len = int(self.mel_cache_len * 480)
+        # speech fade in out
+        self.speech_window = np.hamming(2 * self.source_cache_len)
+        # rtf and decoding related
+        self.stream_scale_factor = 1
+        assert self.stream_scale_factor == 1, 'fix stream_scale_factor to 1 as we haven\'t implement cache in flow matching yet, this constraint will be loosen in the future'
+        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.lock = threading.Lock()
+        # dict used to store session related variable
+        self.tts_speech_token_dict = {}
+        self.llm_end_dict = {}
+        self.hift_cache_dict = {}
+
+    def load(self, llm_model, flow_model, hift_model):
+        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
+        self.llm.to(self.device).eval()
+        if self.fp16 is True:
+            self.llm.half()
+        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
+        self.flow.to(self.device).eval()
+        # in case hift_model is a hifigan model
+        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
+        self.hift.load_state_dict(hift_state_dict, strict=True)
+        self.hift.to(self.device).eval()
+
+    def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
+        assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
+        llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
+        self.llm.text_encoder = llm_text_encoder
+        llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
+        self.llm.llm = llm_llm
+        flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
+        self.flow.encoder = flow_encoder
+
+    def load_onnx(self, flow_decoder_estimator_model):
+        import onnxruntime
+        option = onnxruntime.SessionOptions()
+        option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+        option.intra_op_num_threads = 1
+        providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+        del self.flow.decoder.estimator
+        self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
+
+    def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
+        if self.fp16 is True:
+            llm_embedding = llm_embedding.half()
+        with self.llm_context:
+            for i in self.llm.inference(text=text.to(self.device),
+                                        text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_text=prompt_text.to(self.device),
+                                        prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
+                                        prompt_speech_token=llm_prompt_speech_token.to(self.device),
+                                        prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
+                                        embedding=llm_embedding.to(self.device)):
+                self.tts_speech_token_dict[uuid].append(i)
+        self.llm_end_dict[uuid] = True
+
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
+        tts_mel, _ = self.flow.inference(token=token.to(self.device),
+                                                  token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                                  prompt_token=prompt_token.to(self.device),
+                                                  prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                                  prompt_feat=prompt_feat.to(self.device),
+                                                  prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                                  embedding=embedding.to(self.device),
+                                                  finalize=finalize)
+        tts_mel = tts_mel[:, :, token_offset * self.flow.encoder.up_layer.stride:]
+        # append hift cache
+        if self.hift_cache_dict[uuid] is not None:
+            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
+            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+        else:
+            hift_cache_source = torch.zeros(1, 1, 0)
+        # keep overlap mel and hift cache
+        if finalize is False:
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
+                                          'source': tts_source[:, :, -self.source_cache_len:],
+                                          'speech': tts_speech[:, -self.source_cache_len:]}
+            tts_speech = tts_speech[:, :-self.source_cache_len]
+        else:
+            if speed != 1.0:
+                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
+                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
+            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
+            if self.hift_cache_dict[uuid] is not None:
+                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
+        return tts_speech
+
+    def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
+            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
+            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
+            prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
+        # this_uuid is used to track variables related to this inference thread
+        this_uuid = str(uuid.uuid1())
+        with self.lock:
+            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
+            self.hift_cache_dict[this_uuid] = None
+        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
+        p.start()
+        if stream is True:
+            token_hop_len, token_offset = self.token_min_hop_len, 0
+            self.flow.encoder.static_chunk_size = self.token_min_hop_len
+            self.flow.decoder.estimator.static_chunk_size = self.token_min_hop_len * self.flow.encoder.up_layer.stride
+            while True:
+                time.sleep(0.1)
+                if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= token_hop_len + self.token_right_context:
+                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + token_hop_len + self.token_right_context]) \
+                        .unsqueeze(dim=0)
+                    this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                                     prompt_token=flow_prompt_speech_token,
+                                                     prompt_feat=prompt_speech_feat,
+                                                     embedding=flow_embedding,
+                                                     uuid=this_uuid,
+                                                     token_offset=token_offset,
+                                                     finalize=False)
+                    token_offset += token_hop_len
+                    yield {'tts_speech': this_tts_speech.cpu()}
+                    # increase token_hop_len for better speech quality
+                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
+                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < token_hop_len + self.token_right_context:
+                    break
+            p.join()
+            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
+            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
+            this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                             prompt_token=flow_prompt_speech_token,
+                                             prompt_feat=prompt_speech_feat,
+                                             embedding=flow_embedding,
+                                             uuid=this_uuid,
+                                             token_offset=token_offset,
+                                             finalize=True)
+            yield {'tts_speech': this_tts_speech.cpu()}
+        else:
+            # deal with all tokens
+            p.join()
+            self.flow.encoder.static_chunk_size = 0
+            self.flow.decoder.estimator.static_chunk_size = 0
+            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
+            this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                             prompt_token=flow_prompt_speech_token,
+                                             prompt_feat=prompt_speech_feat,
+                                             embedding=flow_embedding,
+                                             uuid=this_uuid,
+                                             finalize=True,
+                                             speed=speed)
+            yield {'tts_speech': this_tts_speech.cpu()}
+        with self.lock:
+            self.tts_speech_token_dict.pop(this_uuid)
+            self.llm_end_dict.pop(this_uuid)

+ 89 - 11
cosyvoice/flow/decoder.py

@@ -13,16 +13,84 @@
 # limitations under the License.
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 from einops import pack, rearrange, repeat
+from cosyvoice.utils.common import mask_to_bias
+from cosyvoice.utils.mask import add_optional_chunk_mask
 from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
 from matcha.models.components.transformer import BasicTransformerBlock
 
 
+class Transpose(torch.nn.Module):
+    def __init__(self, dim0: int, dim1: int):
+        super().__init__()
+        self.dim0 = dim0
+        self.dim1 = dim1
+
+    def forward(self, x: torch.Tensor):
+        x = torch.transpose(x, self.dim0, self.dim1)
+        return x
+
+
+class CausalBlock1D(Block1D):
+    def __init__(self, dim: int, dim_out: int):
+        super(CausalBlock1D, self).__init__(dim, dim_out)
+        self.block = torch.nn.Sequential(
+            CausalConv1d(dim, dim_out, 3),
+            Transpose(1, 2),
+            nn.LayerNorm(dim_out),
+            Transpose(1, 2),
+            nn.Mish(),
+        )
+
+    def forward(self, x: torch.Tensor, mask: torch.Tensor):
+        output = self.block(x * mask)
+        return output * mask
+
+
+class CausalResnetBlock1D(ResnetBlock1D):
+    def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int=8):
+        super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
+        self.block1 = CausalBlock1D(dim, dim_out)
+        self.block2 = CausalBlock1D(dim_out, dim_out)
+
+
+class CausalConv1d(torch.nn.Conv1d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = 'zeros',
+        device=None,
+        dtype=None
+    ) -> None:
+        super(CausalConv1d, self).__init__(in_channels, out_channels,
+            kernel_size, stride,
+            padding=0, dilation=dilation,
+            groups=groups, bias=bias,
+            padding_mode=padding_mode,
+            device=device, dtype=dtype
+        )
+        assert stride == 1
+        self.causal_padding = (kernel_size - 1, 0)
+
+    def forward(self, x: torch.Tensor):
+        x = F.pad(x, self.causal_padding)
+        x = super(CausalConv1d, self).forward(x)
+        return x
+
+
 class ConditionalDecoder(nn.Module):
     def __init__(
         self,
         in_channels,
         out_channels,
+        causal=False,
         channels=(256, 256),
         dropout=0.05,
         attention_head_dim=64,
@@ -39,7 +107,7 @@ class ConditionalDecoder(nn.Module):
         channels = tuple(channels)
         self.in_channels = in_channels
         self.out_channels = out_channels
-
+        self.causal = causal
         self.time_embeddings = SinusoidalPosEmb(in_channels)
         time_embed_dim = channels[0] * 4
         self.time_mlp = TimestepEmbedding(
@@ -56,7 +124,7 @@ class ConditionalDecoder(nn.Module):
             input_channel = output_channel
             output_channel = channels[i]
             is_last = i == len(channels) - 1
-            resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+            resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
             transformer_blocks = nn.ModuleList(
                 [
                     BasicTransformerBlock(
@@ -70,14 +138,14 @@ class ConditionalDecoder(nn.Module):
                 ]
             )
             downsample = (
-                Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+                Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
             )
             self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
 
         for _ in range(num_mid_blocks):
             input_channel = channels[-1]
             out_channels = channels[-1]
-            resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
+            resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
 
             transformer_blocks = nn.ModuleList(
                 [
@@ -99,7 +167,11 @@ class ConditionalDecoder(nn.Module):
             input_channel = channels[i] * 2
             output_channel = channels[i + 1]
             is_last = i == len(channels) - 2
-            resnet = ResnetBlock1D(
+            resnet = CausalResnetBlock1D(
+                dim=input_channel,
+                dim_out=output_channel,
+                time_emb_dim=time_embed_dim,
+            ) if self.causal else ResnetBlock1D(
                 dim=input_channel,
                 dim_out=output_channel,
                 time_emb_dim=time_embed_dim,
@@ -119,10 +191,10 @@ class ConditionalDecoder(nn.Module):
             upsample = (
                 Upsample1D(output_channel, use_conv_transpose=True)
                 if not is_last
-                else nn.Conv1d(output_channel, output_channel, 3, padding=1)
+                else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
             )
             self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
-        self.final_block = Block1D(channels[-1], channels[-1])
+        self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
         self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
         self.initialize_weights()
 
@@ -175,7 +247,9 @@ class ConditionalDecoder(nn.Module):
             mask_down = masks[-1]
             x = resnet(x, mask_down, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
+            # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
+            attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
+            attn_mask = mask_to_bias(attn_mask==1, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
                     hidden_states=x,
@@ -192,7 +266,9 @@ class ConditionalDecoder(nn.Module):
         for resnet, transformer_blocks in self.mid_blocks:
             x = resnet(x, mask_mid, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
+            # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
+            attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
+            attn_mask = mask_to_bias(attn_mask==1, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
                     hidden_states=x,
@@ -207,7 +283,9 @@ class ConditionalDecoder(nn.Module):
             x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
             x = resnet(x, mask_up, t)
             x = rearrange(x, "b c t -> b t c").contiguous()
-            attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
+            # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
+            attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
+            attn_mask = mask_to_bias(attn_mask==1, x.dtype)
             for transformer_block in transformer_blocks:
                 x = transformer_block(
                     hidden_states=x,
@@ -218,4 +296,4 @@ class ConditionalDecoder(nn.Module):
             x = upsample(x * mask_up)
         x = self.final_block(x, mask_up)
         output = self.final_proj(x * mask_up)
-        return output * mask
+        return output * mask

+ 80 - 0
cosyvoice/flow/flow.py

@@ -146,3 +146,83 @@ class MaskedDiffWithXvec(torch.nn.Module):
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2
         return feat, flow_cache
+
+
+class CausalMaskedDiffWithXvec(torch.nn.Module):
+    def __init__(self,
+                 input_size: int = 512,
+                 output_size: int = 80,
+                 spk_embed_dim: int = 192,
+                 output_type: str = "mel",
+                 vocab_size: int = 4096,
+                 input_frame_rate: int = 50,
+                 only_mask_loss: bool = True,
+                 encoder: torch.nn.Module = None,
+                 decoder: torch.nn.Module = None,
+                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
+                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
+                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
+                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
+                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
+                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
+                                        'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
+        super().__init__()
+        self.input_size = input_size
+        self.output_size = output_size
+        self.decoder_conf = decoder_conf
+        self.mel_feat_conf = mel_feat_conf
+        self.vocab_size = vocab_size
+        self.output_type = output_type
+        self.input_frame_rate = input_frame_rate
+        logging.info(f"input frame rate={self.input_frame_rate}")
+        self.input_embedding = nn.Embedding(vocab_size, input_size)
+        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
+        self.encoder = encoder
+        self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
+        self.decoder = decoder
+        self.only_mask_loss = only_mask_loss
+
+    @torch.inference_mode()
+    def inference(self,
+                  token,
+                  token_len,
+                  prompt_token,
+                  prompt_token_len,
+                  prompt_feat,
+                  prompt_feat_len,
+                  embedding,
+                  finalize):
+        assert token.shape[0] == 1
+        # xvec projection
+        embedding = F.normalize(embedding, dim=1)
+        embedding = self.spk_embed_affine_layer(embedding)
+
+        # concat text and prompt_text
+        token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
+        token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
+        mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
+        token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+        # text encode
+        h, h_lengths = self.encoder(token, token_len)
+        if finalize is False:
+            h = h[:, :-self.encoder.pre_lookahead_layer.pre_lookahead_len * self.encoder.up_layer.stride]
+        mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] -  prompt_feat.shape[1]
+        h = self.encoder_proj(h)
+
+        # get conditions
+        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
+        conds[:, :mel_len1] = prompt_feat
+        conds = conds.transpose(1, 2)
+
+        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
+        feat, _ = self.decoder(
+            mu=h.transpose(1, 2).contiguous(),
+            mask=mask.unsqueeze(1),
+            spks=embedding,
+            cond=conds,
+            n_timesteps=10
+        )
+        feat = feat[:, :, mel_len1:]
+        assert feat.shape[2] == mel_len2
+        return feat, None

+ 51 - 9
cosyvoice/flow/flow_matching.py

@@ -89,17 +89,25 @@ class ConditionalCFM(BASECFM):
         sol = []
 
         for step in range(1, len(t_span)):
-            dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
             # Classifier-Free Guidance inference introduced in VoiceBox
             if self.inference_cfg_rate > 0:
-                cfg_dphi_dt = self.forward_estimator(
-                    x, mask,
-                    torch.zeros_like(mu), t,
-                    torch.zeros_like(spks) if spks is not None else None,
-                    torch.zeros_like(cond)
-                )
-                dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
-                           self.inference_cfg_rate * cfg_dphi_dt)
+                x_in = torch.concat([x, x], dim=0)
+                mask_in = torch.concat([mask, mask], dim=0)
+                mu_in = torch.concat([mu, torch.zeros_like(mu).to(x.device)], dim=0)
+                t_in = torch.concat([t, t], dim=0)
+                spks_in = torch.concat([spks, torch.zeros_like(spks).to(x.device)], dim=0) if spks is not None else None
+                cond_in = torch.concat([cond, torch.zeros_like(cond).to(x.device)], dim=0) if cond is not None else None
+            else:
+                x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
+            dphi_dt = self.forward_estimator(
+                x_in, mask_in,
+                mu_in, t_in,
+                spks_in,
+                cond_in
+            )
+            if self.inference_cfg_rate > 0:
+                dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
+                dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
             x = x + dt * dphi_dt
             t = t + dt
             sol.append(x)
@@ -163,3 +171,37 @@ class ConditionalCFM(BASECFM):
         pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
         loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
         return loss, y
+
+
+class CausalConditionalCFM(ConditionalCFM):
+    def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
+        super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
+        self.rand_noise = torch.randn([1, 80, 50 * 300])
+
+    @torch.inference_mode()
+    def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
+        """Forward diffusion
+
+        Args:
+            mu (torch.Tensor): output of encoder
+                shape: (batch_size, n_feats, mel_timesteps)
+            mask (torch.Tensor): output_mask
+                shape: (batch_size, 1, mel_timesteps)
+            n_timesteps (int): number of diffusion steps
+            temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+            spks (torch.Tensor, optional): speaker ids. Defaults to None.
+                shape: (batch_size, spk_emb_dim)
+            cond: Not used but kept for future purposes
+
+        Returns:
+            sample: generated mel-spectrogram
+                shape: (batch_size, n_feats, mel_timesteps)
+        """
+
+        z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
+        z[:] = 0
+        # fix prompt and overlap part mu and z
+        t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
+        if self.t_scheduler == 'cosine':
+            t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
+        return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None

+ 36 - 0
cosyvoice/tokenizer/tokenizer.py

@@ -2,6 +2,8 @@ import base64
 import os
 from functools import lru_cache
 from typing import Optional
+import torch
+from transformers import AutoTokenizer
 from whisper.tokenizer import Tokenizer
 
 import tiktoken
@@ -234,3 +236,37 @@ def get_tokenizer(
     return Tokenizer(
         encoding=encoding, num_languages=num_languages, language=language, task=task
     )
+
+
+class QwenTokenizer():
+    def __init__(self, token_path, skip_special_tokens=True):
+        special_tokens = {
+            'eos_token': '<|endoftext|>',
+            'pad_token': '<|endoftext|>',
+            'additional_special_tokens': [
+                '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
+                '[breath]', '<strong>', '</strong>', '[noise]',
+                '[laughter]', '[cough]', '[clucking]', '[accent]',
+                '[quick_breath]',
+            ]
+        }
+        self.tokenizer = AutoTokenizer.from_pretrained(token_path)
+        self.tokenizer.add_special_tokens(special_tokens)
+        self.skip_special_tokens = skip_special_tokens
+
+    def encode(self, text, **kwargs):
+        tokens = self.tokenizer([text], return_tensors="pt")
+        tokens = tokens["input_ids"][0].cpu().tolist()
+        return tokens
+
+    def decode(self, tokens):
+        tokens = torch.tensor(tokens, dtype=torch.int64)
+        text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
+        return text
+
+@lru_cache(maxsize=None)
+def get_qwen_tokenizer(
+    token_path: str,
+    skip_special_tokens: bool
+) -> QwenTokenizer:
+    return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)

+ 7 - 7
cosyvoice/transformer/encoder_layer.py

@@ -49,8 +49,8 @@ class TransformerEncoderLayer(nn.Module):
         super().__init__()
         self.self_attn = self_attn
         self.feed_forward = feed_forward
-        self.norm1 = nn.LayerNorm(size, eps=1e-5)
-        self.norm2 = nn.LayerNorm(size, eps=1e-5)
+        self.norm1 = nn.LayerNorm(size, eps=1e-12)
+        self.norm2 = nn.LayerNorm(size, eps=1e-12)
         self.dropout = nn.Dropout(dropout_rate)
         self.size = size
         self.normalize_before = normalize_before
@@ -142,17 +142,17 @@ class ConformerEncoderLayer(nn.Module):
         self.feed_forward = feed_forward
         self.feed_forward_macaron = feed_forward_macaron
         self.conv_module = conv_module
-        self.norm_ff = nn.LayerNorm(size, eps=1e-5)  # for the FNN module
-        self.norm_mha = nn.LayerNorm(size, eps=1e-5)  # for the MHA module
+        self.norm_ff = nn.LayerNorm(size, eps=1e-12)  # for the FNN module
+        self.norm_mha = nn.LayerNorm(size, eps=1e-12)  # for the MHA module
         if feed_forward_macaron is not None:
-            self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
+            self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
             self.ff_scale = 0.5
         else:
             self.ff_scale = 1.0
         if self.conv_module is not None:
-            self.norm_conv = nn.LayerNorm(size, eps=1e-5)  # for the CNN module
+            self.norm_conv = nn.LayerNorm(size, eps=1e-12)  # for the CNN module
             self.norm_final = nn.LayerNorm(
-                size, eps=1e-5)  # for the final output of the block
+                size, eps=1e-12)  # for the final output of the block
         self.dropout = nn.Dropout(dropout_rate)
         self.size = size
         self.normalize_before = normalize_before

+ 11 - 0
cosyvoice/utils/common.py

@@ -153,3 +153,14 @@ def set_all_random_seed(seed):
     np.random.seed(seed)
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
+
+
+def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+    assert mask.dtype == torch.bool
+    assert dtype in [torch.float32, torch.bfloat16, torch.float16]
+    mask = mask.to(dtype)
+    # attention mask bias
+    # NOTE(Mddct): torch.finfo jit issues
+    #     chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
+    mask = (1.0 - mask) * torch.finfo(dtype).min
+    return mask