Browse Source

Merge pull request #1670 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
Xiang Lyu 4 months ago
parent
commit
0c50894d49

+ 1 - 1
.github/workflows/lint.yml

@@ -52,5 +52,5 @@ jobs:
           set -eux
           set -eux
           pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
           pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
           flake8 --version
           flake8 --version
-          flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
+          flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F722,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
           if [ $? != 0 ]; then exit 1; fi
           if [ $? != 0 ]; then exit 1; fi

+ 45 - 107
README.md

@@ -2,7 +2,7 @@
 
 
 ## 👉🏻 CosyVoice 👈🏻
 ## 👉🏻 CosyVoice 👈🏻
 
 
-**CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/abs/2505.17589); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
+**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/abs/2505.17589); [Modelscope](https://www.modelscope.cn/studios/FunAudioLLM/Fun-CosyVoice3-0.5B); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
 
 
 **CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
 **CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
 
 
@@ -10,45 +10,43 @@
 
 
 ## Highlight🔥
 ## Highlight🔥
 
 
-**CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
-### Multilingual
-- **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
-- **Crosslingual & Mixlingual**:Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
-### Ultra-Low Latency
-- **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
-- **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
-### High Accuracy
-- **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
-- **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
-### Strong Stability
-- **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
-- **Cross-language Synthesis**: Marked improvements compared to version 1.0.
-### Natural Experience
-- **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
-- **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
+**Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
+### Key Features
+- **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shan1dong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
+- **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
+- **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
+- **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
+- **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
+- **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.
+
 
 
 ## Roadmap
 ## Roadmap
 
 
+- [x] 2025/12
+
+    - [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
+    - [x] release Fun-CosyVoice3-0.5B modelscope gradio space
+
 - [x] 2025/08
 - [x] 2025/08
 
 
     - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
     - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
 
 
 - [x] 2025/07
 - [x] 2025/07
 
 
-    - [x] release cosyvoice 3.0 eval set
+    - [x] release Fun-CosyVoice 3.0 eval set
 
 
 - [x] 2025/05
 - [x] 2025/05
 
 
-    - [x] add cosyvoice 2.0 vllm support
+    - [x] add CosyVoice2-0.5B vllm support
 
 
 - [x] 2024/12
 - [x] 2024/12
 
 
-    - [x] 25hz cosyvoice 2.0 released
+    - [x] 25hz CosyVoice2-0.5B released
 
 
 - [x] 2024/09
 - [x] 2024/09
 
 
-    - [x] 25hz cosyvoice base model
-    - [x] 25hz cosyvoice voice conversion model
+    - [x] 25hz CosyVoice-300M base model
+    - [x] 25hz CosyVoice-300M voice conversion function
 
 
 - [x] 2024/08
 - [x] 2024/08
 
 
@@ -61,6 +59,25 @@
     - [x] WeTextProcessing support when ttsfrd is not available
     - [x] WeTextProcessing support when ttsfrd is not available
     - [x] Fastapi server and client
     - [x] Fastapi server and client
 
 
+## Evaluation
+| Model | CER (%) ↓ (test-zh) | WER (%) ↓ (test-en) | CER (%) ↓ (test-hard) |
+|-----|------------------|------------------|------------------|
+| Human | 1.26 | 2.14 | - |
+| F5-TTS | 1.53 | 2.00 | 8.67 |
+| SparkTTS | 1.20 | 1.98 | - |
+| Seed-TTS | 1.12 | 2.25 | 7.59 |
+| CosyVoice2 | 1.45 | 2.57 | 6.83 |
+| FireRedTTS-2 | 1.14 | 1.95 | - |
+| IndexTTS2 | 1.01 | 1.52 | 7.12 |
+| VibeVoice | 1.16 | 3.04 | - |
+| HiggsAudio | 1.79 | 2.44 | - |
+| MiniMax-Speech | 0.83 | 1.65 | - |
+| VoxPCM | 0.93 | 1.85 | 8.87 |
+| GLM-TTS | 1.03 | - | - |
+| GLM-TTS_RL | 0.89 | - | - |
+| Fun-CosyVoice3-0.5B-2512 | 1.21 |  2.24 | 6.71 |
+| Fun-CosyVoice3-0.5B-2512_RL | 0.81 | 1.68 | 5.44 |
+
 
 
 ## Install
 ## Install
 
 
@@ -91,11 +108,12 @@
 
 
 ### Model download
 ### Model download
 
 
-We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
+We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
 
 
 ``` python
 ``` python
 # SDK模型下载
 # SDK模型下载
 from modelscope import snapshot_download
 from modelscope import snapshot_download
+snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
 snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
 snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
 snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
 snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
 snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
 snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
@@ -103,16 +121,6 @@ snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/Co
 snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
 snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
 ```
 ```
 
 
-``` sh
-# git模型下载,请确保已安装git lfs
-mkdir -p pretrained_models
-git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
-git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
-git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
-git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
-git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
-```
-
 Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
 Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
 
 
 Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.
 Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.
@@ -126,50 +134,10 @@ pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
 
 
 ### Basic Usage
 ### Basic Usage
 
 
-We strongly recommend using `CosyVoice2-0.5B` for better performance.
-Follow the code below for detailed usage of each model.
-
-``` python
-import sys
-sys.path.append('third_party/Matcha-TTS')
-from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
-from cosyvoice.utils.file_utils import load_wav
-import torchaudio
-```
-
-#### CosyVoice2 Usage
-```python
-cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, load_vllm=False, fp16=False)
-
-# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
-# zero_shot usage
-prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
-for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
-    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-
-# save zero_shot spk for future usage
-assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', prompt_speech_16k, 'my_zero_shot_spk') is True
-for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk', stream=False)):
-    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-cosyvoice.save_spkinfo()
-
-# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
-for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
-    torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-
-# instruct usage
-for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
-    torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-
-# bistream usage, you can use generator as input, this is useful when using text llm model as input
-# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
-def text_generator():
-    yield '收到好友从远方寄来的生日礼物,'
-    yield '那份意外的惊喜与深深的祝福'
-    yield '让我心中充满了甜蜜的快乐,'
-    yield '笑容如花儿般绽放。'
-for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
-    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
+Follow the code in `example.py` for detailed usage of each model.
+```sh
+python example.py
 ```
 ```
 
 
 #### CosyVoice2 vllm Usage
 #### CosyVoice2 vllm Usage
@@ -184,36 +152,6 @@ pip install vllm==v0.9.0 transformers==4.51.3 -i https://mirrors.aliyun.com/pypi
 python vllm_example.py
 python vllm_example.py
 ```
 ```
 
 
-#### CosyVoice Usage
-```python
-cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
-# sft usage
-print(cosyvoice.list_available_spks())
-# change stream=True for chunk stream inference
-for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
-    torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-
-cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
-# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
-prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
-for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
-    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-# cross_lingual usage
-prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
-for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
-    torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-# vc usage
-prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
-source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
-for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
-    torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-
-cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
-# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
-for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
-    torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
-```
-
 #### Start web demo
 #### Start web demo
 
 
 You can use our web demo page to get familiar with CosyVoice quickly.
 You can use our web demo page to get familiar with CosyVoice quickly.

BIN
asset/dingding.png


BIN
asset/zero_shot_prompt.wav


+ 14 - 16
cosyvoice/bin/export_jit.py

@@ -23,8 +23,10 @@ import torch
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
+from cosyvoice.cli.cosyvoice import AutoModel
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
 from cosyvoice.utils.file_utils import logging
 from cosyvoice.utils.file_utils import logging
+from cosyvoice.utils.class_utils import get_model_type
 
 
 
 
 def get_args():
 def get_args():
@@ -57,15 +59,17 @@ def main():
     torch._C._jit_set_profiling_mode(False)
     torch._C._jit_set_profiling_mode(False)
     torch._C._jit_set_profiling_executor(False)
     torch._C._jit_set_profiling_executor(False)
 
 
-    try:
-        model = CosyVoice(args.model_dir)
-    except Exception:
-        try:
-            model = CosyVoice2(args.model_dir)
-        except Exception:
-            raise TypeError('no valid model_type!')
+    model = AutoModel(model_dir=args.model_dir)
 
 
-    if not isinstance(model, CosyVoice2):
+    if get_model_type(model.model) == CosyVoiceModel:
+        # 1. export flow encoder
+        flow_encoder = model.model.flow.encoder
+        script = get_optimized_script(flow_encoder)
+        script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
+        script = get_optimized_script(flow_encoder.half())
+        script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
+        logging.info('successfully export flow_encoder')
+    elif get_model_type(model.model) == CosyVoice2Model:
         # 1. export llm text_encoder
         # 1. export llm text_encoder
         llm_text_encoder = model.model.llm.text_encoder
         llm_text_encoder = model.model.llm.text_encoder
         script = get_optimized_script(llm_text_encoder)
         script = get_optimized_script(llm_text_encoder)
@@ -90,13 +94,7 @@ def main():
         script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
         script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
         logging.info('successfully export flow_encoder')
         logging.info('successfully export flow_encoder')
     else:
     else:
-        # 3. export flow encoder
-        flow_encoder = model.model.flow.encoder
-        script = get_optimized_script(flow_encoder)
-        script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
-        script = get_optimized_script(flow_encoder.half())
-        script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
-        logging.info('successfully export flow_encoder')
+        raise ValueError('unsupported model type')
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':

+ 2 - 8
cosyvoice/bin/export_onnx.py

@@ -27,7 +27,7 @@ from tqdm import tqdm
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../..'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
 sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
+from cosyvoice.cli.cosyvoice import AutoModel
 from cosyvoice.utils.file_utils import logging
 from cosyvoice.utils.file_utils import logging
 
 
 
 
@@ -58,13 +58,7 @@ def main():
     logging.basicConfig(level=logging.DEBUG,
     logging.basicConfig(level=logging.DEBUG,
                         format='%(asctime)s %(levelname)s %(message)s')
                         format='%(asctime)s %(levelname)s %(message)s')
 
 
-    try:
-        model = CosyVoice(args.model_dir)
-    except Exception:
-        try:
-            model = CosyVoice2(args.model_dir)
-        except Exception:
-            raise TypeError('no valid model_type!')
+    model = AutoModel(model_dir=args.model_dir)
 
 
     # 1. export flow decoder estimator
     # 1. export flow decoder estimator
     estimator = model.model.flow.decoder.estimator
     estimator = model.model.flow.decoder.estimator

+ 0 - 126
cosyvoice/bin/inference_deprecated.py

@@ -1,126 +0,0 @@
-# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import argparse
-import logging
-logging.getLogger('matplotlib').setLevel(logging.WARNING)
-import os
-import torch
-from torch.utils.data import DataLoader
-import torchaudio
-from hyperpyyaml import load_hyperpyyaml
-from tqdm import tqdm
-from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
-from cosyvoice.dataset.dataset import Dataset
-
-
-def get_args():
-    parser = argparse.ArgumentParser(description='inference with your model')
-    parser.add_argument('--config', required=True, help='config file')
-    parser.add_argument('--prompt_data', required=True, help='prompt data file')
-    parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
-    parser.add_argument('--tts_text', required=True, help='tts input file')
-    parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
-    parser.add_argument('--llm_model', required=True, help='llm model file')
-    parser.add_argument('--flow_model', required=True, help='flow model file')
-    parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
-    parser.add_argument('--gpu',
-                        type=int,
-                        default=-1,
-                        help='gpu id for this rank, -1 for cpu')
-    parser.add_argument('--mode',
-                        default='sft',
-                        choices=['sft', 'zero_shot'],
-                        help='inference mode')
-    parser.add_argument('--result_dir', required=True, help='asr result file')
-    args = parser.parse_args()
-    print(args)
-    return args
-
-
-def main():
-    args = get_args()
-    logging.basicConfig(level=logging.DEBUG,
-                        format='%(asctime)s %(levelname)s %(message)s')
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
-
-    # Init cosyvoice models from configs
-    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
-    device = torch.device('cuda' if use_cuda else 'cpu')
-    try:
-        with open(args.config, 'r') as f:
-            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
-        model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
-    except Exception:
-        try:
-            with open(args.config, 'r') as f:
-                configs = load_hyperpyyaml(f)
-            model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
-        except Exception:
-            raise TypeError('no valid model_type!')
-
-    model.load(args.llm_model, args.flow_model, args.hifigan_model)
-
-    test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
-                           tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
-    test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
-
-    sample_rate = configs['sample_rate']
-    del configs
-    os.makedirs(args.result_dir, exist_ok=True)
-    fn = os.path.join(args.result_dir, 'wav.scp')
-    f = open(fn, 'w')
-    with torch.no_grad():
-        for _, batch in tqdm(enumerate(test_data_loader)):
-            utts = batch["utts"]
-            assert len(utts) == 1, "inference mode only support batchsize 1"
-            text_token = batch["text_token"].to(device)
-            text_token_len = batch["text_token_len"].to(device)
-            tts_index = batch["tts_index"]
-            tts_text_token = batch["tts_text_token"].to(device)
-            tts_text_token_len = batch["tts_text_token_len"].to(device)
-            speech_token = batch["speech_token"].to(device)
-            speech_token_len = batch["speech_token_len"].to(device)
-            speech_feat = batch["speech_feat"].to(device)
-            speech_feat_len = batch["speech_feat_len"].to(device)
-            utt_embedding = batch["utt_embedding"].to(device)
-            spk_embedding = batch["spk_embedding"].to(device)
-            if args.mode == 'sft':
-                model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
-                               'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
-            else:
-                model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
-                               'prompt_text': text_token, 'prompt_text_len': text_token_len,
-                               'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
-                               'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
-                               'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
-                               'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
-            tts_speeches = []
-            for model_output in model.tts(**model_input):
-                tts_speeches.append(model_output['tts_speech'])
-            tts_speeches = torch.concat(tts_speeches, dim=1)
-            tts_key = '{}_{}'.format(utts[0], tts_index[0])
-            tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
-            torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
-            f.write('{} {}\n'.format(tts_key, tts_fn))
-            f.flush()
-    f.close()
-    logging.info('Result wav.scp saved in {}'.format(fn))
-
-
-if __name__ == '__main__':
-    logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
-    main()

+ 67 - 23
cosyvoice/cli/cosyvoice.py

@@ -19,7 +19,7 @@ from hyperpyyaml import load_hyperpyyaml
 from modelscope import snapshot_download
 from modelscope import snapshot_download
 import torch
 import torch
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 from cosyvoice.cli.frontend import CosyVoiceFrontEnd
-from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
 from cosyvoice.utils.file_utils import logging
 from cosyvoice.utils.file_utils import logging
 from cosyvoice.utils.class_utils import get_model_type
 from cosyvoice.utils.class_utils import get_model_type
 
 
@@ -27,7 +27,6 @@ from cosyvoice.utils.class_utils import get_model_type
 class CosyVoice:
 class CosyVoice:
 
 
     def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
     def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
-        self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.model_dir = model_dir
         self.fp16 = fp16
         self.fp16 = fp16
         if not os.path.exists(model_dir):
         if not os.path.exists(model_dir):
@@ -37,7 +36,7 @@ class CosyVoice:
             raise ValueError('{} not found!'.format(hyper_yaml_path))
             raise ValueError('{} not found!'.format(hyper_yaml_path))
         with open(hyper_yaml_path, 'r') as f:
         with open(hyper_yaml_path, 'r') as f:
             configs = load_hyperpyyaml(f)
             configs = load_hyperpyyaml(f)
-        assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
+        assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
         self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
         self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                           configs['feat_extractor'],
                                           configs['feat_extractor'],
                                           '{}/campplus.onnx'.format(model_dir),
                                           '{}/campplus.onnx'.format(model_dir),
@@ -67,9 +66,9 @@ class CosyVoice:
         spks = list(self.frontend.spk2info.keys())
         spks = list(self.frontend.spk2info.keys())
         return spks
         return spks
 
 
-    def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
+    def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
         assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
         assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
-        model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
+        model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
         del model_input['text']
         del model_input['text']
         del model_input['text_len']
         del model_input['text_len']
         self.frontend.spk2info[zero_shot_spk_id] = model_input
         self.frontend.spk2info[zero_shot_spk_id] = model_input
@@ -89,12 +88,12 @@ class CosyVoice:
                 yield model_output
                 yield model_output
                 start_time = time.time()
                 start_time = time.time()
 
 
-    def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
+    def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
         prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
         prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
             if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
             if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
                 logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
                 logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
-            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
             start_time = time.time()
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -103,9 +102,9 @@ class CosyVoice:
                 yield model_output
                 yield model_output
                 start_time = time.time()
                 start_time = time.time()
 
 
-    def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
+    def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
-            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+            model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
             start_time = time.time()
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -116,8 +115,6 @@ class CosyVoice:
 
 
     def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
     def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
         assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
         assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
-        if self.instruct is False:
-            raise ValueError('{} do not support instruct inference'.format(self.model_dir))
         instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
         instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
             model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
             model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
@@ -129,8 +126,8 @@ class CosyVoice:
                 yield model_output
                 yield model_output
                 start_time = time.time()
                 start_time = time.time()
 
 
-    def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
-        model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
+    def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
+        model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
         start_time = time.time()
         start_time = time.time()
         for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
         for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
             speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
             speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
@@ -142,7 +139,6 @@ class CosyVoice:
 class CosyVoice2(CosyVoice):
 class CosyVoice2(CosyVoice):
 
 
     def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
     def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
-        self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.model_dir = model_dir
         self.fp16 = fp16
         self.fp16 = fp16
         if not os.path.exists(model_dir):
         if not os.path.exists(model_dir):
@@ -160,9 +156,9 @@ class CosyVoice2(CosyVoice):
                                           '{}/spk2info.pt'.format(model_dir),
                                           '{}/spk2info.pt'.format(model_dir),
                                           configs['allowed_special'])
                                           configs['allowed_special'])
         self.sample_rate = configs['sample_rate']
         self.sample_rate = configs['sample_rate']
