1
0
Selaa lähdekoodia

Merge branch 'inference_streaming' into flow_tensorrt

Xiang Lyu 1 vuosi sitten
vanhempi
commit
ee988420f3

+ 30 - 0
README.md

@@ -4,6 +4,36 @@
 
 For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVoice) and [SenseVoice space](https://www.modelscope.cn/studios/iic/SenseVoice).
 
+## Roadmap
+
+- [x] 2024/07
+
+    - [x] Flow matching training support
+    - [x] WeTextProcessing support when ttsfrd is not avaliable
+    - [x] Fastapi server and client
+
+- [ ] 2024/08
+
+    - [ ] Repetition Aware Sampling(RAS) inference for llm stability
+    - [ ] Streaming inference mode support, including kv cache and sdpa for rtf optimization
+
+- [ ] 2024/09
+
+    - [ ] 50hz llm model which supports 10 language
+
+- [ ] 2024/10
+
+    - [ ] 50hz llama based llm model which supports lora finetune
+
+- [ ] TBD
+
+    - [ ] Support more instruction mode
+    - [ ] Voice conversion
+    - [ ] Music generation
+    - [ ] Training script sample based on Mandarin
+    - [ ] CosyVoice-500M trained with more multi-lingual data
+    - [ ] More...
+
 ## Install
 
 **Clone and install**

+ 0 - 1
cosyvoice/cli/cosyvoice.py

@@ -43,7 +43,6 @@ class CosyVoice:
         if load_jit:
             self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
                                     '{}/llm.llm.fp16.zip'.format(model_dir))
-
         if load_trt:
             self.model.load_trt(model_dir, use_fp16)
             

+ 2 - 3
cosyvoice/cli/model.py

@@ -137,7 +137,6 @@ class CosyVoiceModel:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
         p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         p.start()
-        p.join()
         if stream is True:
             token_hop_len = self.token_min_hop_len
             while True:
@@ -158,7 +157,7 @@ class CosyVoiceModel:
                     token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
                 if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
                     break
-            # p.join()
+            p.join()
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
@@ -171,7 +170,7 @@ class CosyVoiceModel:
             yield {'tts_speech': this_tts_speech.cpu()}
         else:
             # deal with all tokens
