浏览代码

Merge remote-tracking branch 'origin/inference_streaming' into inference_streaming

禾息 1 年之前
父节点
当前提交
752103a307

+ 30 - 0
README.md

@@ -4,6 +4,36 @@
 
 
 For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVoice) and [SenseVoice space](https://www.modelscope.cn/studios/iic/SenseVoice).
 For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVoice) and [SenseVoice space](https://www.modelscope.cn/studios/iic/SenseVoice).
 
 
+## Roadmap
+
+- [x] 2024/07
+
+    - [x] Flow matching training support
+    - [x] WeTextProcessing support when ttsfrd is not avaliable
+    - [x] Fastapi server and client
+
+- [ ] 2024/08
+
+    - [ ] Repetition Aware Sampling(RAS) inference for llm stability
+    - [ ] Streaming inference mode support, including kv cache and sdpa for rtf optimization
+
+- [ ] 2024/09
+
+    - [ ] 50hz llm model which supports 10 language
+
+- [ ] 2024/10
+
+    - [ ] 50hz llama based llm model which supports lora finetune
+
+- [ ] TBD
+
+    - [ ] Support more instruction mode
+    - [ ] Voice conversion
+    - [ ] Music generation
+    - [ ] Training script sample based on Mandarin
+    - [ ] CosyVoice-500M trained with more multi-lingual data
+    - [ ] More...
+
 ## Install
 ## Install
 
 
 **Clone and install**
 **Clone and install**

+ 2 - 3
cosyvoice/cli/model.py

@@ -159,7 +159,6 @@ class CosyVoiceModel:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
         p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         p.start()
         p.start()
-        p.join()
         if stream is True:
         if stream is True:
             token_hop_len = self.token_min_hop_len
             token_hop_len = self.token_min_hop_len
             while True:
             while True:
@@ -180,7 +179,7 @@ class CosyVoiceModel:
                     token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
                     token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
                 if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
                 if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
                     break
                     break
-            # p.join()
+            p.join()
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
             with self.flow_hift_context:
@@ -193,7 +192,7 @@ class CosyVoiceModel:
             yield {'tts_speech': this_tts_speech.cpu()}
             yield {'tts_speech': this_tts_speech.cpu()}
         else:
         else:
             # deal with all tokens
             # deal with all tokens
-            # p.join()
+            p.join()
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
             with self.flow_hift_context:
                 this_tts_speech = self.token2wav(token=this_tts_speech_token,
                 this_tts_speech = self.token2wav(token=this_tts_speech_token,

+ 6 - 0
cosyvoice/flow/flow.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 import logging
 import logging
+import random
 from typing import Dict, Optional
 from typing import Dict, Optional
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
 
 
         # get conditions
         # get conditions
         conds = torch.zeros(feat.shape, device=token.device)
         conds = torch.zeros(feat.shape, device=token.device)
+        for i, j in enumerate(feat_len):
+            if random.random() < 0.5:
+                continue
+            index = random.randint(0, int(0.3 * j))
+            conds[i, :index] = feat[i, :index]
         conds = conds.transpose(1, 2)
         conds = conds.transpose(1, 2)
 
 
         mask = (~make_pad_mask(feat_len)).to(h)
         mask = (~make_pad_mask(feat_len)).to(h)

+ 2 - 2
cosyvoice/flow/flow_matching.py

@@ -82,10 +82,10 @@ class ConditionalCFM(BASECFM):
         sol = []
         sol = []
 
 
         for step in range(1, len(t_span)):
         for step in range(1, len(t_span)):
-            dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
+            dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
             # Classifier-Free Guidance inference introduced in VoiceBox
             # Classifier-Free Guidance inference introduced in VoiceBox
             if self.inference_cfg_rate > 0:
             if self.inference_cfg_rate > 0:
-                cfg_dphi_dt = self.forward_estimator(
+                cfg_dphi_dt = self.estimator(
                     x, mask,
                     x, mask,
                     torch.zeros_like(mu), t,
                     torch.zeros_like(mu), t,
                     torch.zeros_like(spks) if spks is not None else None,
                     torch.zeros_like(spks) if spks is not None else None,

+ 1 - 1
cosyvoice/transformer/encoder.py

@@ -299,7 +299,7 @@ class BaseEncoder(torch.nn.Module):
                rate.
                rate.
             3. Currently, nn.Sequential is used to stack all the convolution
             3. Currently, nn.Sequential is used to stack all the convolution
                layers in subsampling, we need to rewrite it to make it work
                layers in subsampling, we need to rewrite it to make it work
-               with cache, which is not prefered.
+               with cache, which is not preferred.
         Args:
         Args:
             xs (torch.Tensor): (1, max_len, dim)
             xs (torch.Tensor): (1, max_len, dim)
             chunk_size (int): decoding chunk size
             chunk_size (int): decoding chunk size

+ 198 - 0
examples/magicdata-read/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -0,0 +1,198 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 22050
+text_encoder_input_size: 512
+llm_input_size: 1024
+llm_output_size: 1024
+spk_embed_dim: 192
+
+# model params
+# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+llm: !new:cosyvoice.llm.llm.TransformerLM
+    text_encoder_input_size: !ref <text_encoder_input_size>
+    llm_input_size: !ref <llm_input_size>
+    llm_output_size: !ref <llm_output_size>
+    text_token_size: 51866
+    speech_token_size: 4096
+    length_normalized_loss: True
+    lsm_weight: 0
+    spk_embed_dim: !ref <spk_embed_dim>
+    text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+        input_size: !ref <text_encoder_input_size>
+        output_size: 1024
+        attention_heads: 8
+        linear_units: 2048
+        num_blocks: 3
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.0
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        use_cnn_module: False
+        macaron_style: False
+        use_dynamic_chunk: False
+        use_dynamic_left_chunk: False
+        static_chunk_size: 1
+    llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
+        input_size: !ref <llm_input_size>
+        output_size: !ref <llm_output_size>
+        attention_heads: 8
+        linear_units: 2048
+        num_blocks: 7
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.0
+        input_layer: 'linear_legacy'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        static_chunk_size: 1
+
+flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
+    input_size: 512
+    output_size: 80
+    spk_embed_dim: !ref <spk_embed_dim>
+    output_type: 'mel'
+    vocab_size: 4096
+    input_frame_rate: 50
+    only_mask_loss: True
+    encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+        output_size: 512
+        attention_heads: 4
+        linear_units: 1024
+        num_blocks: 3
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.1
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        input_size: 512
+        use_cnn_module: False
+        macaron_style: False
+    length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
+        channels: 80
+        sampling_ratios: [1, 1, 1, 1]
+    decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
+        in_channels: 240
+        n_spks: 1
+        spk_emb_dim: 80
+        cfm_params: !new:omegaconf.DictConfig
+            content:
+                sigma_min: 1e-06
+                solver: 'euler'
+                t_scheduler: 'cosine'
+                training_cfg_rate: 0.2
+                inference_cfg_rate: 0.7
+                reg_loss_type: 'l1'
+        estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
+            in_channels: 320
+            out_channels: 80
+            channels: [256, 256]
+            dropout: 0.0
+            attention_head_dim: 64
+            n_blocks: 4
+            num_mid_blocks: 8
+            num_heads: 8
+            act_fn: 'gelu'
+
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+    in_channels: 80
+    base_channels: 512
+    nb_harmonics: 8
+    sampling_rate: !ref <sample_rate>
+    nsf_alpha: 0.1
+    nsf_sigma: 0.003
+    nsf_voiced_threshold: 10
+    upsample_rates: [8, 8]
+    upsample_kernel_sizes: [16, 16]
+    istft_params:
+        n_fft: 16
+        hop_len: 4
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    source_resblock_kernel_sizes: [7, 11]
+    source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
+    lrelu_slope: 0.1
+    audio_limit: 0.99
+    f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+        num_class: 1
+        in_channels: 80
+        cond_channels: 512
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:whisper.tokenizer.get_tokenizer
+    multilingual: True
+    num_languages: 100
+    language: 'en'
+    task: 'transcribe'
+allowed_special: 'all'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+    get_tokenizer: !ref <get_tokenizer>
+    allowed_special: !ref <allowed_special>
+filter: !name:cosyvoice.dataset.processor.filter
+    max_length: 40960
+    min_length: 0
+    token_max_length: 200
+    token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+    resample_rate: !ref <sample_rate>
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+    feat_extractor: !ref <feat_extractor>
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+    normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+    shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+    sort_size: 500  # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+    batch_type: 'dynamic'
+    max_frames_in_batch: 12000
+padding: !name:cosyvoice.dataset.processor.padding
+    use_spk_embedding: False # change to True during sft
+
+# dataset processor pipeline
+data_pipeline: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <compute_fbank>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+
+# train conf
+train_conf:
+    optim: adam
+    optim_conf:
+        lr: 0.002 # change to 0.001 if you want to train flow from scratch
+    scheduler: warmuplr
+    scheduler_conf:
+        warmup_steps: 25000
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 2
+    log_interval: 100
+    save_per_step: -1

+ 198 - 0
examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml

@@ -0,0 +1,198 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 22050
+text_encoder_input_size: 512
+llm_input_size: 1024
+llm_output_size: 1024
+spk_embed_dim: 192
+
+# model params
+# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+llm: !new:cosyvoice.llm.llm.TransformerLM
+    text_encoder_input_size: !ref <text_encoder_input_size>
+    llm_input_size: !ref <llm_input_size>
+    llm_output_size: !ref <llm_output_size>
+    text_token_size: 51866
+    speech_token_size: 4096
+    length_normalized_loss: True
+    lsm_weight: 0
+    spk_embed_dim: !ref <spk_embed_dim>
+    text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+        input_size: !ref <text_encoder_input_size>
+        output_size: 1024
+        attention_heads: 16
+        linear_units: 4096
+        num_blocks: 6
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.0
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        use_cnn_module: False
+        macaron_style: False
+        use_dynamic_chunk: False
+        use_dynamic_left_chunk: False
+        static_chunk_size: 1
+    llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
+        input_size: !ref <llm_input_size>
+        output_size: !ref <llm_output_size>
+        attention_heads: 16
+        linear_units: 4096
+        num_blocks: 14
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.0
+        input_layer: 'linear_legacy'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        static_chunk_size: 1
+
+flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
+    input_size: 512
+    output_size: 80
+    spk_embed_dim: !ref <spk_embed_dim>
+    output_type: 'mel'
+    vocab_size: 4096
+    input_frame_rate: 50
+    only_mask_loss: True
+    encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+        output_size: 512
+        attention_heads: 8
+        linear_units: 2048
+        num_blocks: 6
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.1
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        input_size: 512
+        use_cnn_module: False
+        macaron_style: False
+    length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
+        channels: 80
+        sampling_ratios: [1, 1, 1, 1]
+    decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
+        in_channels: 240
+        n_spks: 1
+        spk_emb_dim: 80
+        cfm_params: !new:omegaconf.DictConfig
+            content:
+                sigma_min: 1e-06
+                solver: 'euler'
+                t_scheduler: 'cosine'
+                training_cfg_rate: 0.2
+                inference_cfg_rate: 0.7
+                reg_loss_type: 'l1'
+        estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
+            in_channels: 320
+            out_channels: 80
+            channels: [256, 256]
+            dropout: 0.0
+            attention_head_dim: 64
+            n_blocks: 4
+            num_mid_blocks: 12
+            num_heads: 8
+            act_fn: 'gelu'
+
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+    in_channels: 80
+    base_channels: 512
+    nb_harmonics: 8
+    sampling_rate: !ref <sample_rate>
+    nsf_alpha: 0.1
+    nsf_sigma: 0.003
+    nsf_voiced_threshold: 10
+    upsample_rates: [8, 8]
+    upsample_kernel_sizes: [16, 16]
+    istft_params:
+        n_fft: 16
+        hop_len: 4
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    source_resblock_kernel_sizes: [7, 11]
+    source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
+    lrelu_slope: 0.1
+    audio_limit: 0.99
+    f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+        num_class: 1
+        in_channels: 80
+        cond_channels: 512
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:whisper.tokenizer.get_tokenizer
+    multilingual: True
+    num_languages: 100
+    language: 'en'
+    task: 'transcribe'
+allowed_special: 'all'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+    get_tokenizer: !ref <get_tokenizer>
+    allowed_special: !ref <allowed_special>
+filter: !name:cosyvoice.dataset.processor.filter
+    max_length: 40960
+    min_length: 0
+    token_max_length: 200
+    token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+    resample_rate: !ref <sample_rate>
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+    feat_extractor: !ref <feat_extractor>
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+    normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+    shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+    sort_size: 500  # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+    batch_type: 'dynamic'
+    max_frames_in_batch: 2000
+padding: !name:cosyvoice.dataset.processor.padding
+    use_spk_embedding: False # change to True during sft
+
+# dataset processor pipeline
+data_pipeline: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <compute_fbank>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+
+# train conf
+train_conf:
+    optim: adam
+    optim_conf:
+        lr: 0.001 # change to 1e-5 during sft
+    scheduler: warmuplr # change to constantlr during sft
+    scheduler_conf:
+        warmup_steps: 2500
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 2
+    log_interval: 100
+    save_per_step: -1

+ 42 - 0
examples/magicdata-read/cosyvoice/conf/ds_stage2.json

@@ -0,0 +1,42 @@
+{
+  "train_micro_batch_size_per_gpu": 1,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 100,
+  "gradient_clipping": 5,
+  "fp16": {
+    "enabled": false,
+    "auto_cast": false,
+    "loss_scale": 0,
+    "initial_scale_power": 16,
+    "loss_scale_window": 256,
+    "hysteresis": 2,
+    "consecutive_hysteresis": false,
+    "min_loss_scale": 1
+  },
+  "bf16": {
+    "enabled": false
+  },
+  "zero_force_ds_cpu_optimizer": false,
+  "zero_optimization": {
+    "stage": 2,
+    "offload_optimizer": {
+      "device": "none",
+      "pin_memory": true
+    },
+    "allgather_partitions": true,
+    "allgather_bucket_size": 5e8,
+    "overlap_comm": false,
+    "reduce_scatter": true,
+    "reduce_bucket_size": 5e8,
+    "contiguous_gradients" : true
+  },
+  "optimizer": {
+    "type": "AdamW",
+    "params": {
+        "lr": 0.001,
+        "weight_decay": 0.0001,
+        "torch_adam": true,
+        "adam_w_mode": true
+    }
+  }
+}

+ 1 - 0
examples/magicdata-read/cosyvoice/cosyvoice

@@ -0,0 +1 @@
+../../../cosyvoice

+ 97 - 0
examples/magicdata-read/cosyvoice/local/download_and_untar.sh

@@ -0,0 +1,97 @@
+#!/bin/bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+  echo "          train-clean-100, train-clean-360, train-other-500."
+  exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1
+fi
+
+part_ok=false
+list="dev_set test_set train_set"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1
+fi
+
+if [ -f $data/.$part.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0
+fi
+
+
+# sizes of the archive files in bytes.  This is some older versions.
+sizes_old="1035537823 2201936013 52627842921"
+# sizes_new is the archive file sizes of the final release.  Some of these sizes are of
+# things we probably won't download.
+sizes_new="3886385"
+
+if [ -f $data/$part.tar.gz ]; then
+  size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tar.gz
+  else
+    echo "$data/$part.tar.gz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+  if ! which wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1
+  fi
+  full_url=$url/$part.tar.gz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  if ! wget -P $data --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1
+  fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+  echo "$0: error un-tarring archive $data/$part.tar.gz"
+  exit 1
+fi
+
+touch $data/.$part.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+  rm $data/$part.tar.gz
+fi

+ 50 - 0
examples/magicdata-read/cosyvoice/local/prepare_data.py

@@ -0,0 +1,50 @@
+import argparse
+import logging
+import os
+from tqdm import tqdm
+
+
+logger = logging.getLogger()
+
+def main():
+    utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
+    with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
+        lines = f.readlines()[1:]
+        lines = [l.split('\t') for l in lines]
+    for wav, spk, content in tqdm(lines):
+        wav, spk, content = wav.strip(), spk.strip(), content.strip()
+        content = content.replace('[FIL]', '')
+        content = content.replace('[SPK]', '')
+        wav = os.path.join(args.src_dir, spk, wav)
+        if not os.path.exists(wav):
+            continue
+        utt = os.path.basename(wav).replace('.wav', '')
+        utt2wav[utt] = wav
+        utt2text[utt] = content
+        utt2spk[utt] = spk
+        if spk not in spk2utt:
+            spk2utt[spk] = []
+        spk2utt[spk].append(utt)
+
+    with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
+        for k, v in utt2wav.items():
+            f.write('{} {}\n'.format(k, v))
+    with open('{}/text'.format(args.des_dir), 'w') as f:
+        for k, v in utt2text.items():
+            f.write('{} {}\n'.format(k, v))
+    with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
+        for k, v in utt2spk.items():
+            f.write('{} {}\n'.format(k, v))
+    with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
+        for k, v in spk2utt.items():
+            f.write('{} {}\n'.format(k, ' '.join(v)))
+    return
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--src_dir',
+                        type=str)
+    parser.add_argument('--des_dir',
+                        type=str)
+    args = parser.parse_args()
+    main()

+ 3 - 0
examples/magicdata-read/cosyvoice/path.sh

@@ -0,0 +1,3 @@
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH

+ 105 - 0
examples/magicdata-read/cosyvoice/run.sh

@@ -0,0 +1,105 @@
+#!/bin/bash
+# Copyright 2024 Alibaba Inc. All Rights Reserved.
+. ./path.sh || exit 1;
+
+stage=-1
+stop_stage=3
+
+data_url=www.openslr.org/resources/68
+data_dir=/mnt/hengwu.zty/data/tts/openslr/magicdata-read
+pretrained_model_dir=../../../pretrained_models/CosyVoice-300M
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+  echo "Data Download"
+  for part in dev_set test_set train_set; do
+    local/download_and_untar.sh ${data_dir} ${data_url} ${part}
+  done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+  echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
+  for x in dev test train; do
+    mkdir -p data/$x
+    python local/prepare_data.py --src_dir $data_dir/$x --des_dir data/$x
+  done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
+  for x in dev test train; do
+    tools/extract_embedding.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/campplus.onnx
+  done
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
+  for x in dev test train; do
+    tools/extract_speech_token.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
+  done
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
+  for x in dev test train; do
+    mkdir -p data/$x/parquet
+    tools/make_parquet_list.py --num_utts_per_parquet 1000 \
+      --num_processes 10 \
+      --src_dir data/$x \
+      --des_dir data/$x/parquet
+  done
+fi
+
+# inference
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+  echo "Run inference. Please make sure utt in tts_text is in prompt_data"
+  for mode in sft zero_shot; do
+    python cosyvoice/bin/inference.py --mode $mode \
+      --gpu 0 \
+      --config conf/cosyvoice.yaml \
+      --prompt_data data/test/parquet/data.list \
+      --prompt_utt2data data/test/parquet/utt2data.list \
+      --tts_text `pwd`/tts_text.json \
+      --llm_model $pretrained_model_dir/llm.pt \
+      --flow_model $pretrained_model_dir/flow.pt \
+      --hifigan_model $pretrained_model_dir/hift.pt \
+      --result_dir `pwd`/exp/cosyvoice/test/$mode
+  done
+fi
+
+# train llm
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+job_id=1986
+dist_backend="nccl"
+num_workers=2
+prefetch=100
+train_engine=torch_ddp
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+  echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
+  if [ $train_engine == 'deepspeed' ]; then
+    echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
+  fi
+  cp data/train/parquet/data.list data/train.data.list
+  cp data/dev/parquet/data.list data/dev.data.list
+  for model in llm; do
+    torchrun --nnodes=1 --nproc_per_node=$num_gpus \
+        --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
+      cosyvoice/bin/train.py \
+      --train_engine $train_engine \
+      --config conf/cosyvoice.yaml \
+      --train_data data/train.data.list \
+      --cv_data data/dev.data.list \
+      --model $model \
+      --checkpoint $pretrained_model_dir/$model.pt \
+      --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \
+      --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \
+      --ddp.dist_backend $dist_backend \
+      --num_workers ${num_workers} \
+      --prefetch ${prefetch} \
+      --pin_memory \
+      --deepspeed_config ./conf/ds_stage2.json \
+      --deepspeed.save_states model+optimizer
+  done
+fi

+ 1 - 0
examples/magicdata-read/cosyvoice/tools

@@ -0,0 +1 @@
+../../../tools

+ 18 - 0
examples/magicdata-read/cosyvoice/tts_text.json

@@ -0,0 +1,18 @@
+{
+  "38_5718_20170915093303": [
+    "我想这出最好歌曲把歌词发到网上请别人帮我作曲急急",
+    "叫他明天早上差五分儿九点去机场"
+  ],
+  "38_5721_20170915091235": [
+    "变温室调到零下两度档",
+    "交谈中请勿轻信汇款信息陌生电话请勿使用外挂软件"
+  ],
+  "38_5733_20170915130323": [
+    "这是老鹰乐队的一首经典歌曲",
+    "我急用这段音乐我自己找到一段但是有现场杂音"
+  ],
+  "38_5836_20170916221414": [
+    "给我播一个陶喆的专辑",
+    "这套餐好贵呀我发这么多短信贵死了"
+  ]
+}