-        if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
-            load_jit, load_trt, fp16 = False, False, False
-            logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
+        if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
+            load_jit, load_trt, load_vllm, fp16 = False, False, False, False
+            logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
         self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
         self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
         self.model.load('{}/llm.pt'.format(model_dir),
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
@@ -178,13 +174,9 @@ class CosyVoice2(CosyVoice):
                                 self.fp16)
                                 self.fp16)
         del configs
         del configs
 
 
-    def inference_instruct(self, *args, **kwargs):
-        raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
-
-    def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
-        assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
+    def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
-            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+            model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
             start_time = time.time()
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@@ -192,3 +184,55 @@ class CosyVoice2(CosyVoice):
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 yield model_output
                 start_time = time.time()
                 start_time = time.time()
+
+
+class CosyVoice3(CosyVoice2):
+
+    def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
+        self.model_dir = model_dir
+        self.fp16 = fp16
+        if not os.path.exists(model_dir):
+            model_dir = snapshot_download(model_dir)
+        hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
+        if not os.path.exists(hyper_yaml_path):
+            raise ValueError('{} not found!'.format(hyper_yaml_path))
+        with open(hyper_yaml_path, 'r') as f:
+            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
+        assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
+        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+                                          configs['feat_extractor'],
+                                          '{}/campplus.onnx'.format(model_dir),
+                                          '{}/speech_tokenizer_v3.onnx'.format(model_dir),
+                                          '{}/spk2info.pt'.format(model_dir),
+                                          configs['allowed_special'])
+        self.sample_rate = configs['sample_rate']
+        if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
+            load_trt, fp16 = False, False
+            logging.warning('no cuda device, set load_trt/fp16 to False')
+        self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
+        self.model.load('{}/llm.pt'.format(model_dir),
+                        '{}/flow.pt'.format(model_dir),
+                        '{}/hift.pt'.format(model_dir))
+        if load_vllm:
+            self.model.load_vllm('{}/vllm'.format(model_dir))
+        if load_trt:
+            if self.fp16 is True:
+                logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
+            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
+                                trt_concurrent,
+                                self.fp16)
+        del configs
+
+
+def AutoModel(**kwargs):
+    if not os.path.exists(kwargs['model_dir']):
+        kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
+    if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
+        return CosyVoice(**kwargs)
+    elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
+        return CosyVoice2(**kwargs)
+    elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
+        return CosyVoice3(**kwargs)
+    else:
+        raise TypeError('No valid model type found!')

+ 23 - 19
cosyvoice/cli/frontend.py

@@ -32,7 +32,7 @@ except ImportError:
     from wetext import Normalizer as ZhNormalizer
     from wetext import Normalizer as ZhNormalizer
     from wetext import Normalizer as EnNormalizer
     from wetext import Normalizer as EnNormalizer
     use_ttsfrd = False
     use_ttsfrd = False
-from cosyvoice.utils.file_utils import logging
+from cosyvoice.utils.file_utils import logging, load_wav
 from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
 from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
 
 
 
 
@@ -89,7 +89,8 @@ class CosyVoiceFrontEnd:
             for i in range(text_token.shape[1]):
             for i in range(text_token.shape[1]):
                 yield text_token[:, i: i + 1]
                 yield text_token[:, i: i + 1]
 
 