-            # p.join()
+            p.join()
             this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
             with self.flow_hift_context:
                 this_tts_speech = self.token2wav(token=this_tts_speech_token,

+ 6 - 0
cosyvoice/flow/flow.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 from typing import Dict, Optional
 import torch
 import torch.nn as nn
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
 
         # get conditions
         conds = torch.zeros(feat.shape, device=token.device)
+        for i, j in enumerate(feat_len):
+            if random.random() < 0.5:
+                continue
+            index = random.randint(0, int(0.3 * j))
+            conds[i, :index] = feat[i, :index]
         conds = conds.transpose(1, 2)
 
         mask = (~make_pad_mask(feat_len)).to(h)

+ 2 - 2
cosyvoice/flow/flow_matching.py

@@ -78,10 +78,10 @@ class ConditionalCFM(BASECFM):
         sol = []
 
         for step in range(1, len(t_span)):
-            dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
+            dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
             # Classifier-Free Guidance inference introduced in VoiceBox
             if self.inference_cfg_rate > 0:
-                cfg_dphi_dt = self.forward_estimator(
+                cfg_dphi_dt = self.estimator(
                     x, mask,
                     torch.zeros_like(mu), t,
                     torch.zeros_like(spks) if spks is not None else None,

+ 1 - 1
cosyvoice/transformer/encoder.py

@@ -299,7 +299,7 @@ class BaseEncoder(torch.nn.Module):
                rate.
             3. Currently, nn.Sequential is used to stack all the convolution
                layers in subsampling, we need to rewrite it to make it work
-               with cache, which is not prefered.
+               with cache, which is not preferred.
         Args:
             xs (torch.Tensor): (1, max_len, dim)
             chunk_size (int): decoding chunk size

+ 198 - 0
examples/magicdata-read/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -0,0 +1,198 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 22050
+text_encoder_input_size: 512
+llm_input_size: 1024
+llm_output_size: 1024
+spk_embed_dim: 192
+
+# model params
+# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+llm: !new:cosyvoice.llm.llm.TransformerLM
+    text_encoder_input_size: !ref <text_encoder_input_size>
+    llm_input_size: !ref <llm_input_size>
+    llm_output_size: !ref <llm_output_size>
+    text_token_size: 51866
+    speech_token_size: 4096
+    length_normalized_loss: True
+    lsm_weight: 0
+    spk_embed_dim: !ref <spk_embed_dim>
+    text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+        input_size: !ref <text_encoder_input_size>
+        output_size: 1024
+        attention_heads: 8
+        linear_units: 2048
+        num_blocks: 3
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.0
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        use_cnn_module: False
+        macaron_style: False
+        use_dynamic_chunk: False
+        use_dynamic_left_chunk: False
+        static_chunk_size: 1
+    llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
+        input_size: !ref <llm_input_size>
+        output_size: !ref <llm_output_size>
+        attention_heads: 8
+        linear_units: 2048
+        num_blocks: 7
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.0
+        input_layer: 'linear_legacy'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        static_chunk_size: 1
+
+flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
+    input_size: 512
+    output_size: 80
+    spk_embed_dim: !ref <spk_embed_dim>
+    output_type: 'mel'
+    vocab_size: 4096
+    input_frame_rate: 50
+    only_mask_loss: True
+    encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+        output_size: 512
+        attention_heads: 4
+        linear_units: 1024
+        num_blocks: 3
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.1
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        input_size: 512
+        use_cnn_module: False
+        macaron_style: False
+    length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
+        channels: 80
+        sampling_ratios: [1, 1, 1, 1]
+    decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
+        in_channels: 240
+        n_spks: 1
+        spk_emb_dim: 80
+        cfm_params: !new:omegaconf.DictConfig
+            content:
+                sigma_min: 1e-06
+                solver: 'euler'
+                t_scheduler: 'cosine'
+                training_cfg_rate: 0.2
+                inference_cfg_rate: 0.7
+                reg_loss_type: 'l1'
+        estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
+            in_channels: 320
+            out_channels: 80
+            channels: [256, 256]
+            dropout: 0.0
+            attention_head_dim: 64
+            n_blocks: 4
+            num_mid_blocks: 8
+            num_heads: 8
+            act_fn: 'gelu'
+
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+    in_channels: 80
+    base_channels: 512
+    nb_harmonics: 8
+    sampling_rate: !ref <sample_rate>
+    nsf_alpha: 0.1
+    nsf_sigma: 0.003
+    nsf_voiced_threshold: 10
+    upsample_rates: [8, 8]
+    upsample_kernel_sizes: [16, 16]
+    istft_params:
+        n_fft: 16
+        hop_len: 4
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    source_resblock_kernel_sizes: [7, 11]
+    source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
+    lrelu_slope: 0.1
+    audio_limit: 0.99
+    f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+        num_class: 1
+        in_channels: 80
+        cond_channels: 512
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:whisper.tokenizer.get_tokenizer
+    multilingual: True
+    num_languages: 100
+    language: 'en'
+    task: 'transcribe'
+allowed_special: 'all'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+    get_tokenizer: !ref <get_tokenizer>
+    allowed_special: !ref <allowed_special>
+filter: !name:cosyvoice.dataset.processor.filter
+    max_length: 40960
+    min_length: 0
+    token_max_length: 200
+    token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+    resample_rate: !ref <sample_rate>
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+    feat_extractor: !ref <feat_extractor>
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+    normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+    shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+    sort_size: 500  # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+    batch_type: 'dynamic'
+    max_frames_in_batch: 12000
+padding: !name:cosyvoice.dataset.processor.padding
+    use_spk_embedding: False # change to True during sft
+
+# dataset processor pipeline
+data_pipeline: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <compute_fbank>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+
+# train conf
+train_conf:
+    optim: adam
+    optim_conf:
+        lr: 0.002 # change to 0.001 if you want to train flow from scratch
+    scheduler: warmuplr
+    scheduler_conf:
+        warmup_steps: 25000
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 2
+    log_interval: 100
+    save_per_step: -1

+ 198 - 0
examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml

@@ -0,0 +1,198 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 22050
+text_encoder_input_size: 512
+llm_input_size: 1024
+llm_output_size: 1024
+spk_embed_dim: 192
+
+# model params
+# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+llm: !new:cosyvoice.llm.llm.TransformerLM
+    text_encoder_input_size: !ref <text_encoder_input_size>
+    llm_input_size: !ref <llm_input_size>
+    llm_output_size: !ref <llm_output_size>
+    text_token_size: 51866
+    speech_token_size: 4096
+    length_normalized_loss: True
+    lsm_weight: 0
+    spk_embed_dim: !ref <spk_embed_dim>
+    text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+        input_size: !ref <text_encoder_input_size>
+        output_size: 1024
+        attention_heads: 16
+        linear_units: 4096
+        num_blocks: 6
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.0
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        use_cnn_module: False
+        macaron_style: False
+        use_dynamic_chunk: False
+        use_dynamic_left_chunk: False
+        static_chunk_size: 1
+    llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
+        input_size: !ref <llm_input_size>
+        output_size: !ref <llm_output_size>
+        attention_heads: 16
+        linear_units: 4096
+        num_blocks: 14
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.0
+        input_layer: 'linear_legacy'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        static_chunk_size: 1
+
+flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
+    input_size: 512
+    output_size: 80
+    spk_embed_dim: !ref <spk_embed_dim>
+    output_type: 'mel'
+    vocab_size: 4096
+    input_frame_rate: 50
+    only_mask_loss: True
+    encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
+        output_size: 512
+        attention_heads: 8
+        linear_units: 2048
+        num_blocks: 6
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.1
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        input_size: 512
+        use_cnn_module: False
+        macaron_style: False
+    length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
+        channels: 80
+        sampling_ratios: [1, 1, 1, 1]
+    decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
+        in_channels: 240
+        n_spks: 1
+        spk_emb_dim: 80
+        cfm_params: !new:omegaconf.DictConfig
+            content:
+                sigma_min: 1e-06
+                solver: 'euler'
+                t_scheduler: 'cosine'
+                training_cfg_rate: 0.2
+                inference_cfg_rate: 0.7
+                reg_loss_type: 'l1'
+        estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
+            in_channels: 320
+            out_channels: 80
+            channels: [256, 256]
+            dropout: 0.0
+            attention_head_dim: 64
+            n_blocks: 4
+            num_mid_blocks: 12
+            num_heads: 8
+            act_fn: 'gelu'
+
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+    in_channels: 80
+    base_channels: 512
+    nb_harmonics: 8
+    sampling_rate: !ref <sample_rate>
+    nsf_alpha: 0.1
+    nsf_sigma: 0.003
+    nsf_voiced_threshold: 10
+    upsample_rates: [8, 8]
+    upsample_kernel_sizes: [16, 16]
+    istft_params:
+        n_fft: 16
+        hop_len: 4
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    source_resblock_kernel_sizes: [7, 11]
+    source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
+    lrelu_slope: 0.1
+    audio_limit: 0.99
+    f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+        num_class: 1
+        in_channels: 80
+        cond_channels: 512
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:whisper.tokenizer.get_tokenizer
+    multilingual: True
+    num_languages: 100
+    language: 'en'
+    task: 'transcribe'
+allowed_special: 'all'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+    get_tokenizer: !ref <get_tokenizer>
+    allowed_special: !ref <allowed_special>
+filter: !name:cosyvoice.dataset.processor.filter
+    max_length: 40960
+    min_length: 0
+    token_max_length: 200
+    token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+    resample_rate: !ref <sample_rate>
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1024
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 256
+    win_size: 1024
+    fmin: 0
+    fmax: 8000
+    center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+    feat_extractor: !ref <feat_extractor>
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+    normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+    shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+    sort_size: 500  # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+    batch_type: 'dynamic'
+    max_frames_in_batch: 2000
+padding: !name:cosyvoice.dataset.processor.padding
+    use_spk_embedding: False # change to True during sft
+
+# dataset processor pipeline
+data_pipeline: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <compute_fbank>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+
+# train conf
+train_conf:
+    optim: adam
+    optim_conf:
+        lr: 0.001 # change to 1e-5 during sft
+    scheduler: warmuplr # change to constantlr during sft
+    scheduler_conf:
+        warmup_steps: 2500
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 2
+    log_interval: 100
+    save_per_step: -1

+ 42 - 0
examples/magicdata-read/cosyvoice/conf/ds_stage2.json

@@ -0,0 +1,42 @@
+{
+  "train_micro_batch_size_per_gpu": 1,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 100,
+  "gradient_clipping": 5,
+  "fp16": {
+    "enabled": false,
+    "auto_cast": false,
+    "loss_scale": 0,
+    "initial_scale_power": 16,
+    "loss_scale_window": 256,
+    "hysteresis": 2,
+    "consecutive_hysteresis": false,
+    "min_loss_scale": 1
+  },
+  "bf16": {
+    "enabled": false
+  },
+  "zero_force_ds_cpu_optimizer": false,
+  "zero_optimization": {
+    "stage": 2,
+    "offload_optimizer": {
+      "device": "none",
+      "pin_memory": true
+    },
+    "allgather_partitions": true,
+    "allgather_bucket_size": 5e8,
+    "overlap_comm": false,
+    "reduce_scatter": true,
+    "reduce_bucket_size": 5e8,
+    "contiguous_gradients" : true
+  },
+  "optimizer": {
+    "type": "AdamW",
+    "params": {
+        "lr": 0.001,
+        "weight_decay": 0.0001,
+        "torch_adam": true,
+        "adam_w_mode": true
+    }
+  }
+}

+ 1 - 0
examples/magicdata-read/cosyvoice/cosyvoice

@@ -0,0 +1 @@
+../../../cosyvoice

+ 97 - 0
examples/magicdata-read/cosyvoice/local/download_and_untar.sh

@@ -0,0 +1,97 @@
+#!/bin/bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+  echo "          train-clean-100, train-clean-360, train-other-500."
+  exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1
+fi
+
+part_ok=false
+list="dev_set test_set train_set"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1
+fi
+
+if [ -f $data/.$part.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0
+fi
+
+
+# sizes of the archive files in bytes.  This is some older versions.
+sizes_old="1035537823 2201936013 52627842921"
+# sizes_new is the archive file sizes of the final release.  Some of these sizes are of
+# things we probably won't download.
+sizes_new="3886385"
+
+if [ -f $data/$part.tar.gz ]; then
+  size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tar.gz
+  else
+    echo "$data/$part.tar.gz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+  if ! which wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1
+  fi
+  full_url=$url/$part.tar.gz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  if ! wget -P $data --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1
+  fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+  echo "$0: error un-tarring archive $data/$part.tar.gz"
+  exit 1
+fi
+
+touch $data/.$part.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+  rm $data/$part.tar.gz
+fi

+ 50 - 0
examples/magicdata-read/cosyvoice/local/prepare_data.py

@@ -0,0 +1,50 @@
+import argparse
+import logging
+import os
+from tqdm import tqdm
+
+
+logger = logging.getLogger()
+
+def main():
+    utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
+    with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
+        lines = f.readlines()[1:]
+        lines = [l.split('\t') for l in lines]
+    for wav, spk, content in tqdm(lines):
+        wav, spk, content = wav.strip(), spk.strip(), content.strip()
+        content = content.replace('[FIL]', '')
+        content = content.replace('[SPK]', '')
+        wav = os.path.join(args.src_dir, spk, wav)
+        if not os.path.exists(wav):
+            continue
+        utt = os.path.basename(wav).replace('.wav', '')
+        utt2wav[utt] = wav
+        utt2text[utt] = content
+        utt2spk[utt] = spk
+        if spk not in spk2utt:
+            spk2utt[spk] = []
+        spk2utt[spk].append(utt)
+
+    with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
+        for k, v in utt2wav.items():
+            f.write('{} {}\n'.format(k, v))
+    with open('{}/text'.format(args.des_dir), 'w') as f:
+        for k, v in utt2text.items():
+            f.write('{} {}\n'.format(k, v))
+    with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
+        for k, v in utt2spk.items():
+            f.write('{} {}\n'.format(k, v))
+    with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
+        for k, v in spk2utt.items():
+            f.write('{} {}\n'.format(k, ' '.join(v)))
+    return
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--src_dir',
+                        type=str)
+    parser.add_argument('--des_dir',
+                        type=str)
+    args = parser.parse_args()
+    main()

+ 3 - 0
examples/magicdata-read/cosyvoice/path.sh

@@ -0,0 +1,3 @@
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH

+ 105 - 0
examples/magicdata-read/cosyvoice/run.sh

@@ -0,0 +1,105 @@
+#!/bin/bash
+# Copyright 2024 Alibaba Inc. All Rights Reserved.
+. ./path.sh || exit 1;
+
+stage=-1
+stop_stage=3
+
+data_url=www.openslr.org/resources/68
+data_dir=/mnt/hengwu.zty/data/tts/openslr/magicdata-read
+pretrained_model_dir=../../../pretrained_models/CosyVoice-300M
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+  echo "Data Download"
+  for part in dev_set test_set train_set; do
+    local/download_and_untar.sh ${data_dir} ${data_url} ${part}
+  done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+  echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
+  for x in dev test train; do
+    mkdir -p data/$x
+    python local/prepare_data.py --src_dir $data_dir/$x --des_dir data/$x
+  done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
+  for x in dev test train; do
+    tools/extract_embedding.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/campplus.onnx
+  done
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
+  for x in dev test train; do
+    tools/extract_speech_token.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
+  done
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
+  for x in dev test train; do
+    mkdir -p data/$x/parquet
+    tools/make_parquet_list.py --num_utts_per_parquet 1000 \
+      --num_processes 10 \
+      --src_dir data/$x \
+      --des_dir data/$x/parquet
+  done
+fi
+
+# inference
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+  echo "Run inference. Please make sure utt in tts_text is in prompt_data"
+  for mode in sft zero_shot; do
+    python cosyvoice/bin/inference.py --mode $mode \
+      --gpu 0 \
+      --config conf/cosyvoice.yaml \
+      --prompt_data data/test/parquet/data.list \
+      --prompt_utt2data data/test/parquet/utt2data.list \
+      --tts_text `pwd`/tts_text.json \
+      --llm_model $pretrained_model_dir/llm.pt \
+      --flow_model $pretrained_model_dir/flow.pt \
+      --hifigan_model $pretrained_model_dir/hift.pt \
+      --result_dir `pwd`/exp/cosyvoice/test/$mode
+  done
+fi
+
+# train llm
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+job_id=1986
+dist_backend="nccl"
+num_workers=2
+prefetch=100
+train_engine=torch_ddp
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+  echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
+  if [ $train_engine == 'deepspeed' ]; then
+    echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
+  fi
+  cp data/train/parquet/data.list data/train.data.list
+  cp data/dev/parquet/data.list data/dev.data.list
+  for model in llm; do
+    torchrun --nnodes=1 --nproc_per_node=$num_gpus \
+        --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
+      cosyvoice/bin/train.py \
+      --train_engine $train_engine \
+      --config conf/cosyvoice.yaml \
+      --train_data data/train.data.list \
+      --cv_data data/dev.data.list \
+      --model $model \
+      --checkpoint $pretrained_model_dir/$model.pt \
+      --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \
+      --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \
+      --ddp.dist_backend $dist_backend \
+      --num_workers ${num_workers} \
+      --prefetch ${prefetch} \
+      --pin_memory \
+      --deepspeed_config ./conf/ds_stage2.json \
+      --deepspeed.save_states model+optimizer
+  done
+fi

+ 1 - 0
examples/magicdata-read/cosyvoice/tools

@@ -0,0 +1 @@
+../../../tools

+ 18 - 0
examples/magicdata-read/cosyvoice/tts_text.json

@@ -0,0 +1,18 @@
+{
+  "38_5718_20170915093303": [
+    "我想这出最好歌曲把歌词发到网上请别人帮我作曲急急",
+    "叫他明天早上差五分儿九点去机场"
+  ],
+  "38_5721_20170915091235": [
+    "变温室调到零下两度档",
+    "交谈中请勿轻信汇款信息陌生电话请勿使用外挂软件"
+  ],
+  "38_5733_20170915130323": [
+    "这是老鹰乐队的一首经典歌曲",
+    "我急用这段音乐我自己找到一段但是有现场杂音"
+  ],
+  "38_5836_20170916221414": [
+    "给我播一个陶喆的专辑",
+    "这套餐好贵呀我发这么多短信贵死了"
+  ]
+}