-    def _extract_speech_token(self, speech):
+    def _extract_speech_token(self, prompt_wav):
+        speech = load_wav(prompt_wav, 16000)
         assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
         assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
         feat = whisper.log_mel_spectrogram(speech, n_mels=128)
         feat = whisper.log_mel_spectrogram(speech, n_mels=128)
         speech_token = self.speech_tokenizer_session.run(None,
         speech_token = self.speech_tokenizer_session.run(None,
@@ -101,7 +102,8 @@ class CosyVoiceFrontEnd:
         speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
         speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
         return speech_token, speech_token_len
         return speech_token, speech_token_len
 
 
-    def _extract_spk_embedding(self, speech):
+    def _extract_spk_embedding(self, prompt_wav):
+        speech = load_wav(prompt_wav, 16000)
         feat = kaldi.fbank(speech,
         feat = kaldi.fbank(speech,
                            num_mel_bins=80,
                            num_mel_bins=80,
                            dither=0,
                            dither=0,
@@ -112,7 +114,8 @@ class CosyVoiceFrontEnd:
         embedding = torch.tensor([embedding]).to(self.device)
         embedding = torch.tensor([embedding]).to(self.device)
         return embedding
         return embedding
 
 
-    def _extract_speech_feat(self, speech):
+    def _extract_speech_feat(self, prompt_wav):
+        speech = load_wav(prompt_wav, 24000)
         speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
         speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
         speech_feat = speech_feat.unsqueeze(dim=0)
         speech_feat = speech_feat.unsqueeze(dim=0)
         speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
         speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
@@ -122,6 +125,9 @@ class CosyVoiceFrontEnd:
         if isinstance(text, Generator):
         if isinstance(text, Generator):
             logging.info('get tts_text generator, will skip text_normalize!')
             logging.info('get tts_text generator, will skip text_normalize!')
             return [text]
             return [text]
+        # NOTE skip text_frontend when ssml symbol in text
+        if '<|' in text and '|>' in text:
+            text_frontend = False
         if text_frontend is False or text == '':
         if text_frontend is False or text == '':
             return [text] if split is True else text
             return [text] if split is True else text
         text = text.strip()
         text = text.strip()
@@ -154,19 +160,18 @@ class CosyVoiceFrontEnd:
         model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
         model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
         return model_input
         return model_input
 
 
-    def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
+    def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
         tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
         tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
         if zero_shot_spk_id == '':
         if zero_shot_spk_id == '':
             prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
             prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
-            prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
-            speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
-            speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
+            speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
+            speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
             if resample_rate == 24000:
             if resample_rate == 24000:
                 # cosyvoice2, force speech_feat % speech_token = 2
                 # cosyvoice2, force speech_feat % speech_token = 2
                 token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
                 token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
                 speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
                 speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
                 speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
                 speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
-            embedding = self._extract_spk_embedding(prompt_speech_16k)
+            embedding = self._extract_spk_embedding(prompt_wav)
             model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
             model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
                            'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
                            'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
                            'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
                            'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
@@ -178,8 +183,8 @@ class CosyVoiceFrontEnd:
         model_input['text_len'] = tts_text_token_len
         model_input['text_len'] = tts_text_token_len
         return model_input
         return model_input
 
 
-    def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
-        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
+    def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
+        model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
         # in cross lingual mode, we remove prompt in llm
         # in cross lingual mode, we remove prompt in llm
         del model_input['prompt_text']
         del model_input['prompt_text']
         del model_input['prompt_text_len']
         del model_input['prompt_text_len']
@@ -191,22 +196,21 @@ class CosyVoiceFrontEnd:
         model_input = self.frontend_sft(tts_text, spk_id)
         model_input = self.frontend_sft(tts_text, spk_id)
         # in instruct mode, we remove spk_embedding in llm due to information leakage
         # in instruct mode, we remove spk_embedding in llm due to information leakage
         del model_input['llm_embedding']
         del model_input['llm_embedding']
-        instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
+        instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
         model_input['prompt_text'] = instruct_text_token
         model_input['prompt_text'] = instruct_text_token
         model_input['prompt_text_len'] = instruct_text_token_len
         model_input['prompt_text_len'] = instruct_text_token_len
         return model_input
         return model_input
 
 
-    def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
-        model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate, zero_shot_spk_id)
+    def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
+        model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
         del model_input['llm_prompt_speech_token']
         del model_input['llm_prompt_speech_token']
         del model_input['llm_prompt_speech_token_len']
         del model_input['llm_prompt_speech_token_len']
         return model_input
         return model_input
 
 
-    def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
-        prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
-        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
-        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
-        embedding = self._extract_spk_embedding(prompt_speech_16k)
+    def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
+        prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
+        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
+        embedding = self._extract_spk_embedding(prompt_wav)
         source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
         source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
         model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
         model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
                        'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
                        'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,

+ 52 - 8
cosyvoice/cli/model.py

@@ -38,9 +38,6 @@ class CosyVoiceModel:
         self.flow = flow
         self.flow = flow
         self.hift = hift
         self.hift = hift
         self.fp16 = fp16
         self.fp16 = fp16
-        if self.fp16 is True:
-            self.llm.half()
-            self.flow.half()
         self.token_min_hop_len = 2 * self.flow.input_frame_rate
         self.token_min_hop_len = 2 * self.flow.input_frame_rate
         self.token_max_hop_len = 4 * self.flow.input_frame_rate
         self.token_max_hop_len = 4 * self.flow.input_frame_rate
         self.token_overlap_len = 20
         self.token_overlap_len = 20
@@ -129,7 +126,7 @@ class CosyVoiceModel:
 
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
     def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
         with torch.cuda.amp.autocast(self.fp16):
         with torch.cuda.amp.autocast(self.fp16):
-            tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
+            tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
                                                                       token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                                                       token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                                                       prompt_token=prompt_token.to(self.device),
                                                                       prompt_token=prompt_token.to(self.device),
                                                                       prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                                                       prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
@@ -249,9 +246,6 @@ class CosyVoice2Model(CosyVoiceModel):
         self.flow = flow
         self.flow = flow
         self.hift = hift
         self.hift = hift
         self.fp16 = fp16
         self.fp16 = fp16
-        if self.fp16 is True:
-            self.llm.half()
-            self.flow.half()
         # NOTE must matching training static_chunk_size
         # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
         self.token_hop_len = 25
         # hift cache
         # hift cache
@@ -284,7 +278,7 @@ class CosyVoice2Model(CosyVoiceModel):
 
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
     def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
         with torch.cuda.amp.autocast(self.fp16):
         with torch.cuda.amp.autocast(self.fp16):
-            tts_mel, _ = self.flow.inference(token=token.to(self.device),
+            tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
                                              token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                              token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                              prompt_token=prompt_token.to(self.device),
                                              prompt_token=prompt_token.to(self.device),
                                              prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                              prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
@@ -384,3 +378,53 @@ class CosyVoice2Model(CosyVoiceModel):
         if torch.cuda.is_available():
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
             torch.cuda.empty_cache()
             torch.cuda.current_stream().synchronize()
             torch.cuda.current_stream().synchronize()
+
+
+class CosyVoice3Model(CosyVoice2Model):
+
+    def __init__(self,
+                 llm: torch.nn.Module,
+                 flow: torch.nn.Module,
+                 hift: torch.nn.Module,
+                 fp16: bool = False):
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.llm = llm
+        self.flow = flow
+        self.hift = hift
+        self.fp16 = fp16
+        # NOTE must matching training static_chunk_size
+        self.token_hop_len = 25
+        # rtf and decoding related
+        self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
+        self.lock = threading.Lock()
+        # dict used to store session related variable
+        self.tts_speech_token_dict = {}
+        self.llm_end_dict = {}
+        self.hift_cache_dict = {}
+
+    def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
+        with torch.cuda.amp.autocast(self.fp16):
+            tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
+                                             token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_token=prompt_token.to(self.device),
+                                             prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
+                                             prompt_feat=prompt_feat.to(self.device),
+                                             prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
+                                             embedding=embedding.to(self.device),
+                                             streaming=stream,
+                                             finalize=finalize)
+            tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
+            # append mel cache
+            if self.hift_cache_dict[uuid] is not None:
+                hift_cache_mel = self.hift_cache_dict[uuid]['mel']
+                tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
+                self.hift_cache_dict[uuid]['mel'] = tts_mel
+            else:
+                self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
+            if speed != 1.0:
+                assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
+                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
+            tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
+            tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
+            self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
+        return tts_speech

+ 9 - 0
cosyvoice/dataset/processor.py

@@ -242,6 +242,10 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
     for sample in data:
     for sample in data:
         assert 'text' in sample
         assert 'text' in sample
         sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
         sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
+        if 'instruct' in sample:
+            sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
+        else:
+            sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
         yield sample
         yield sample
 
 
 
 
@@ -390,6 +394,9 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
         text_token = [torch.tensor(sample[i]['text_token']) for i in order]
         text_token = [torch.tensor(sample[i]['text_token']) for i in order]
         text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
         text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
         text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
         text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
+        instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
+        instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
+        instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
         utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
         utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
         spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
         spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
         batch = {
         batch = {
@@ -403,6 +410,8 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
             "text": text,
             "text": text,
             "text_token": text_token,
             "text_token": text_token,
             "text_token_len": text_token_len,
             "text_token_len": text_token_len,
+            "instruct_token": instruct_token,
+            "instruct_token_len": instruct_token_len,
             "utt_embedding": utt_embedding,
             "utt_embedding": utt_embedding,
             "spk_embedding": spk_embedding,
             "spk_embedding": spk_embedding,
         }
         }

+ 176 - 0
cosyvoice/flow/DiT/dit.py

@@ -0,0 +1,176 @@
+
+"""
+ein notation:
+b - batch
+n - sequence
+nt - text sequence
+nw - raw wave length
+d - dimension
+"""
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import repeat
+from x_transformers.x_transformers import RotaryEmbedding
+from cosyvoice.utils.mask import add_optional_chunk_mask
+from cosyvoice.flow.DiT.modules import (
+    TimestepEmbedding,
+    ConvNeXtV2Block,
+    CausalConvPositionEmbedding,
+    DiTBlock,
+    AdaLayerNormZero_Final,
+    precompute_freqs_cis,
+    get_pos_embed_indices,
+)
+
+
+# Text embedding
+
+
+class TextEmbedding(nn.Module):
+    def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
+        super().__init__()
+        self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim)  # use 0 as filler token
+
+        if conv_layers > 0:
+            self.extra_modeling = True
+            self.precompute_max_pos = 4096  # ~44s of 24khz audio
+            self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
+            self.text_blocks = nn.Sequential(
+                *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
+            )
+        else:
+            self.extra_modeling = False
+
+    def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
+        batch, text_len = text.shape[0], text.shape[1]
+        text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
+        text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
+        text = F.pad(text, (0, seq_len - text_len), value=0)
+
+        if drop_text:  # cfg for text
+            text = torch.zeros_like(text)
+
+        text = self.text_embed(text)  # b n -> b n d
+
+        # possible extra modeling
+        if self.extra_modeling:
+            # sinus pos emb
+            batch_start = torch.zeros((batch,), dtype=torch.long)
+            pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
+            text_pos_embed = self.freqs_cis[pos_idx]
+            text = text + text_pos_embed
+
+            # convnextv2 blocks
+            text = self.text_blocks(text)
+
+        return text
+
+
+# noised input audio and context mixing embedding
+
+
+class InputEmbedding(nn.Module):
+    def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
+        super().__init__()
+        spk_dim = 0 if spk_dim is None else spk_dim
+        self.spk_dim = spk_dim
+        self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
+        self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
+
+    def forward(
+            self,
+            x: float["b n d"],
+            cond: float["b n d"],
+            text_embed: float["b n d"],
+            spks: float["b d"],
+    ):
+        to_cat = [x, cond, text_embed]
+        if self.spk_dim > 0:
+            spks = repeat(spks, "b c -> b t c", t=x.shape[1])
+            to_cat.append(spks)
+
+        x = self.proj(torch.cat(to_cat, dim=-1))
+        x = self.conv_pos_embed(x) + x
+        return x
+
+
+# Transformer backbone using DiT blocks
+
+
+class DiT(nn.Module):
+    def __init__(
+        self,
+        *,
+        dim,
+        depth=8,
+        heads=8,
+        dim_head=64,
+        dropout=0.1,
+        ff_mult=4,
+        mel_dim=80,
+        mu_dim=None,
+        long_skip_connection=False,
+        spk_dim=None,
+        out_channels=None,
+        static_chunk_size=50,
+        num_decoding_left_chunks=2
+    ):
+        super().__init__()
+
+        self.time_embed = TimestepEmbedding(dim)
+        if mu_dim is None:
+            mu_dim = mel_dim
+        self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
+
+        self.rotary_embed = RotaryEmbedding(dim_head)
+
+        self.dim = dim
+        self.depth = depth
+
+        self.transformer_blocks = nn.ModuleList(
+            [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
+        )
+        self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
+
+        self.norm_out = AdaLayerNormZero_Final(dim)  # final modulation
+        self.proj_out = nn.Linear(dim, mel_dim)
+        self.out_channels = out_channels
+        self.static_chunk_size = static_chunk_size
+        self.num_decoding_left_chunks = num_decoding_left_chunks
+
+    def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
+        x = x.transpose(1, 2)
+        mu = mu.transpose(1, 2)
+        cond = cond.transpose(1, 2)
+        spks = spks.unsqueeze(dim=1)
+        batch, seq_len = x.shape[0], x.shape[1]
+        if t.ndim == 0:
+            t = t.repeat(batch)
+
+        # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
+        t = self.time_embed(t)
+        x = self.input_embed(x, cond, mu, spks.squeeze(1))
+
+        rope = self.rotary_embed.forward_from_seq_len(seq_len)
+
+        if self.long_skip_connection is not None:
+            residual = x
+
+        if streaming is True:
+            attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
+        else:
+            attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
+
+        for block in self.transformer_blocks:
+            x = block(x, t, mask=attn_mask.bool(), rope=rope)
+
+        if self.long_skip_connection is not None:
+            x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
+
+        x = self.norm_out(x, t)
+        output = self.proj_out(x).transpose(1, 2)
+        return output

+ 616 - 0
cosyvoice/flow/DiT/modules.py

@@ -0,0 +1,616 @@
+
+"""
+ein notation:
+b - batch
+n - sequence
+nt - text sequence
+nw - raw wave length
+d - dimension
+"""
+
+from __future__ import annotations
+from typing import Optional
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+import torchaudio
+
+from x_transformers.x_transformers import apply_rotary_pos_emb
+
+
+# raw wav to mel spec
+class MelSpec(nn.Module):
+    def __init__(
+        self,
+        filter_length=1024,
+        hop_length=256,
+        win_length=1024,
+        n_mel_channels=100,
+        target_sample_rate=24_000,
+        normalize=False,
+        power=1,
+        norm=None,
+        center=True,
+    ):
+        super().__init__()
+        self.n_mel_channels = n_mel_channels
+
+        self.mel_stft = torchaudio.transforms.MelSpectrogram(
+            sample_rate=target_sample_rate,
+            n_fft=filter_length,
+            win_length=win_length,
+            hop_length=hop_length,
+            n_mels=n_mel_channels,
+            power=power,
+            center=center,
+            normalized=normalize,
+            norm=norm,
+        )
+
+        self.register_buffer("dummy", torch.tensor(0), persistent=False)
+
+    def forward(self, inp):
+        if len(inp.shape) == 3:
+            inp = inp.squeeze(1)  # 'b 1 nw -> b nw'
+
+        assert len(inp.shape) == 2
+
+        if self.dummy.device != inp.device:
+            self.to(inp.device)
+
+        mel = self.mel_stft(inp)
+        mel = mel.clamp(min=1e-5).log()
+        return mel
+
+
+# sinusoidal position embedding
+
+
+class SinusPositionEmbedding(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.dim = dim
+
+    def forward(self, x, scale=1000):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
+        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+# convolutional position embedding
+
+
+class ConvPositionEmbedding(nn.Module):
+    def __init__(self, dim, kernel_size=31, groups=16):
+        super().__init__()
+        assert kernel_size % 2 != 0
+        self.conv1d = nn.Sequential(
+            nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
+            nn.Mish(),
+            nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
+            nn.Mish(),
+        )
+
+    def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722
+        if mask is not None:
+            mask = mask[..., None]
+            x = x.masked_fill(~mask, 0.0)
+
+        x = x.permute(0, 2, 1)
+        x = self.conv1d(x)
+        out = x.permute(0, 2, 1)
+
+        if mask is not None:
+            out = out.masked_fill(~mask, 0.0)
+
+        return out
+
+
+class CausalConvPositionEmbedding(nn.Module):
+    def __init__(self, dim, kernel_size=31, groups=16):
+        super().__init__()
+        assert kernel_size % 2 != 0
+        self.kernel_size = kernel_size
+        self.conv1 = nn.Sequential(
+            nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
+            nn.Mish(),
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
+            nn.Mish(),
+        )
+
+    def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):  # noqa: F722
+        if mask is not None:
+            mask = mask[..., None]
+            x = x.masked_fill(~mask, 0.0)
+
+        x = x.permute(0, 2, 1)
+        x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
+        x = self.conv1(x)
+        x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
+        x = self.conv2(x)
+        out = x.permute(0, 2, 1)
+
+        if mask is not None:
+            out = out.masked_fill(~mask, 0.0)
+
+        return out
+
+
+# rotary positional embedding related
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
+    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
+    # has some connection to NTK literature
+    # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+    # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
+    theta *= theta_rescale_factor ** (dim / (dim - 2))
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+    t = torch.arange(end, device=freqs.device)  # type: ignore
+    freqs = torch.outer(t, freqs).float()  # type: ignore
+    freqs_cos = torch.cos(freqs)  # real part
+    freqs_sin = torch.sin(freqs)  # imaginary part
+    return torch.cat([freqs_cos, freqs_sin], dim=-1)
+
+
+def get_pos_embed_indices(start, length, max_pos, scale=1.0):
+    # length = length if isinstance(length, int) else length.max()
+    scale = scale * torch.ones_like(start, dtype=torch.float32)  # in case scale is a scalar
+    pos = (
+        start.unsqueeze(1)
+        + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
+    )
+    # avoid extra long error.
+    pos = torch.where(pos < max_pos, pos, max_pos - 1)
+    return pos
+
+
+# Global Response Normalization layer (Instance Normalization ?)
+
+
+class GRN(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
+        self.beta = nn.Parameter(torch.zeros(1, 1, dim))
+
+    def forward(self, x):
+        Gx = torch.norm(x, p=2, dim=1, keepdim=True)
+        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+        return self.gamma * (x * Nx) + self.beta + x
+
+
+# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
+# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
+
+
+class ConvNeXtV2Block(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        intermediate_dim: int,
+        dilation: int = 1,
+    ):
+        super().__init__()
+        padding = (dilation * (7 - 1)) // 2
+        self.dwconv = nn.Conv1d(
+            dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
+        )  # depthwise conv
+        self.norm = nn.LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(dim, intermediate_dim)  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.grn = GRN(intermediate_dim)
+        self.pwconv2 = nn.Linear(intermediate_dim, dim)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        residual = x
+        x = x.transpose(1, 2)  # b n d -> b d n
+        x = self.dwconv(x)
+        x = x.transpose(1, 2)  # b d n -> b n d
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.grn(x)
+        x = self.pwconv2(x)
+        return residual + x
+
+
+# AdaLayerNormZero
+# return with modulated x for attn input, and params for later mlp modulation
+
+
+class AdaLayerNormZero(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+
+        self.silu = nn.SiLU()
+        self.linear = nn.Linear(dim, dim * 6)
+
+        self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+    def forward(self, x, emb=None):
+        emb = self.linear(self.silu(emb))
+        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
+
+        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+# AdaLayerNormZero for final layer
+# return only with modulated x for attn input, cuz no more mlp modulation
+
+
+class AdaLayerNormZero_Final(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+
+        self.silu = nn.SiLU()
+        self.linear = nn.Linear(dim, dim * 2)
+
+        self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+
+    def forward(self, x, emb):
+        emb = self.linear(self.silu(emb))
+        scale, shift = torch.chunk(emb, 2, dim=1)
+
+        x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+        return x
+
+
+# FeedForward
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
+        super().__init__()
+        inner_dim = int(dim * mult)
+        dim_out = dim_out if dim_out is not None else dim
+
+        activation = nn.GELU(approximate=approximate)
+        project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
+        self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
+
+    def forward(self, x):
+        return self.ff(x)
+
+
+# Attention with possible joint part
+# modified from diffusers/src/diffusers/models/attention_processor.py
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        processor: JointAttnProcessor | AttnProcessor,
+        dim: int,
+        heads: int = 8,
+        dim_head: int = 64,
+        dropout: float = 0.0,
+        context_dim: Optional[int] = None,  # if not None -> joint attention
+        context_pre_only=None,
+    ):
+        super().__init__()
+
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+        self.processor = processor
+
+        self.dim = dim
+        self.heads = heads
+        self.inner_dim = dim_head * heads
+        self.dropout = dropout
+
+        self.context_dim = context_dim
+        self.context_pre_only = context_pre_only
+
+        self.to_q = nn.Linear(dim, self.inner_dim)
+        self.to_k = nn.Linear(dim, self.inner_dim)
+        self.to_v = nn.Linear(dim, self.inner_dim)
+
+        if self.context_dim is not None:
+            self.to_k_c = nn.Linear(context_dim, self.inner_dim)
+            self.to_v_c = nn.Linear(context_dim, self.inner_dim)
+            if self.context_pre_only is not None:
+                self.to_q_c = nn.Linear(context_dim, self.inner_dim)
+
+        self.to_out = nn.ModuleList([])
+        self.to_out.append(nn.Linear(self.inner_dim, dim))
+        self.to_out.append(nn.Dropout(dropout))
+
+        if self.context_pre_only is not None and not self.context_pre_only:
+            self.to_out_c = nn.Linear(self.inner_dim, dim)
+
+    def forward(
+        self,
+        x: float["b n d"],  # noised input x  # noqa: F722
+        c: float["b n d"] = None,  # context c  # noqa: F722
+        mask: bool["b n"] | None = None,  # noqa: F722
+        rope=None,  # rotary position embedding for x
+        c_rope=None,  # rotary position embedding for c
+    ) -> torch.Tensor:
+        if c is not None:
+            return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
+        else:
+            return self.processor(self, x, mask=mask, rope=rope)
+
+
+# Attention processor
+
+
+class AttnProcessor:
+    def __init__(self):
+        pass
+
+    def __call__(
+        self,
+        attn: Attention,
+        x: float["b n d"],  # noised input x  # noqa: F722
+        mask: bool["b n"] | None = None,  # noqa: F722
+        rope=None,  # rotary position embedding
+    ) -> torch.FloatTensor:
+        batch_size = x.shape[0]
+
+        # `sample` projections.
+        query = attn.to_q(x)
+        key = attn.to_k(x)
+        value = attn.to_v(x)
+
+        # apply rotary position embedding
+        if rope is not None:
+            freqs, xpos_scale = rope
+            q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
+
+            query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
+            key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
+
+        # attention
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # mask. e.g. inference got a batch with different target durations, mask out the padding
+        if mask is not None:
+            attn_mask = mask
+            if attn_mask.dim() == 2:
+                attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)  # 'b n -> b 1 1 n'
+                attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
+        else:
+            attn_mask = None
+
+        x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
+        x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        x = x.to(query.dtype)
+
+        # linear proj
+        x = attn.to_out[0](x)
+        # dropout
+        x = attn.to_out[1](x)
+
+        if mask is not None:
+            if mask.dim() == 2:
+                mask = mask.unsqueeze(-1)
+            else:
+                mask = mask[:, 0, -1].unsqueeze(-1)
+            x = x.masked_fill(~mask, 0.0)
+
+        return x
+
+
+# Joint Attention processor for MM-DiT
+# modified from diffusers/src/diffusers/models/attention_processor.py
+
+
+class JointAttnProcessor:
+    def __init__(self):
+        pass
+
+    def __call__(
+        self,
+        attn: Attention,
+        x: float["b n d"],  # noised input x  # noqa: F722
+        c: float["b nt d"] = None,  # context c, here text # noqa: F722
+        mask: bool["b n"] | None = None,  # noqa: F722
+        rope=None,  # rotary position embedding for x
+        c_rope=None,  # rotary position embedding for c
+    ) -> torch.FloatTensor:
+        residual = x
+
+        batch_size = c.shape[0]
+
+        # `sample` projections.
+        query = attn.to_q(x)
+        key = attn.to_k(x)
+        value = attn.to_v(x)
+
+        # `context` projections.
+        c_query = attn.to_q_c(c)
+        c_key = attn.to_k_c(c)
+        c_value = attn.to_v_c(c)
+
+        # apply rope for context and noised input independently
+        if rope is not None:
+            freqs, xpos_scale = rope
+            q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
+            query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
+            key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
+        if c_rope is not None:
+            freqs, xpos_scale = c_rope
+            q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
+            c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
+            c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
+
+        # attention
+        query = torch.cat([query, c_query], dim=1)
+        key = torch.cat([key, c_key], dim=1)
+        value = torch.cat([value, c_value], dim=1)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # mask. e.g. inference got a batch with different target durations, mask out the padding
+        if mask is not None:
+            attn_mask = F.pad(mask, (0, c.shape[1]), value=True)  # no mask for c (text)
+            attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)  # 'b n -> b 1 1 n'
+            attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
+        else:
+            attn_mask = None
+
+        x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
+        x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        x = x.to(query.dtype)
+
+        # Split the attention outputs.
+        x, c = (
+            x[:, : residual.shape[1]],
+            x[:, residual.shape[1]:],
+        )
+
+        # linear proj
+        x = attn.to_out[0](x)
+        # dropout
+        x = attn.to_out[1](x)
+        if not attn.context_pre_only:
+            c = attn.to_out_c(c)
+
+        if mask is not None:
+            mask = mask.unsqueeze(-1)
+            x = x.masked_fill(~mask, 0.0)
+            # c = c.masked_fill(~mask, 0.)  # no mask for c (text)
+
+        return x, c
+
+
+# DiT Block
+
+
+class DiTBlock(nn.Module):
+    def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
+        super().__init__()
+
+        self.attn_norm = AdaLayerNormZero(dim)
+        self.attn = Attention(
+            processor=AttnProcessor(),
+            dim=dim,
+            heads=heads,
+            dim_head=dim_head,
+            dropout=dropout,
+        )
+
+        self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+        self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
+
+    def forward(self, x, t, mask=None, rope=None):  # x: noised input, t: time embedding
+        # pre-norm & modulation for attention input
+        norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
+
+        # attention
+        attn_output = self.attn(x=norm, mask=mask, rope=rope)
+
+        # process attention output for input x
+        x = x + gate_msa.unsqueeze(1) * attn_output
+
+        ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+        ff_output = self.ff(ff_norm)
+        x = x + gate_mlp.unsqueeze(1) * ff_output
+
+        return x
+
+
+# MMDiT Block https://arxiv.org/abs/2403.03206
+
+
+class MMDiTBlock(nn.Module):
+    r"""
+    modified from diffusers/src/diffusers/models/attention.py
+
+    notes.
+    _c: context related. text, cond, etc. (left part in sd3 fig2.b)
+    _x: noised input related. (right part)
+    context_pre_only: last layer only do prenorm + modulation cuz no more ffn
+    """
+
+    def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
+        super().__init__()
+
+        self.context_pre_only = context_pre_only
+
+        self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
+        self.attn_norm_x = AdaLayerNormZero(dim)
+        self.attn = Attention(
+            processor=JointAttnProcessor(),
+            dim=dim,
+            heads=heads,
+            dim_head=dim_head,
+            dropout=dropout,
+            context_dim=dim,
+            context_pre_only=context_pre_only,
+        )
+
+        if not context_pre_only:
+            self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+            self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
+        else:
+            self.ff_norm_c = None
+            self.ff_c = None
+        self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+        self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
+
+    def forward(self, x, c, t, mask=None, rope=None, c_rope=None):  # x: noised input, c: context, t: time embedding
+        # pre-norm & modulation for attention input
+        if self.context_pre_only:
+            norm_c = self.attn_norm_c(c, t)
+        else:
+            norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
+        norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
+
+        # attention
+        x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
+
+        # process attention output for context c
+        if self.context_pre_only:
+            c = None
+        else:  # if not last layer
+            c = c + c_gate_msa.unsqueeze(1) * c_attn_output
+
+            norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+            c_ff_output = self.ff_c(norm_c)
+            c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
+
+        # process attention output for input x
+        x = x + x_gate_msa.unsqueeze(1) * x_attn_output
+
+        norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
+        x_ff_output = self.ff_x(norm_x)
+        x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
+
+        return c, x
+
+
+# time step conditioning embedding
+
+
+class TimestepEmbedding(nn.Module):
+    def __init__(self, dim, freq_embed_dim=256):
+        super().__init__()
+        self.time_embed = SinusPositionEmbedding(freq_embed_dim)
+        self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
+
+    def forward(self, timestep: float["b"]):  # noqa: F821
+        time_hidden = self.time_embed(timestep)
+        time_hidden = time_hidden.to(timestep.dtype)
+        time = self.time_mlp(time_hidden)  # b d
+        return time

+ 159 - 8
cosyvoice/flow/flow.py

@@ -37,14 +37,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
                                        'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                        'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                                                  'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                                                  'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                        'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
                                        'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
-                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
-                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
-                                        'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
+                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
         super().__init__()
         super().__init__()
         self.input_size = input_size
         self.input_size = input_size
         self.output_size = output_size
         self.output_size = output_size
         self.decoder_conf = decoder_conf
         self.decoder_conf = decoder_conf
-        self.mel_feat_conf = mel_feat_conf
         self.vocab_size = vocab_size
         self.vocab_size = vocab_size
         self.output_type = output_type
         self.output_type = output_type
         self.input_frame_rate = input_frame_rate
         self.input_frame_rate = input_frame_rate
@@ -165,14 +162,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
                                        'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                        'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                                                  'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                                                  'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                        'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
                                        'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
-                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
-                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
-                                        'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
+                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
         super().__init__()
         super().__init__()
         self.input_size = input_size
         self.input_size = input_size
         self.output_size = output_size
         self.output_size = output_size
         self.decoder_conf = decoder_conf
         self.decoder_conf = decoder_conf
-        self.mel_feat_conf = mel_feat_conf
         self.vocab_size = vocab_size
         self.vocab_size = vocab_size
         self.output_type = output_type
         self.output_type = output_type
         self.input_frame_rate = input_frame_rate
         self.input_frame_rate = input_frame_rate
@@ -279,3 +273,160 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
         feat = feat[:, :, mel_len1:]
         feat = feat[:, :, mel_len1:]
         assert feat.shape[2] == mel_len2
         assert feat.shape[2] == mel_len2
         return feat.float(), None
         return feat.float(), None
+
+
+class CausalMaskedDiffWithDiT(torch.nn.Module):
+    def __init__(self,
+                 input_size: int = 512,
+                 output_size: int = 80,
+                 spk_embed_dim: int = 192,
+                 output_type: str = "mel",
+                 vocab_size: int = 4096,
+                 input_frame_rate: int = 50,
+                 only_mask_loss: bool = True,
+                 token_mel_ratio: int = 2,
+                 pre_lookahead_len: int = 3,
+                 pre_lookahead_layer: torch.nn.Module = None,
+                 decoder: torch.nn.Module = None,
+                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
+                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
+                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
+                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
+                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
+        super().__init__()
+        self.input_size = input_size
+        self.output_size = output_size
+        self.decoder_conf = decoder_conf
+        self.vocab_size = vocab_size
+        self.output_type = output_type
+        self.input_frame_rate = input_frame_rate
+        logging.info(f"input frame rate={self.input_frame_rate}")
+        self.input_embedding = nn.Embedding(vocab_size, input_size)
+        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
+        self.pre_lookahead_len = pre_lookahead_len
+        self.pre_lookahead_layer = pre_lookahead_layer
+        self.decoder = decoder
+        self.only_mask_loss = only_mask_loss
+        self.token_mel_ratio = token_mel_ratio
+
+    def forward(
+            self,
+            batch: dict,
+            device: torch.device,
+    ) -> Dict[str, Optional[torch.Tensor]]:
+        token = batch['speech_token'].to(device)
+        token_len = batch['speech_token_len'].to(device)
+        feat = batch['speech_feat'].to(device)
+        feat_len = batch['speech_feat_len'].to(device)
+        embedding = batch['embedding'].to(device)
+
+        # NOTE unified training, static_chunk_size > 0 or = 0
+        streaming = True if random.random() < 0.5 else False
+
+        # xvec projection
+        embedding = F.normalize(embedding, dim=1)
+        embedding = self.spk_embed_affine_layer(embedding)
+
+        # concat text and prompt_text
+        mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
+        token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+        # text encode
+        h, h_lengths = self.encoder(token, token_len, streaming=streaming)
+        h = self.encoder_proj(h)
+
+        # get conditions
+        conds = torch.zeros(feat.shape, device=token.device)
+        for i, j in enumerate(feat_len):
+            if random.random() < 0.5:
+                continue
+            index = random.randint(0, int(0.3 * j))
+            conds[i, :index] = feat[i, :index]
+        conds = conds.transpose(1, 2)
+
+        mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
+        loss, _ = self.decoder.compute_loss(
+            feat.transpose(1, 2).contiguous(),
+            mask.unsqueeze(1),
+            h.transpose(1, 2).contiguous(),
+            embedding,
+            cond=conds,
+            streaming=streaming,
+        )
+        return {'loss': loss}
+
+    @torch.inference_mode()
+    def inference(self,
+                  token,
+                  token_len,
+                  prompt_token,
+                  prompt_token_len,
+                  prompt_feat,
+                  prompt_feat_len,
+                  embedding,
+                  streaming,
+                  finalize):
+        assert token.shape[0] == 1
+        # xvec projection
+        embedding = F.normalize(embedding, dim=1)
+        embedding = self.spk_embed_affine_layer(embedding)
+
+        # concat text and prompt_text
+        token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
+        mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
+        token = self.input_embedding(torch.clamp(token, min=0)) * mask
+
+        # text encode
+        if finalize is True:
+            h = self.pre_lookahead_layer(token)
+        else:
+            h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
+        h = h.repeat_interleave(self.token_mel_ratio, dim=1)
+        mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
+
+        # get conditions
+        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
+        conds[:, :mel_len1] = prompt_feat
+        conds = conds.transpose(1, 2)
+
+        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
+        feat, _ = self.decoder(
+            mu=h.transpose(1, 2).contiguous(),
+            mask=mask.unsqueeze(1),
+            spks=embedding,
+            cond=conds,
+            n_timesteps=10,
+            streaming=streaming
+        )
+        feat = feat[:, :, mel_len1:]
+        assert feat.shape[2] == mel_len2
+        return feat.float(), None
+
+
+if __name__ == '__main__':
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False
+    from hyperpyyaml import load_hyperpyyaml
+    with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
+        configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
+    model = configs['flow']
+    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+    model.to(device)
+    model.eval()
+    max_len = 10 * model.decoder.estimator.static_chunk_size
+    chunk_size = model.decoder.estimator.static_chunk_size
+    context_size = model.pre_lookahead_layer.pre_lookahead_len
+    token = torch.randint(0, 6561, size=(1, max_len)).to(device)
+    token_len = torch.tensor([max_len]).to(device)
+    prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
+    prompt_token_len = torch.tensor([chunk_size]).to(device)
+    prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
+    prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
+    prompt_embedding = torch.rand(1, 192).to(device)
+    pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
+    for i in range(0, max_len, chunk_size):
+        finalize = True if i + chunk_size + context_size >= max_len else False
+        pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
+                                        prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
+        pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
+        print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())

+ 7 - 6
cosyvoice/flow/flow_matching.py

@@ -91,12 +91,13 @@ class ConditionalCFM(BASECFM):
         sol = []
         sol = []
 
 
         # Do not use concat, it may cause memory format changed and trt infer with wrong results!
         # Do not use concat, it may cause memory format changed and trt infer with wrong results!
-        x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-        mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
-        mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-        t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
-        spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
-        cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+        # NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
+        x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
+        mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
+        mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
+        t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
+        spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
+        cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
         for step in range(1, len(t_span)):
         for step in range(1, len(t_span)):
             # Classifier-Free Guidance inference introduced in VoiceBox
             # Classifier-Free Guidance inference introduced in VoiceBox
             x_in[:] = x
             x_in[:] = x

+ 45 - 0
cosyvoice/hifigan/f0_predictor.py

@@ -17,6 +17,7 @@ try:
     from torch.nn.utils.parametrizations import weight_norm
     from torch.nn.utils.parametrizations import weight_norm
 except ImportError:
 except ImportError:
     from torch.nn.utils import weight_norm
     from torch.nn.utils import weight_norm
+from cosyvoice.transformer.convolution import CausalConv1d
 
 
 
 
 class ConvRNNF0Predictor(nn.Module):
 class ConvRNNF0Predictor(nn.Module):
@@ -56,3 +57,47 @@ class ConvRNNF0Predictor(nn.Module):
         x = self.condnet(x)
         x = self.condnet(x)
         x = x.transpose(1, 2)
         x = x.transpose(1, 2)
         return torch.abs(self.classifier(x).squeeze(-1))
         return torch.abs(self.classifier(x).squeeze(-1))
+
+
+class CausalConvRNNF0Predictor(nn.Module):
+    def __init__(self,
+                 num_class: int = 1,
+                 in_channels: int = 80,
+                 cond_channels: int = 512
+                 ):
+        super().__init__()
+
+        self.num_class = num_class
+        self.condnet = nn.Sequential(
+            weight_norm(
+                CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
+            ),
+            nn.ELU(),
+            weight_norm(
+                CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
+            ),
+            nn.ELU(),
+            weight_norm(
+                CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
+            ),
+            nn.ELU(),
+            weight_norm(
+                CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
+            ),
+            nn.ELU(),
+            weight_norm(
+                CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
+            ),
+            nn.ELU(),
+        )
+        self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
+
+    def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
+        if finalize is True:
+            x = self.condnet[0](x)
+        else:
+            x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
+        for i in range(1, len(self.condnet)):
+            x = self.condnet[i](x)
+        x = x.transpose(1, 2)
+        return torch.abs(self.classifier(x).squeeze(-1))

+ 240 - 76
cosyvoice/hifigan/generator.py

@@ -28,7 +28,7 @@ try:
 except ImportError:
 except ImportError:
     from torch.nn.utils import weight_norm
     from torch.nn.utils import weight_norm
 from torch.distributions.uniform import Uniform
 from torch.distributions.uniform import Uniform
-
+from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
 from cosyvoice.transformer.activation import Snake
 from cosyvoice.transformer.activation import Snake
 from cosyvoice.utils.common import get_padding
 from cosyvoice.utils.common import get_padding
 from cosyvoice.utils.common import init_weights
 from cosyvoice.utils.common import init_weights
@@ -50,8 +50,10 @@ class ResBlock(torch.nn.Module):
         channels: int = 512,
         channels: int = 512,
         kernel_size: int = 3,
         kernel_size: int = 3,
         dilations: List[int] = [1, 3, 5],
         dilations: List[int] = [1, 3, 5],
+        causal: bool = False,
     ):
     ):
         super(ResBlock, self).__init__()
         super(ResBlock, self).__init__()
+        self.causal = causal
         self.convs1 = nn.ModuleList()
         self.convs1 = nn.ModuleList()
         self.convs2 = nn.ModuleList()
         self.convs2 = nn.ModuleList()
 
 
@@ -64,7 +66,14 @@ class ResBlock(torch.nn.Module):
                         kernel_size,
                         kernel_size,
                         1,
                         1,
                         dilation=dilation,
                         dilation=dilation,
-                        padding=get_padding(kernel_size, dilation)
+                        padding=get_padding(kernel_size, dilation)) if causal is False else
+                    CausalConv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation,
+                        causal_type='left'
                     )
                     )
                 )
                 )
             )
             )
@@ -76,7 +85,14 @@ class ResBlock(torch.nn.Module):
                         kernel_size,
                         kernel_size,
                         1,
                         1,
                         dilation=1,
                         dilation=1,
-                        padding=get_padding(kernel_size, 1)
+                        padding=get_padding(kernel_size, 1)) if causal is False else
+                    CausalConv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        causal_type='left'
                     )
                     )
                 )
                 )
             )
             )
@@ -139,11 +155,13 @@ class SineGen(torch.nn.Module):
 
 
     @torch.no_grad()
     @torch.no_grad()
     def forward(self, f0):
     def forward(self, f0):
+        """ sine_tensor, uv = forward(f0)
+        input F0: tensor(batchsize=1, dim=1, length)
+                  f0 for unvoiced steps should be 0
+        output sine_tensor: tensor(batchsize=1, length, dim)
+        output uv: tensor(batchsize=1, length, 1)
         """
         """
-        :param f0: [B, 1, sample_len], Hz
-        :return: [B, 1, sample_len]
-        """
-
+        f0 = f0.transpose(1, 2)
         F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
         F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
         for i in range(self.harmonic_num + 1):
         for i in range(self.harmonic_num + 1):
             F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
             F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
@@ -168,59 +186,7 @@ class SineGen(torch.nn.Module):
         # first: set the unvoiced part to 0 by uv
         # first: set the unvoiced part to 0 by uv
         # then: additive noise
         # then: additive noise
         sine_waves = sine_waves * uv + noise
         sine_waves = sine_waves * uv + noise
-        return sine_waves, uv, noise
-
-
-class SourceModuleHnNSF(torch.nn.Module):
-    """ SourceModule for hn-nsf
-    SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
-                 add_noise_std=0.003, voiced_threshod=0)
-    sampling_rate: sampling_rate in Hz
-    harmonic_num: number of harmonic above F0 (default: 0)
-    sine_amp: amplitude of sine source signal (default: 0.1)
-    add_noise_std: std of additive Gaussian noise (default: 0.003)
-        note that amplitude of noise in unvoiced is decided
-        by sine_amp
-    voiced_threshold: threhold to set U/V given F0 (default: 0)
-    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
-    F0_sampled (batchsize, length, 1)
-    Sine_source (batchsize, length, 1)
-    noise_source (batchsize, length 1)
-    uv (batchsize, length, 1)
-    """
-
-    def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
-                 add_noise_std=0.003, voiced_threshod=0):
-        super(SourceModuleHnNSF, self).__init__()
-
-        self.sine_amp = sine_amp
-        self.noise_std = add_noise_std
-
-        # to produce sine waveforms
-        self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
-                                 sine_amp, add_noise_std, voiced_threshod)
-
-        # to merge source harmonics into a single excitation
-        self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
-        self.l_tanh = torch.nn.Tanh()
-
-    def forward(self, x):
-        """
-        Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
-        F0_sampled (batchsize, length, 1)
-        Sine_source (batchsize, length, 1)
-        noise_source (batchsize, length 1)
-        """
-        # source for harmonic branch
-        with torch.no_grad():
-            sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
-            sine_wavs = sine_wavs.transpose(1, 2)
-            uv = uv.transpose(1, 2)
-        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
-
-        # source for noise branch, in the same shape as uv
-        noise = torch.randn_like(uv) * self.sine_amp / 3
-        return sine_merge, noise, uv
+        return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
 
 
 
 
 class SineGen2(torch.nn.Module):
 class SineGen2(torch.nn.Module):
@@ -242,7 +208,8 @@ class SineGen2(torch.nn.Module):
     def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
     def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
                  sine_amp=0.1, noise_std=0.003,
                  sine_amp=0.1, noise_std=0.003,
                  voiced_threshold=0,
                  voiced_threshold=0,
-                 flag_for_pulse=False):
+                 flag_for_pulse=False,
+                 causal=False):
         super(SineGen2, self).__init__()
         super(SineGen2, self).__init__()
         self.sine_amp = sine_amp
         self.sine_amp = sine_amp
         self.noise_std = noise_std
         self.noise_std = noise_std
@@ -252,6 +219,11 @@ class SineGen2(torch.nn.Module):
         self.voiced_threshold = voiced_threshold
         self.voiced_threshold = voiced_threshold
         self.flag_for_pulse = flag_for_pulse
         self.flag_for_pulse = flag_for_pulse
         self.upsample_scale = upsample_scale
         self.upsample_scale = upsample_scale
+        self.causal = causal
+        if causal is True:
+            self.rand_ini = torch.rand(1, 9)
+            self.rand_ini[:, 0] = 0
+            self.sine_waves = torch.rand(1, 300 * 24000, 9)
 
 
     def _f02uv(self, f0):
     def _f02uv(self, f0):
         # generate uv signal
         # generate uv signal
@@ -267,9 +239,12 @@ class SineGen2(torch.nn.Module):
         rad_values = (f0_values / self.sampling_rate) % 1
         rad_values = (f0_values / self.sampling_rate) % 1
 
 
         # initial phase noise (no noise for fundamental component)
         # initial phase noise (no noise for fundamental component)
-        rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
-        rand_ini[:, 0] = 0
-        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+        if self.training is False and self.causal is True:
+            rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
+        else:
+            rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
+            rand_ini[:, 0] = 0
+            rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
 
 
         # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
         # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
         if not self.flag_for_pulse:
         if not self.flag_for_pulse:
@@ -279,7 +254,7 @@ class SineGen2(torch.nn.Module):
 
 
             phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
             phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
             phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
             phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
-                                                    scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
+                                                    scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
             sines = torch.sin(phase)
             sines = torch.sin(phase)
         else:
         else:
             # If necessary, make sure that the first time step of every
             # If necessary, make sure that the first time step of every
@@ -331,7 +306,10 @@ class SineGen2(torch.nn.Module):
         #        std = self.sine_amp/3 -> max value ~ self.sine_amp
         #        std = self.sine_amp/3 -> max value ~ self.sine_amp
         # .       for voiced regions is self.noise_std
         # .       for voiced regions is self.noise_std
         noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
         noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
-        noise = noise_amp * torch.randn_like(sine_waves)
+        if self.training is False and self.causal is True:
+            noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
+        else:
+            noise = noise_amp * torch.randn_like(sine_waves)
 
 
         # first: set the unvoiced part to 0 by uv
         # first: set the unvoiced part to 0 by uv
         # then: additive noise
         # then: additive noise
@@ -339,7 +317,7 @@ class SineGen2(torch.nn.Module):
         return sine_waves, uv, noise
         return sine_waves, uv, noise
 
 
 
 
-class SourceModuleHnNSF2(torch.nn.Module):
+class SourceModuleHnNSF(torch.nn.Module):
     """ SourceModule for hn-nsf
     """ SourceModule for hn-nsf
     SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
     SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
                  add_noise_std=0.003, voiced_threshod=0)
                  add_noise_std=0.003, voiced_threshod=0)
@@ -358,19 +336,24 @@ class SourceModuleHnNSF2(torch.nn.Module):
     """
     """
 
 
     def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
     def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
-                 add_noise_std=0.003, voiced_threshod=0):
-        super(SourceModuleHnNSF2, self).__init__()
+                 add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
+        super(SourceModuleHnNSF, self).__init__()
 
 
         self.sine_amp = sine_amp
         self.sine_amp = sine_amp
         self.noise_std = add_noise_std
         self.noise_std = add_noise_std
 
 
         # to produce sine waveforms
         # to produce sine waveforms
-        self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
-                                  sine_amp, add_noise_std, voiced_threshod)
+        if sinegen_type == '1':
+            self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
+        else:
+            self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
 
 
         # to merge source harmonics into a single excitation
         # to merge source harmonics into a single excitation
         self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
         self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
         self.l_tanh = torch.nn.Tanh()
         self.l_tanh = torch.nn.Tanh()
+        self.causal = causal
+        if causal is True:
+            self.uv = torch.rand(1, 300 * 24000, 1)
 
 
     def forward(self, x):
     def forward(self, x):
         """
         """
@@ -385,7 +368,10 @@ class SourceModuleHnNSF2(torch.nn.Module):
         sine_merge = self.l_tanh(self.l_linear(sine_wavs))
         sine_merge = self.l_tanh(self.l_linear(sine_wavs))
 
 
         # source for noise branch, in the same shape as uv
         # source for noise branch, in the same shape as uv
-        noise = torch.randn_like(uv) * self.sine_amp / 3
+        if self.training is False and self.causal is True:
+            noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
+        else:
+            noise = torch.randn_like(uv) * self.sine_amp / 3
         return sine_merge, noise, uv
         return sine_merge, noise, uv
 
 
 
 
@@ -425,15 +411,16 @@ class HiFTGenerator(nn.Module):
 
 
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_upsamples = len(upsample_rates)
         self.num_upsamples = len(upsample_rates)
-        # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
-        this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
-        self.m_source = this_SourceModuleHnNSF(
+        # NOTE in CosyVoice2, we use the original SineGen implementation
+        self.m_source = SourceModuleHnNSF(
             sampling_rate=sampling_rate,
             sampling_rate=sampling_rate,
             upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
             upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
             harmonic_num=nb_harmonics,
             harmonic_num=nb_harmonics,
             sine_amp=nsf_alpha,
             sine_amp=nsf_alpha,
             add_noise_std=nsf_sigma,
             add_noise_std=nsf_sigma,
-            voiced_threshod=nsf_voiced_threshold)
+            voiced_threshod=nsf_voiced_threshold,
+            sinegen_type='1' if self.sampling_rate == 22050 else '2',
+            causal=False)
         self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
         self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
 
 
         self.conv_pre = weight_norm(
         self.conv_pre = weight_norm(
@@ -580,3 +567,180 @@ class HiFTGenerator(nn.Module):
             s[:, :, :cache_source.shape[2]] = cache_source
             s[:, :, :cache_source.shape[2]] = cache_source
         generated_speech = self.decode(x=speech_feat, s=s)
         generated_speech = self.decode(x=speech_feat, s=s)
         return generated_speech, s
         return generated_speech, s
+
+
+class CausalHiFTGenerator(HiFTGenerator):
+    """
+    HiFTNet Generator: Neural Source Filter + ISTFTNet
+    https://arxiv.org/abs/2309.09493
+    """
+    def __init__(
+            self,
+            in_channels: int = 80,
+            base_channels: int = 512,
+            nb_harmonics: int = 8,
+            sampling_rate: int = 22050,
+            nsf_alpha: float = 0.1,
+            nsf_sigma: float = 0.003,
+            nsf_voiced_threshold: float = 10,
+            upsample_rates: List[int] = [8, 8],
+            upsample_kernel_sizes: List[int] = [16, 16],
+            istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
+            resblock_kernel_sizes: List[int] = [3, 7, 11],
+            resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+            source_resblock_kernel_sizes: List[int] = [7, 11],
+            source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
+            lrelu_slope: float = 0.1,
+            audio_limit: float = 0.99,
+            conv_pre_look_right: int = 4,
+            f0_predictor: torch.nn.Module = None,
+    ):
+        torch.nn.Module.__init__(self)
+
+        self.out_channels = 1
+        self.nb_harmonics = nb_harmonics
+        self.sampling_rate = sampling_rate
+        self.istft_params = istft_params
+        self.lrelu_slope = lrelu_slope
+        self.audio_limit = audio_limit
+
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        self.m_source = SourceModuleHnNSF(
+            sampling_rate=sampling_rate,
+            upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
+            harmonic_num=nb_harmonics,
+            sine_amp=nsf_alpha,
+            add_noise_std=nsf_sigma,
+            voiced_threshod=nsf_voiced_threshold,
+            sinegen_type='1' if self.sampling_rate == 22050 else '2',
+            causal=True)
+        self.upsample_rates = upsample_rates
+        self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
+
+        self.conv_pre = weight_norm(
+            CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
+        )
+
+        # Up
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    CausalConv1dUpsample(
+                        base_channels // (2**i),
+                        base_channels // (2**(i + 1)),
+                        k,
+                        u,
+                    )
+                )
+            )
+
+        # Down
+        self.source_downs = nn.ModuleList()
+        self.source_resblocks = nn.ModuleList()
+        downsample_rates = [1] + upsample_rates[::-1][:-1]
+        downsample_cum_rates = np.cumprod(downsample_rates)
+        for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
+            if u == 1:
+                self.source_downs.append(
+                    CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
+                )
+            else:
+                self.source_downs.append(
+                    CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
+                )
+
+            self.source_resblocks.append(
+                ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = base_channels // (2**(i + 1))
+            for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
+                self.resblocks.append(ResBlock(ch, k, d, causal=True))
+
+        self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+        self.reflection_pad = nn.ReflectionPad1d((1, 0))
+        self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
+        self.conv_pre_look_right = conv_pre_look_right
+        self.f0_predictor = f0_predictor
+
+    def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
+        s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
+        if finalize is True:
+            x = self.conv_pre(x)
+        else:
+            x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
+            s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
+            s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
+        s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
+
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, self.lrelu_slope)
+            x = self.ups[i](x)
+
+            if i == self.num_upsamples - 1:
+                x = self.reflection_pad(x)
+
+            # fusion
+            si = self.source_downs[i](s_stft)
+            si = self.source_resblocks[i](si)
+            x = x + si
+
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
+        phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :])  # actually, sin is redundancy
+
+        x = self._istft(magnitude, phase)
+        if finalize is False:
+            x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
+        x = torch.clamp(x, -self.audio_limit, self.audio_limit)
+        return x
+
+    @torch.inference_mode()
+    def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
+        # mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
+        self.f0_predictor.to('cpu')
+        f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
+        # f0->source
+        s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
+        s, _, _ = self.m_source(s)
+        s = s.transpose(1, 2)
+        if finalize is True:
+            generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
+        else:
+            generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
+        return generated_speech, s
+
+
+if __name__ == '__main__':
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False
+    from hyperpyyaml import load_hyperpyyaml
+    with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
+        configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
+    model = configs['hift']
+    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+    model.to(device)
+    model.eval()
+    max_len, chunk_size, context_size = 300, 30, 8
+    mel = torch.rand(1, 80, max_len).to(device)
+    pred_gt, _ = model.inference(mel)
+    for i in range(0, max_len, chunk_size):
+        finalize = True if i + chunk_size + context_size >= max_len else False
+        pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
+        pred_chunk = pred_chunk[:, i * 480:]
+        print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())

+ 172 - 44
cosyvoice/llm/llm.py

@@ -17,6 +17,7 @@ import random
 import time
 import time
 import threading
 import threading
 from typing import Dict, Optional, Callable, List, Generator
 from typing import Dict, Optional, Callable, List, Generator
+import numpy as np
 import torch
 import torch
 from torch import nn
 from torch import nn
 import torch.nn.functional as F
 import torch.nn.functional as F
@@ -56,8 +57,9 @@ class TransformerLM(torch.nn.Module):
         )
         )
 
 
         # 2. build speech token language model related modules
         # 2. build speech token language model related modules
-        self.sos_eos = 0
+        self.sos = 0
         self.task_id = 1
         self.task_id = 1
+        self.eos_token = self.speech_token_size
         self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
         self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
         self.llm = llm
         self.llm = llm
         self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
         self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
@@ -85,10 +87,10 @@ class TransformerLM(torch.nn.Module):
         encoder_out = self.text_encoder_affine_layer(encoder_out)
         encoder_out = self.text_encoder_affine_layer(encoder_out)
         return encoder_out, encoder_out_lens
         return encoder_out, encoder_out_lens
 
 
-    def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
+    def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
         text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
         text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
-        lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
+        lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
                     for i in range(len(text_token))]
                     for i in range(len(text_token))]
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
         lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
         lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
@@ -126,15 +128,15 @@ class TransformerLM(torch.nn.Module):
         embedding = self.spk_embed_affine_layer(embedding)
         embedding = self.spk_embed_affine_layer(embedding)
         embedding = embedding.unsqueeze(1)
         embedding = embedding.unsqueeze(1)
 
 
-        # 3. eos and task_id
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        # 3. sos and task_id
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
 
 
         # 4. encode speech_token
         # 4. encode speech_token
         speech_token = self.speech_embedding(speech_token)
         speech_token = self.speech_embedding(speech_token)
 
 
         # 5. unpad and pad
         # 5. unpad and pad
-        lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
+        lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
                                                          task_id_emb, speech_token, speech_token_len)
                                                          task_id_emb, speech_token, speech_token_len)
 
 
         # 6. run lm forward
         # 6. run lm forward
@@ -154,7 +156,7 @@ class TransformerLM(torch.nn.Module):
         num_trials, max_trials = 0, 100
         num_trials, max_trials = 0, 100
         while True:
         while True:
             top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
             top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
-            if (not ignore_eos) or (self.speech_token_size not in top_ids):
+            if (not ignore_eos) or (top_ids < self.speech_token_size):
                 break
                 break
             num_trials += 1
             num_trials += 1
             if num_trials > max_trials:
             if num_trials > max_trials:
@@ -193,13 +195,13 @@ class TransformerLM(torch.nn.Module):
             embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
             embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
 
 
         # 3. concat llm_input
         # 3. concat llm_input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         if prompt_speech_token_len != 0:
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
         else:
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
+        lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
 
 
         # 4. cal min/max_length
         # 4. cal min/max_length
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -215,11 +217,8 @@ class TransformerLM(torch.nn.Module):
                                                                   att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
                                                                   att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
                                                                                                  device=lm_input.device)).to(torch.bool))
                                                                                                  device=lm_input.device)).to(torch.bool))
             logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
             logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            # force continue decode first token
-            if i == 0:
-                logp[:, self.speech_token_size] = -float('inf')
-            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
-            if top_ids == self.speech_token_size:
+            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
+            if top_ids == self.eos_token:
                 break
                 break
             # in stream mode, yield token one by one
             # in stream mode, yield token one by one
             yield top_ids
             yield top_ids
@@ -276,9 +275,10 @@ class Qwen2LM(TransformerLM):
         self.llm_output_size = llm_output_size
         self.llm_output_size = llm_output_size
         self.speech_token_size = speech_token_size
         self.speech_token_size = speech_token_size
         # 2. build speech token language model related modules
         # 2. build speech token language model related modules
-        self.sos_eos = 0
+        self.sos = 0
         self.task_id = 1
         self.task_id = 1
-        self.fill_token = 2
+        self.eos_token = speech_token_size
+        self.fill_token = speech_token_size + 2
 
 
         self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
         self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
         self.llm = llm
         self.llm = llm
@@ -301,7 +301,7 @@ class Qwen2LM(TransformerLM):
         self.stop_token_ids = [speech_token_size + i for i in range(3)]
         self.stop_token_ids = [speech_token_size + i for i in range(3)]
         self.vllm_output_queue = {}
         self.vllm_output_queue = {}
 
 
-    def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
+    def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len):
         lm_target, lm_input = [], []
         lm_target, lm_input = [], []
         text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
         text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
@@ -312,7 +312,7 @@ class Qwen2LM(TransformerLM):
             if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
             if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
                 this_lm_target, this_lm_input = [], []
                 this_lm_target, this_lm_input = [], []
                 this_lm_target.append(IGNORE_ID)
                 this_lm_target.append(IGNORE_ID)
-                this_lm_input.append(self.llm_embedding.weight[self.sos_eos].reshape(1, -1))
+                this_lm_input.append(sos_emb.squeeze(dim=0))
                 for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
                 for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
                     this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
                     this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
                     this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
                     this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
@@ -320,22 +320,21 @@ class Qwen2LM(TransformerLM):
                         assert len(this_speech_token) == self.mix_ratio[1]
                         assert len(this_speech_token) == self.mix_ratio[1]
                         this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
                         this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
                         this_lm_target += this_speech_token
                         this_lm_target += this_speech_token
-                        this_lm_target.append(self.speech_token_size + 2)
+                        this_lm_target.append(self.fill_token)
                         this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
                         this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
                         this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
                         this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
                     else:
                     else:
                         this_lm_target += [-1] * len(this_text_token)
                         this_lm_target += [-1] * len(this_text_token)
                         this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
                         this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
-                        this_lm_target.append(self.speech_token_size)
+                        this_lm_target.append(self.eos_token)
                         this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
                         this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
-                        this_lm_input.append(self.llm_embedding.weight[self.task_id].reshape(1, -1))
+                        this_lm_input.append(task_id_emb.squeeze(dim=0))
                         this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
                         this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
                 this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
                 this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
             # unistream sequence
             # unistream sequence
             else:
             else:
-                this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.speech_token_size])
-                this_lm_input = torch.concat([self.llm_embedding.weight[self.sos_eos].reshape(1, -1), text_token_emb[i],
-                                              self.llm_embedding.weight[self.task_id].reshape(1, -1), speech_token_emb[i]], dim=0)
+                this_lm_target = torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
+                this_lm_input = torch.concat([sos_emb.squeeze(dim=0), text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
             lm_target.append(this_lm_target)
             lm_target.append(this_lm_target)
             lm_input.append(this_lm_input)
             lm_input.append(this_lm_input)
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
         lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
@@ -363,11 +362,16 @@ class Qwen2LM(TransformerLM):
         # 1. encode text_token
         # 1. encode text_token
         text_token_emb = self.llm.model.model.embed_tokens(text_token)
         text_token_emb = self.llm.model.model.embed_tokens(text_token)
 
 
+        # 3. sos and task_id
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
+        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+
         # 2. encode speech_token
         # 2. encode speech_token
         speech_token_emb = self.speech_embedding(speech_token)
         speech_token_emb = self.speech_embedding(speech_token)
 
 
         # 3. prepare llm_input/target
         # 3. prepare llm_input/target
-        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len)
+        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
+                                                                         speech_token, speech_token_emb, speech_token_len)
         lm_target = lm_target.to(device)
         lm_target = lm_target.to(device)
 
 
         # 4. run lm forward
         # 4. run lm forward
@@ -392,6 +396,10 @@ class Qwen2LM(TransformerLM):
         # 1. encode text_token
         # 1. encode text_token
         text_token_emb = self.llm.model.model.embed_tokens(text_token)
         text_token_emb = self.llm.model.model.embed_tokens(text_token)
 
 
+        # 3. sos and task_id
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
+        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
+
         # 2. encode speech_token
         # 2. encode speech_token
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
         speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
         reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
         reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
@@ -401,8 +409,8 @@ class Qwen2LM(TransformerLM):
         speech_token_combined_emb = self.speech_embedding(speech_token_combined)
         speech_token_combined_emb = self.speech_embedding(speech_token_combined)
 
 
         # 3. prepare llm_input/target
         # 3. prepare llm_input/target
-        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
-                                                                         speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
+        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
+                                                                         task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
         lm_target = lm_target.to(device)
         lm_target = lm_target.to(device)
 
 
         # 4. run lm forward
         # 4. run lm forward
@@ -445,13 +453,13 @@ class Qwen2LM(TransformerLM):
         text = self.llm.model.model.embed_tokens(text)
         text = self.llm.model.model.embed_tokens(text)
 
 
         # 3. concat llm_input
         # 3. concat llm_input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         if prompt_speech_token_len != 0:
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
         else:
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
+        lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
 
 
         # 4. cal min/max_length
         # 4. cal min/max_length
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
         min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@@ -500,11 +508,9 @@ class Qwen2LM(TransformerLM):
                                                           masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
                                                           masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
                                                           cache=cache)
                                                           cache=cache)
                 logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
                 logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-                top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
-                if top_ids == self.speech_token_size:
+                top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
+                if top_ids in self.stop_token_ids:
                     break
                     break
-                if top_ids > self.speech_token_size:
-                    continue
                 # in stream mode, yield token one by one
                 # in stream mode, yield token one by one
                 yield top_ids
                 yield top_ids
                 out_tokens.append(top_ids)
                 out_tokens.append(top_ids)
@@ -526,20 +532,20 @@ class Qwen2LM(TransformerLM):
 
 
         device = prompt_text.device
         device = prompt_text.device
         # 1. prepare input
         # 1. prepare input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
+        sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
         if prompt_speech_token_len != 0:
         if prompt_speech_token_len != 0:
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
             prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
         else:
         else:
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
             prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb], dim=1)
+        lm_input = torch.concat([sos_emb], dim=1)
 
 
         # 2. iterate text
         # 2. iterate text
         out_tokens = []
         out_tokens = []
         cache = None
         cache = None
         # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
         # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
         text_cache = self.llm.model.model.embed_tokens(prompt_text)
         text_cache = self.llm.model.model.embed_tokens(prompt_text)
-        next_fill_index = -1
+        next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
         for this_text in text:
         for this_text in text:
             text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
             text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
             # prompt_speech_token_emb not empty, try append to lm_input
             # prompt_speech_token_emb not empty, try append to lm_input
@@ -554,12 +560,12 @@ class Qwen2LM(TransformerLM):
                     break
                     break
             # no prompt_speech_token_emb remain, can decode some speech token
             # no prompt_speech_token_emb remain, can decode some speech token
             if prompt_speech_token_emb.size(1) == 0:
             if prompt_speech_token_emb.size(1) == 0:
-                if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
+                if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
                     logging.info('get fill token, need to append more text token')
                     logging.info('get fill token, need to append more text token')
                     if text_cache.size(1) >= self.mix_ratio[0]:
                     if text_cache.size(1) >= self.mix_ratio[0]:
                         lm_input_text = text_cache[:, :self.mix_ratio[0]]
                         lm_input_text = text_cache[:, :self.mix_ratio[0]]
                         logging.info('append {} text token'.format(lm_input_text.size(1)))
                         logging.info('append {} text token'.format(lm_input_text.size(1)))
-                        if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
+                        if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
                             lm_input = lm_input_text
                             lm_input = lm_input_text
                         else:
                         else:
                             lm_input = torch.concat([lm_input, lm_input_text], dim=1)
                             lm_input = torch.concat([lm_input, lm_input_text], dim=1)
@@ -574,16 +580,16 @@ class Qwen2LM(TransformerLM):
                                                               cache=cache)
                                                               cache=cache)
                     logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
                     logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
                     if next_fill_index != -1 and len(out_tokens) == next_fill_index:
                     if next_fill_index != -1 and len(out_tokens) == next_fill_index:
-                        top_ids = self.speech_token_size + 2
+                        top_ids = self.fill_token
                         next_fill_index += (self.mix_ratio[1] + 1)
                         next_fill_index += (self.mix_ratio[1] + 1)
                     else:
                     else:
-                        top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
-                    if top_ids == self.speech_token_size + 2:
+                        top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
+                    if top_ids == self.fill_token:
                         next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
                         next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
                         logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
                         logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
                     out_tokens.append(top_ids)
                     out_tokens.append(top_ids)
                     if top_ids >= self.speech_token_size:
                     if top_ids >= self.speech_token_size:
-                        if top_ids == self.speech_token_size + 2:
+                        if top_ids == self.fill_token:
                             break
                             break
                         else:
                         else:
                             raise ValueError('should not get token {}'.format(top_ids))
                             raise ValueError('should not get token {}'.format(top_ids))
@@ -599,13 +605,135 @@ class Qwen2LM(TransformerLM):
                                                       masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
                                                       masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
                                                       cache=cache)
                                                       cache=cache)
             logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
             logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
+            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
             out_tokens.append(top_ids)
             out_tokens.append(top_ids)
             if top_ids >= self.speech_token_size:
             if top_ids >= self.speech_token_size:
-                if top_ids == self.speech_token_size:
+                if top_ids == self.eos_token:
                     break
                     break
                 else:
                 else:
                     raise ValueError('should not get token {}'.format(top_ids))
                     raise ValueError('should not get token {}'.format(top_ids))
             # in stream mode, yield token one by one
             # in stream mode, yield token one by one
             yield top_ids
             yield top_ids
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
+
+
+class CosyVoice3LM(Qwen2LM):
+    def __init__(
+            self,
+            llm_input_size: int,
+            llm_output_size: int,
+            speech_token_size: int,
+            llm: torch.nn.Module,
+            sampling: Callable,
+            length_normalized_loss: bool = True,
+            lsm_weight: float = 0.0,
+            mix_ratio: List[int] = [5, 15],
+    ):
+        torch.nn.Module.__init__(self)
+        self.llm_input_size = llm_input_size
+        self.llm_output_size = llm_output_size
+        self.speech_token_size = speech_token_size
+        # 2. build speech token language model related modules
+        self.sos = speech_token_size + 0
+        self.eos_token = speech_token_size + 1
+        self.task_id = speech_token_size + 2
+        self.fill_token = speech_token_size + 3
+
+        self.llm = llm
+        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
+        self.criterion_ce = LabelSmoothingLoss(
+            size=speech_token_size + 200,
+            padding_idx=IGNORE_ID,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        # 3. [Optional] build speech token related modules
+        self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
+
+        # 4. sampling method
+        self.sampling = sampling
+        self.mix_ratio = mix_ratio
+
+        # 5. vllm related
+        self.stop_token_ids = [speech_token_size + i for i in range(200)]
+        self.vllm_output_queue = {}
+
+    def forward(
+            self,
+            batch: dict,
+            device: torch.device,
+    ) -> Dict[str, Optional[torch.Tensor]]:
+        """
+        Args:
+            text: (B, L, D)
+            text_lengths: (B,)
+            audio: (B, T, N) or (B, T)
+            audio_lengths: (B,)
+        """
+        text_token = batch['text_token'].to(device)
+        text_token_len = batch['text_token_len'].to(device)
+        speech_token = batch['speech_token'].to(device)
+        speech_token_len = batch['speech_token_len'].to(device)
+        # NOTE should append instruct_token to sequence, not implemented yet
+        instruct_token = batch['instruct_token'].to(device)
+        instruct_token_len = batch['instruct_token_len'].to(device)
+
+        # 1. encode text_token
+        text_token_emb = self.llm.model.model.embed_tokens(text_token)
+
+        # 3. sos and task_id
+        sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
+        task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
+
+        # 2. encode speech_token
+        speech_token_emb = self.speech_embedding(speech_token)
+
+        # 3. prepare llm_input/target
+        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
+                                                                         speech_token, speech_token_emb, speech_token_len)
+        lm_target = lm_target.to(device)
+
+        # 4. run lm forward
+        lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
+        logits = self.llm_decoder(lm_output)
+        loss = self.criterion_ce(logits, lm_target.to(device))
+        acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
+        return {'loss': loss, 'acc': acc}
+
+    @torch.inference_mode()
+    def inference(
+            self,
+            text: torch.Tensor,
+            text_len: torch.Tensor,
+            prompt_text: torch.Tensor,
+            prompt_text_len: torch.Tensor,
+            prompt_speech_token: torch.Tensor,
+            prompt_speech_token_len: torch.Tensor,
+            embedding: torch.Tensor,
+            sampling: int = 25,
+            max_token_text_ratio: float = 20,
+            min_token_text_ratio: float = 2,
+            uuid: str = '',
+    ) -> Generator[torch.Tensor, None, None]:
+        device = text.device
+        text = torch.concat([prompt_text, text], dim=1)
+        text_len += prompt_text_len
+        text = self.llm.model.model.embed_tokens(text)
+
+        # 3. concat llm_input
+        sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
+        task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
+        if prompt_speech_token_len != 0:
+            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
+        else:
+            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
+        lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
+
+        # 4. cal min/max_length
+        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
+        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
+
+        # 5. step by step decode
+        for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
+            yield token

+ 52 - 4
cosyvoice/tokenizer/tokenizer.py

@@ -238,7 +238,7 @@ def get_tokenizer(
     )
     )
 
 
 
 
-class QwenTokenizer():
+class CosyVoice2Tokenizer():
     def __init__(self, token_path, skip_special_tokens=True):
     def __init__(self, token_path, skip_special_tokens=True):
         super().__init__()
         super().__init__()
         # NOTE: non-chat model, all these special tokens keep randomly initialized.
         # NOTE: non-chat model, all these special tokens keep randomly initialized.
@@ -271,9 +271,57 @@ class QwenTokenizer():
         return text
         return text
 
 
 
 
+class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
+    def __init__(self, token_path, skip_special_tokens=True):
+        # NOTE: non-chat model, all these special tokens keep randomly initialized.
+        special_tokens = {
+            'eos_token': '<|endoftext|>',
+            'pad_token': '<|endoftext|>',
+            'additional_special_tokens': [
+                '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
+                '[breath]', '<strong>', '</strong>', '[noise]',
+                '[laughter]', '[cough]', '[clucking]', '[accent]',
+                '[quick_breath]',
+                "<laughter>", "</laughter>",
+                "[hissing]", "[sigh]", "[vocalized-noise]",
+                "[lipsmack]", "[mn]", "<|endofsystem|>",
+                "[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
+                "[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
+                "[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
+                "[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
+                "[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
+                "[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
+                "[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
+                "[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
+                "[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
+                "[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
+                "[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
+                "[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
+                "[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
+                "[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
+                "[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
+                "[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
+                "[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
+                "[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
+                "[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
+                "[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
+            ]
+        }
+        self.special_tokens = special_tokens
+        self.tokenizer = AutoTokenizer.from_pretrained(token_path)
+        self.tokenizer.add_special_tokens(special_tokens)
+        self.skip_special_tokens = skip_special_tokens
+
+
 @lru_cache(maxsize=None)
 @lru_cache(maxsize=None)
 def get_qwen_tokenizer(
 def get_qwen_tokenizer(
     token_path: str,
     token_path: str,
-    skip_special_tokens: bool
-) -> QwenTokenizer:
-    return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
+    skip_special_tokens: bool,
+    version: str = 'cosyvoice2'
+):
+    if version == 'cosyvoice2':
+        return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
+    elif version == 'cosyvoice3':
+        return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
+    else:
+        raise ValueError

+ 113 - 0
cosyvoice/transformer/convolution.py

@@ -19,6 +19,7 @@ from typing import Tuple
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
+import torch.nn.functional as F
 
 
 
 
 class ConvolutionModule(nn.Module):
 class ConvolutionModule(nn.Module):
@@ -143,3 +144,115 @@ class ConvolutionModule(nn.Module):
             x.masked_fill_(~mask_pad, 0.0)
             x.masked_fill_(~mask_pad, 0.0)
 
 
         return x.transpose(1, 2), new_cache
         return x.transpose(1, 2), new_cache
+
+
+# NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
+class CausalConv1d(torch.nn.Conv1d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = 'zeros',
+        causal_type: str = 'left',
+        device=None,
+        dtype=None
+    ) -> None:
+        super(CausalConv1d, self).__init__(in_channels, out_channels,
+                                           kernel_size, stride=1,
+                                           padding=0, dilation=dilation,
+                                           groups=groups, bias=bias,
+                                           padding_mode=padding_mode,
+                                           device=device, dtype=dtype)
+        assert stride == 1
+        self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
+        assert causal_type in ['left', 'right']
+        self.causal_type = causal_type
+
+    def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
+        input_timestep = x.shape[2]
+        if cache.size(2) == 0:
+            cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
+        assert cache.size(2) == self.causal_padding
+        if self.causal_type == 'left':
+            x = torch.concat([cache, x], dim=2)
+        else:
+            x = torch.concat([x, cache], dim=2)
+        x = super(CausalConv1d, self).forward(x)
+        assert x.shape[2] == input_timestep
+        return x
+
+
+class CausalConv1dDownSample(torch.nn.Conv1d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = 'zeros',
+        device=None,
+        dtype=None
+    ) -> None:
+        super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
+                                                     kernel_size, stride,
+                                                     padding=0, dilation=dilation,
+                                                     groups=groups, bias=bias,
+                                                     padding_mode=padding_mode,
+                                                     device=device, dtype=dtype)
+        assert stride != 1 and dilation == 1
+        assert kernel_size % stride == 0
+        self.causal_padding = stride - 1
+
+    def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+        if cache.size(2) == 0:
+            x = F.pad(x, (self.causal_padding, 0), value=0.0)
+        else:
+            assert cache.size(2) == self.causal_padding
+            x = torch.concat([cache, x], dim=2)
+        x = super(CausalConv1dDownSample, self).forward(x)
+        return x
+
+
+class CausalConv1dUpsample(torch.nn.Conv1d):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = 'zeros',
+        device=None,
+        dtype=None
+    ) -> None:
+        super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
+                                                   kernel_size, 1,
+                                                   padding=0, dilation=dilation,
+                                                   groups=groups, bias=bias,
+                                                   padding_mode=padding_mode,
+                                                   device=device, dtype=dtype)
+        assert dilation == 1
+        self.causal_padding = kernel_size - 1
+        self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
+
+    def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
+        x = self.upsample(x)
+        input_timestep = x.shape[2]
+        if cache.size(2) == 0:
+            x = F.pad(x, (self.causal_padding, 0), value=0.0)
+        else:
+            assert cache.size(2) == self.causal_padding
+            x = torch.concat([cache, x], dim=2)
+        x = super(CausalConv1dUpsample, self).forward(x)
+        assert input_timestep == x.shape[2]
+        return x

+ 5 - 4
cosyvoice/transformer/upsample_encoder.py

@@ -64,17 +64,18 @@ class Upsample1D(nn.Module):
 
 
 
 
 class PreLookaheadLayer(nn.Module):
 class PreLookaheadLayer(nn.Module):
-    def __init__(self, channels: int, pre_lookahead_len: int = 1):
+    def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
         super().__init__()
         super().__init__()
+        self.in_channels = in_channels
         self.channels = channels
         self.channels = channels
         self.pre_lookahead_len = pre_lookahead_len
         self.pre_lookahead_len = pre_lookahead_len
         self.conv1 = nn.Conv1d(
         self.conv1 = nn.Conv1d(
-            channels, channels,
+            in_channels, channels,
             kernel_size=pre_lookahead_len + 1,
             kernel_size=pre_lookahead_len + 1,
             stride=1, padding=0,
             stride=1, padding=0,
         )
         )
         self.conv2 = nn.Conv1d(
         self.conv2 = nn.Conv1d(
-            channels, channels,
+            channels, in_channels,
             kernel_size=3, stride=1, padding=0,
             kernel_size=3, stride=1, padding=0,
         )
         )
 
 
@@ -199,7 +200,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
         # convolution module definition
         # convolution module definition
         convolution_layer_args = (output_size, cnn_module_kernel, activation,
         convolution_layer_args = (output_size, cnn_module_kernel, activation,
                                   cnn_module_norm, causal)
                                   cnn_module_norm, causal)
-        self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
+        self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
         self.encoders = torch.nn.ModuleList([
         self.encoders = torch.nn.ModuleList([
             ConformerEncoderLayer(
             ConformerEncoderLayer(
                 output_size,
                 output_size,

+ 6 - 4
cosyvoice/utils/class_utils.py

@@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
                                              RelPositionMultiHeadedAttention)
                                              RelPositionMultiHeadedAttention)
 from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
 from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
 from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
 from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
-from cosyvoice.llm.llm import TransformerLM, Qwen2LM
-from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
-from cosyvoice.hifigan.generator import HiFTGenerator
-from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
+from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
+from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
+from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
 
 
 
 
 COSYVOICE_ACTIVATION_CLASSES = {
 COSYVOICE_ACTIVATION_CLASSES = {
@@ -80,4 +80,6 @@ def get_model_type(configs):
         return CosyVoiceModel
         return CosyVoiceModel
     if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
     if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
         return CosyVoice2Model
         return CosyVoice2Model
+    if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
+        return CosyVoice3Model
     raise TypeError('No valid model type found!')
     raise TypeError('No valid model type found!')

+ 29 - 2
cosyvoice/utils/common.py

@@ -25,6 +25,33 @@ import torch
 
 
 IGNORE_ID = -1
 IGNORE_ID = -1
 
 
+instruct_list = ["You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用东北话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用甘肃话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用贵州话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用河南话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用湖北话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用湖南话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用江西话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用闽南话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用宁夏话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用山西话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用陕西话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用山东话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用上海话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用四川话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用天津话表达。<|endofprompt|>",
+                 "You are a helpful assistant. 请用云南话表达。<|endofprompt|>",
+                 "You are a helpful assistant. Please say a sentence as loudly as possible.<|endofprompt|>",
+                 "You are a helpful assistant. Please say a sentence in a very soft voice.<|endofprompt|>",
+                 "You are a helpful assistant. 请用尽可能慢地语速说一句话。<|endofprompt|>",
+                 "You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
+                 "You are a helpful assistant. 请非常开心地说一句话。<|endofprompt|>",
+                 "You are a helpful assistant. 请非常伤心地说一句话。<|endofprompt|>",
+                 "You are a helpful assistant. 请非常生气地说一句话。<|endofprompt|>",
+                 "You are a helpful assistant. 我想体验一下小猪佩奇风格,可以吗?<|endofprompt|>",
+                 "You are a helpful assistant. 你可以尝试用机器人的方式解答吗?<|endofprompt|>"]
+
 
 
 def pad_list(xs: List[torch.Tensor], pad_value: int):
 def pad_list(xs: List[torch.Tensor], pad_value: int):
     """Perform padding for the list of tensors.
     """Perform padding for the list of tensors.
@@ -130,12 +157,12 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
             break
             break
     prob = torch.tensor(prob).to(weighted_scores)
     prob = torch.tensor(prob).to(weighted_scores)
     indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
     indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
-    top_ids = indices[prob.multinomial(1, replacement=True)]
+    top_ids = indices[prob.multinomial(1, replacement=True)].item()
     return top_ids
     return top_ids
 
 
 
 
 def random_sampling(weighted_scores, decoded_tokens, sampling):
 def random_sampling(weighted_scores, decoded_tokens, sampling):
-    top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
+    top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item()
     return top_ids
     return top_ids
 
 
 
 

+ 10 - 21
cosyvoice/utils/file_utils.py

@@ -41,11 +41,11 @@ def read_json_lists(list_file):
     return results
     return results
 
 
 
 
-def load_wav(wav, target_sr):
+def load_wav(wav, target_sr, min_sr=16000):
     speech, sample_rate = torchaudio.load(wav, backend='soundfile')
     speech, sample_rate = torchaudio.load(wav, backend='soundfile')
     speech = speech.mean(dim=0, keepdim=True)
     speech = speech.mean(dim=0, keepdim=True)
     if sample_rate != target_sr:
     if sample_rate != target_sr:
-        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
+        assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
         speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
         speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
     return speech
     return speech
 
 
@@ -88,30 +88,18 @@ def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
     logging.info("Succesfully convert onnx to trt...")
     logging.info("Succesfully convert onnx to trt...")
 
 
 
 
+# NOTE do not support bistream inference as only speech token embedding/head is kept
 def export_cosyvoice2_vllm(model, model_path, device):
 def export_cosyvoice2_vllm(model, model_path, device):
     if os.path.exists(model_path):
     if os.path.exists(model_path):
         return
         return
-    pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
-    vocab_size = model.speech_embedding.num_embeddings
-    feature_size = model.speech_embedding.embedding_dim
-    pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
 
 
     dtype = torch.bfloat16
     dtype = torch.bfloat16
     # lm_head
     # lm_head
-    new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
-    with torch.no_grad():
-        new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
-        new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
-        new_lm_head.weight[vocab_size:] = 0
-        new_lm_head.bias[vocab_size:] = 0
-    model.llm.model.lm_head = new_lm_head
-    new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
+    use_bias = True if model.llm_decoder.bias is not None else False
+    model.llm.model.lm_head = model.llm_decoder
     # embed_tokens
     # embed_tokens
     embed_tokens = model.llm.model.model.embed_tokens
     embed_tokens = model.llm.model.model.embed_tokens
-    with torch.no_grad():
-        new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
-        new_codec_embed.weight[vocab_size:] = 0
-    model.llm.model.set_input_embeddings(new_codec_embed)
+    model.llm.model.set_input_embeddings(model.speech_embedding)
     model.llm.model.to(device)
     model.llm.model.to(device)
     model.llm.model.to(dtype)
     model.llm.model.to(dtype)
     tmp_vocab_size = model.llm.model.config.vocab_size
     tmp_vocab_size = model.llm.model.config.vocab_size
@@ -119,11 +107,12 @@ def export_cosyvoice2_vllm(model, model_path, device):
     del model.llm.model.generation_config.eos_token_id
     del model.llm.model.generation_config.eos_token_id
     del model.llm.model.config.bos_token_id
     del model.llm.model.config.bos_token_id
     del model.llm.model.config.eos_token_id
     del model.llm.model.config.eos_token_id
-    model.llm.model.config.vocab_size = pad_vocab_size
+    model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
     model.llm.model.config.tie_word_embeddings = False
     model.llm.model.config.tie_word_embeddings = False
-    model.llm.model.config.use_bias = True
+    model.llm.model.config.use_bias = use_bias
     model.llm.model.save_pretrained(model_path)
     model.llm.model.save_pretrained(model_path)
-    os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
+    if use_bias is True:
+        os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
     model.llm.model.config.vocab_size = tmp_vocab_size
     model.llm.model.config.vocab_size = tmp_vocab_size
     model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
     model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
     model.llm.model.set_input_embeddings(embed_tokens)
     model.llm.model.set_input_embeddings(embed_tokens)

+ 106 - 0
example.py

@@ -0,0 +1,106 @@
+import sys
+sys.path.append('third_party/Matcha-TTS')
+from cosyvoice.cli.cosyvoice import AutoModel
+import torchaudio
+
+
+def cosyvoice_example():
+    """ CosyVoice Usage, check https://fun-audio-llm.github.io/ for more details
+    """
+    cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-SFT')
+    # sft usage
+    print(cosyvoice.list_available_spks())
+    # change stream=True for chunk stream inference
+    for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
+        torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+    cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
+    # zero_shot usage
+    for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
+        torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+    # cross_lingual usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
+    for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
+                                                            './asset/cross_lingual_prompt.wav')):
+        torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+    # vc usage
+    for i, j in enumerate(cosyvoice.inference_vc('./asset/cross_lingual_prompt.wav', './asset/zero_shot_prompt.wav')):
+        torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+    cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
+    # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
+    for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男',
+                                                       'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.<|endofprompt|>')):
+        torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+
+def cosyvoice2_example():
+    """ CosyVoice2 Usage, check https://funaudiollm.github.io/cosyvoice2/ for more details
+    """
+    cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B')
+
+    # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
+    # zero_shot usage
+    for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
+        torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+    # save zero_shot spk for future usage
+    assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', 'my_zero_shot_spk') is True
+    for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk')):
+        torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+    cosyvoice.save_spkinfo()
+
+    # fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
+    for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', './asset/zero_shot_prompt.wav')):
+        torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+    # instruct usage
+    for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话<|endofprompt|>', './asset/zero_shot_prompt.wav')):
+        torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+    # bistream usage, you can use generator as input, this is useful when using text llm model as input
+    # NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
+    def text_generator():
+        yield '收到好友从远方寄来的生日礼物,'
+        yield '那份意外的惊喜与深深的祝福'
+        yield '让我心中充满了甜蜜的快乐,'
+        yield '笑容如花儿般绽放。'
+    for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
+        torchaudio.save('zero_shot_bistream_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+
+def cosyvoice3_example():
+    """ CosyVoice3 Usage, check https://funaudiollm.github.io/cosyvoice3/ for more details
+    """
+    cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B')
+    # zero_shot usage
+    for i, j in enumerate(cosyvoice.inference_zero_shot('八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
+                                                        './asset/zero_shot_prompt.wav', stream=False)):
+        torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+    # fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L280
+    for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>[breath]因为他们那一辈人[breath]在乡里面住的要习惯一点,[breath]邻居都很活络,[breath]嗯,都很熟悉。[breath]',
+                                                            './asset/zero_shot_prompt.wav', stream=False)):
+        torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+    # instruct usage, for supported control, check cosyvoice/utils/common.py#L28
+    for i, j in enumerate(cosyvoice.inference_instruct2('好少咯,一般系放嗰啲国庆啊,中秋嗰啲可能会咯。', 'You are a helpful assistant. 请用广东话表达。<|endofprompt|>',
+                                                        './asset/zero_shot_prompt.wav', stream=False)):
+        torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+    for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>',
+                                                        './asset/zero_shot_prompt.wav', stream=False)):
+        torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+    # hotfix usage
+    for i, j in enumerate(cosyvoice.inference_zero_shot('高管也通过电话、短信、微信等方式对报道[j][ǐ]予好评。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
+                                                        './asset/zero_shot_prompt.wav', stream=False)):
+        torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
+
+
+def main():
+    # cosyvoice_example()
+    # cosyvoice2_example()
+    cosyvoice3_example()
+
+
+if __name__ == '__main__':
+    main()

+ 9 - 2
examples/libritts/cosyvoice/local/prepare_data.py

@@ -40,6 +40,11 @@ def main():
     with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
     with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
         for k, v in spk2utt.items():
         for k, v in spk2utt.items():
             f.write('{} {}\n'.format(k, ' '.join(v)))
             f.write('{} {}\n'.format(k, ' '.join(v)))
+    if args.instruct is True:
+        with open('{}/instruct'.format(args.des_dir), 'w') as f:
+            for k, v in utt2text.items():
+                # NOTE in CosyVoice3, we add instruct in sequence
+                f.write('{} You are a helpful assistant.<|endofprompt|>\n'.format(k, v))
     return
     return
 
 
 
 
@@ -49,7 +54,9 @@ if __name__ == "__main__":
                         type=str)
                         type=str)
     parser.add_argument('--des_dir',
     parser.add_argument('--des_dir',
                         type=str)
                         type=str)
-    parser.add_argument('--ref_model',
-                        type=str)
+    parser.add_argument('--instruct',
+                        action='store_true',
+                        default=False,
+                        help='create instruct file or not')
     args = parser.parse_args()
     args = parser.parse_args()
     main()
     main()

+ 234 - 0
examples/libritts/cosyvoice3/conf/cosyvoice3.yaml

@@ -0,0 +1,234 @@
+# set random seed, so that you may reproduce your result.
+__set_seed1: !apply:random.seed [1986]
+__set_seed2: !apply:numpy.random.seed [1986]
+__set_seed3: !apply:torch.manual_seed [1986]
+__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
+
+# fixed params
+sample_rate: 24000
+llm_input_size: 896
+llm_output_size: 896
+spk_embed_dim: 192
+qwen_pretrain_path: ''
+token_frame_rate: 25
+token_mel_ratio: 2
+
+# stream related params
+chunk_size: 25 # streaming inference chunk size, in token
+num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
+
+# model params
+# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
+# for system/third_party class/function, we do not require this.
+llm: !new:cosyvoice.llm.llm.Qwen2LM
+    llm_input_size: !ref <llm_input_size>
+    llm_output_size: !ref <llm_output_size>
+    speech_token_size: 6561
+    length_normalized_loss: True
+    lsm_weight: 0
+    mix_ratio: [5, 15]
+    llm: !new:cosyvoice.llm.llm.Qwen2Encoder
+        pretrain_path: !ref <qwen_pretrain_path>
+    sampling: !name:cosyvoice.utils.common.ras_sampling
+        top_p: 0.8
+        top_k: 25
+        win_size: 10
+        tau_r: 0.1
+
+flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
+    input_size: 512
+    output_size: 80
+    spk_embed_dim: !ref <spk_embed_dim>
+    output_type: 'mel'
+    vocab_size: 6561
+    input_frame_rate: !ref <token_frame_rate>
+    only_mask_loss: True
+    token_mel_ratio: !ref <token_mel_ratio>
+    pre_lookahead_len: 3
+    encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
+        output_size: 512
+        attention_heads: 8
+        linear_units: 2048
+        num_blocks: 6
+        dropout_rate: 0.1
+        positional_dropout_rate: 0.1
+        attention_dropout_rate: 0.1
+        normalize_before: True
+        input_layer: 'linear'
+        pos_enc_layer_type: 'rel_pos_espnet'
+        selfattention_layer_type: 'rel_selfattn'
+        input_size: 512
+        use_cnn_module: False
+        macaron_style: False
+        static_chunk_size: !ref <chunk_size>
+    decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
+        in_channels: 240
+        n_spks: 1
+        spk_emb_dim: 80
+        cfm_params: !new:omegaconf.DictConfig
+            content:
+                sigma_min: 1e-06
+                solver: 'euler'
+                t_scheduler: 'cosine'
+                training_cfg_rate: 0.2
+                inference_cfg_rate: 0.7
+                reg_loss_type: 'l1'
+        estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
+            in_channels: 320
+            out_channels: 80
+            channels: [256]
+            dropout: 0.0
+            attention_head_dim: 64
+            n_blocks: 4
+            num_mid_blocks: 12
+            num_heads: 8
+            act_fn: 'gelu'
+            static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
+            num_decoding_left_chunks: !ref <num_decoding_left_chunks>
+
+hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
+    in_channels: 80
+    base_channels: 512
+    nb_harmonics: 8
+    sampling_rate: !ref <sample_rate>
+    nsf_alpha: 0.1
+    nsf_sigma: 0.003
+    nsf_voiced_threshold: 10
+    upsample_rates: [8, 5, 3]
+    upsample_kernel_sizes: [16, 11, 7]
+    istft_params:
+        n_fft: 16
+        hop_len: 4
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    source_resblock_kernel_sizes: [7, 7, 11]
+    source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    lrelu_slope: 0.1
+    audio_limit: 0.99
+    f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
+        num_class: 1
+        in_channels: 80
+        cond_channels: 512
+
+# gan related module
+mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1920
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 480
+    win_size: 1920
+    fmin: 0
+    fmax: null
+    center: False
+hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
+    generator: !ref <hift>
+    discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
+        mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
+        mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
+    mel_spec_transform: [
+        !ref <mel_spec_transform1>
+    ]
+
+# processor functions
+parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
+get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
+    token_path: !ref <qwen_pretrain_path>
+    skip_special_tokens: True
+allowed_special: 'all'
+tokenize: !name:cosyvoice.dataset.processor.tokenize
+    get_tokenizer: !ref <get_tokenizer>
+    allowed_special: !ref <allowed_special>
+filter: !name:cosyvoice.dataset.processor.filter
+    max_length: 40960
+    min_length: 100
+    token_max_length: 200
+    token_min_length: 1
+resample: !name:cosyvoice.dataset.processor.resample
+    resample_rate: !ref <sample_rate>
+truncate: !name:cosyvoice.dataset.processor.truncate
+    truncate_length: 24480 # must be a multiplier of hop_size
+feat_extractor: !name:matcha.utils.audio.mel_spectrogram
+    n_fft: 1920
+    num_mels: 80
+    sampling_rate: !ref <sample_rate>
+    hop_size: 480
+    win_size: 1920
+    fmin: 0
+    fmax: 8000
+    center: False
+compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
+    feat_extractor: !ref <feat_extractor>
+    token_mel_ratio: 2
+compute_f0: !name:cosyvoice.dataset.processor.compute_f0
+    sample_rate: !ref <sample_rate>
+    hop_size: 480
+parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
+    normalize: True
+shuffle: !name:cosyvoice.dataset.processor.shuffle
+    shuffle_size: 1000
+sort: !name:cosyvoice.dataset.processor.sort
+    sort_size: 500  # sort_size should be less than shuffle_size
+batch: !name:cosyvoice.dataset.processor.batch
+    batch_type: 'dynamic'
+    max_frames_in_batch: 2000
+padding: !name:cosyvoice.dataset.processor.padding
+    use_spk_embedding: False # change to True during sft
+
+
+# dataset processor pipeline
+data_pipeline: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <compute_fbank>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+data_pipeline_gan: [
+    !ref <parquet_opener>,
+    !ref <tokenize>,
+    !ref <filter>,
+    !ref <resample>,
+    !ref <truncate>,
+    !ref <compute_fbank>,
+    !ref <compute_f0>,
+    !ref <parse_embedding>,
+    !ref <shuffle>,
+    !ref <sort>,
+    !ref <batch>,
+    !ref <padding>,
+]
+
+# llm flow train conf
+train_conf:
+    optim: adam
+    optim_conf:
+        lr: 1e-5 # change to 1e-5 during sft
+    scheduler: constantlr # change to constantlr during sft
+    scheduler_conf:
+        warmup_steps: 2500
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 2
+    log_interval: 100
+    save_per_step: -1
+
+# gan train conf
+train_conf_gan:
+    optim: adam
+    optim_conf:
+        lr: 0.0002 # use small lr for gan training
+    scheduler: constantlr
+    optim_d: adam
+    optim_conf_d:
+        lr: 0.0002 # use small lr for gan training
+    scheduler_d: constantlr
+    max_epoch: 200
+    grad_clip: 5
+    accum_grad: 1 # in gan training, accum_grad must be 1
+    log_interval: 100
+    save_per_step: -1

+ 42 - 0
examples/libritts/cosyvoice3/conf/ds_stage2.json

@@ -0,0 +1,42 @@
+{
+  "train_micro_batch_size_per_gpu": 1,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 100,
+  "gradient_clipping": 5,
+  "fp16": {
+    "enabled": false,
+    "auto_cast": false,
+    "loss_scale": 0,
+    "initial_scale_power": 16,
+    "loss_scale_window": 256,
+    "hysteresis": 2,
+    "consecutive_hysteresis": false,
+    "min_loss_scale": 1
+  },
+  "bf16": {
+    "enabled": false
+  },
+  "zero_force_ds_cpu_optimizer": false,
+  "zero_optimization": {
+    "stage": 2,
+    "offload_optimizer": {
+      "device": "none",
+      "pin_memory": true
+    },
+    "allgather_partitions": true,
+    "allgather_bucket_size": 5e8,
+    "overlap_comm": false,
+    "reduce_scatter": true,
+    "reduce_bucket_size": 5e8,
+    "contiguous_gradients" : true
+  },
+  "optimizer": {
+    "type": "AdamW",
+    "params": {
+        "lr": 0.001,
+        "weight_decay": 0.0001,
+        "torch_adam": true,
+        "adam_w_mode": true
+    }
+  }
+}

+ 1 - 0
examples/libritts/cosyvoice3/cosyvoice

@@ -0,0 +1 @@
+../../../cosyvoice

+ 1 - 0
examples/libritts/cosyvoice3/local

@@ -0,0 +1 @@
+../cosyvoice/local

+ 1 - 0
examples/libritts/cosyvoice3/path.sh

@@ -0,0 +1 @@
+../cosyvoice/path.sh

+ 112 - 0
examples/libritts/cosyvoice3/run.sh

@@ -0,0 +1,112 @@
+#!/bin/bash
+# Copyright 2024 Alibaba Inc. All Rights Reserved.
+. ./path.sh || exit 1;
+
+stage=-1
+stop_stage=3
+
+data_url=www.openslr.org/resources/60
+data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
+pretrained_model_dir=../../../pretrained_models/CosyVoice3-0.5B
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+  echo "Data Download"
+  for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
+    local/download_and_untar.sh ${data_dir} ${data_url} ${part}
+  done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+  echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    mkdir -p data/$x
+    python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct
+  done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    tools/extract_embedding.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/campplus.onnx
+  done
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    tools/extract_speech_token.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/speech_tokenizer_v3.onnx
+  done
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    mkdir -p data/$x/parquet
+    tools/make_parquet_list.py --num_utts_per_parquet 1000 \
+      --num_processes 10 \
+      --instruct \
+      --src_dir data/$x \
+      --des_dir data/$x/parquet
+  done
+fi
+
+# train llm
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+job_id=1986
+dist_backend="nccl"
+num_workers=2
+prefetch=100
+train_engine=torch_ddp
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+  echo "Run train. We only support llm traning for now"
+  if [ $train_engine == 'deepspeed' ]; then
+    echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
+  fi
+  cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
+  cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
+  # NOTE will update llm/hift training later
+  for model in llm flow hifigan; do
+    torchrun --nnodes=1 --nproc_per_node=$num_gpus \
+        --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
+      cosyvoice/bin/train.py \
+      --train_engine $train_engine \
+      --config conf/cosyvoice3.yaml \
+      --train_data data/train.data.list \
+      --cv_data data/dev.data.list \
+      --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
+      --model $model \
+      --checkpoint $pretrained_model_dir/$model.pt \
+      --model_dir `pwd`/exp/cosyvoice3/$model/$train_engine \
+      --tensorboard_dir `pwd`/tensorboard/cosyvoice3/$model/$train_engine \
+      --ddp.dist_backend $dist_backend \
+      --num_workers ${num_workers} \
+      --prefetch ${prefetch} \
+      --pin_memory \
+      --use_amp \
+      --deepspeed_config ./conf/ds_stage2.json \
+      --deepspeed.save_states model+optimizer
+  done
+fi
+
+# average model
+average_num=5
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+  for model in llm flow hifigan; do
+    decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
+    echo "do model average and final checkpoint is $decode_checkpoint"
+    python cosyvoice/bin/average_model.py \
+      --dst_model $decode_checkpoint \
+      --src_path `pwd`/exp/cosyvoice/$model/$train_engine  \
+      --num ${average_num} \
+      --val_best
+  done
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+  echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
+  python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
+  python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
+fi

+ 1 - 0
examples/libritts/cosyvoice3/tools

@@ -0,0 +1 @@
+../../../tools

+ 3 - 3
requirements.txt

@@ -29,9 +29,9 @@ pyworld==0.3.4
 rich==13.7.1
 rich==13.7.1
 soundfile==0.12.1
 soundfile==0.12.1
 tensorboard==2.14.0
 tensorboard==2.14.0
-tensorrt-cu12==10.0.1; sys_platform == 'linux'
-tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
-tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
+tensorrt-cu12==10.13.3.9; sys_platform == 'linux'
+tensorrt-cu12-bindings==10.13.3.9; sys_platform == 'linux'
+tensorrt-cu12-libs==10.13.3.9; sys_platform == 'linux'
 torch==2.3.1
 torch==2.3.1
 torchaudio==2.3.1
 torchaudio==2.3.1
 transformers==4.51.3
 transformers==4.51.3

+ 13 - 0
tools/make_parquet_list.py

@@ -37,6 +37,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
     speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
     speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
     if args.dpo:
     if args.dpo:
         reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
         reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
+    if args.instruct:
+        instruct_list = [utt2instruct[utt] for utt in utt_list]
 
 
     # 保存到parquet,utt2parquet_file,spk2parquet_file
     # 保存到parquet,utt2parquet_file,spk2parquet_file
     df = pd.DataFrame()
     df = pd.DataFrame()
@@ -50,6 +52,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
     df['speech_token'] = speech_token_list
     df['speech_token'] = speech_token_list
     if args.dpo:
     if args.dpo:
         df['reject_speech_token'] = reject_speech_token_list
         df['reject_speech_token'] = reject_speech_token_list
+    if args.instruct:
+        df['instruct'] = instruct_list
     df.to_parquet(parquet_file)
     df.to_parquet(parquet_file)
     with open(utt2parquet_file, 'w') as f:
     with open(utt2parquet_file, 'w') as f:
         json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
         json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
@@ -68,6 +72,10 @@ if __name__ == "__main__":
                         type=int,
                         type=int,
                         default=1,
                         default=1,
                         help='num processes for make parquets')
                         help='num processes for make parquets')
+    parser.add_argument('--instruct',
+                        action='store_true',
+                        default=False,
+                        help='has instruct file or not')
     parser.add_argument('--src_dir',
     parser.add_argument('--src_dir',
                         type=str)
                         type=str)
     parser.add_argument('--des_dir',
     parser.add_argument('--des_dir',
@@ -91,6 +99,11 @@ if __name__ == "__main__":
         for l in f:
         for l in f:
             l = l.replace('\n', '').split()
             l = l.replace('\n', '').split()
             utt2spk[l[0]] = l[1]
             utt2spk[l[0]] = l[1]
+    if args.instruct is True:
+        with open('{}/instruct'.format(args.src_dir)) as f:
+            for l in f:
+                l = l.replace('\n', '').split()
+                utt2instruct[l[0]] = ' '.join(l[1:])
     utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
     utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
     spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
     spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
     utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
     utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))

+ 22 - 6
vllm_example.py

@@ -4,20 +4,36 @@ from vllm import ModelRegistry
 from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
 from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
 ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
 ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
 
 
-from cosyvoice.cli.cosyvoice import CosyVoice2
-from cosyvoice.utils.file_utils import load_wav
+from cosyvoice.cli.cosyvoice import AutoModel
 from cosyvoice.utils.common import set_all_random_seed
 from cosyvoice.utils.common import set_all_random_seed
 from tqdm import tqdm
 from tqdm import tqdm
 
 
 
 
-def main():
-    cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
-    prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
+def cosyvoice2_example():
+    """ CosyVoice2 vllm usage
+    """
+    cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
     for i in tqdm(range(100)):
     for i in tqdm(range(100)):
         set_all_random_seed(i)
         set_all_random_seed(i)
-        for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
+        for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
             continue
             continue
 
 
 
 
+def cosyvoice3_example():
+    """ CosyVoice3 vllm usage
+    """
+    cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B', load_trt=True, load_vllm=True, fp16=False)
+    for i in tqdm(range(100)):
+        set_all_random_seed(i)
+        for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
+                                                            './asset/zero_shot_prompt.wav', stream=False)):
+            continue
+
+
+def main():
+    # cosyvoice2_example()
+    cosyvoice3_example()
+
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     main()
     main()

+ 6 - 26
webui.py

@@ -22,8 +22,8 @@ import random
 import librosa
 import librosa
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
 sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
-from cosyvoice.utils.file_utils import load_wav, logging
+from cosyvoice.cli.cosyvoice import AutoModel
+from cosyvoice.utils.file_utils import logging
 from cosyvoice.utils.common import set_all_random_seed
 from cosyvoice.utils.common import set_all_random_seed
 
 
 inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
 inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
@@ -43,18 +43,6 @@ def generate_seed():
     }
     }
 
 
 
 
-def postprocess(speech, top_db=60, hop_length=220, win_length=440):
-    speech, _ = librosa.effects.trim(
-        speech, top_db=top_db,
-        frame_length=win_length,
-        hop_length=hop_length
-    )
-    if speech.abs().max() > max_val:
-        speech = speech / speech.abs().max() * max_val
-    speech = torch.concat([speech, torch.zeros(1, int(cosyvoice.sample_rate * 0.2))], dim=1)
-    return speech
-
-
 def change_instruction(mode_checkbox_group):
 def change_instruction(mode_checkbox_group):
     return instruct_dict[mode_checkbox_group]
     return instruct_dict[mode_checkbox_group]
 
 
@@ -118,15 +106,13 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
             yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
             yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '3s极速复刻':
     elif mode_checkbox_group == '3s极速复刻':
         logging.info('get zero_shot inference request')
         logging.info('get zero_shot inference request')
-        prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
         set_all_random_seed(seed)
-        for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
+        for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_wav, stream=stream, speed=speed):
             yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
             yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
     elif mode_checkbox_group == '跨语种复刻':
     elif mode_checkbox_group == '跨语种复刻':
         logging.info('get cross_lingual inference request')
         logging.info('get cross_lingual inference request')
-        prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
         set_all_random_seed(seed)
         set_all_random_seed(seed)
-        for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
+        for i in cosyvoice.inference_cross_lingual(tts_text, prompt_wav, stream=stream, speed=speed):
             yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
             yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten())
     else:
     else:
         logging.info('get instruct inference request')
         logging.info('get instruct inference request')
@@ -181,16 +167,10 @@ if __name__ == '__main__':
                         default=8000)
                         default=8000)
     parser.add_argument('--model_dir',
     parser.add_argument('--model_dir',
                         type=str,
                         type=str,
-                        default='pretrained_models/CosyVoice2-0.5B',
+                        default='pretrained_models/CosyVoice3-0.5B',
                         help='local path or modelscope repo id')
                         help='local path or modelscope repo id')
     args = parser.parse_args()
     args = parser.parse_args()
-    try:
-        cosyvoice = CosyVoice(args.model_dir)
-    except Exception:
-        try:
-            cosyvoice = CosyVoice2(args.model_dir)
-        except Exception:
-            raise TypeError('no valid model_type!')
+    cosyvoice = AutoModel(model_dir=args.model_dir)
 
 
     sft_spk = cosyvoice.list_available_spks()
     sft_spk = cosyvoice.list_available_spks()
     if len(sft_spk) == 0:
     if len(sft_spk) == 0: