1
0
lyuxiang.lx 11 kuukautta sitten
vanhempi
commit
c693039d14
6 muutettua tiedostoa jossa 145 lisäystä ja 71 poistoa
  1. 38 7
      README.md
  2. 20 16
      cosyvoice/cli/cosyvoice.py
  3. 8 8
      cosyvoice/cli/frontend.py
  4. 22 25
      cosyvoice/cli/model.py
  5. 54 11
      cosyvoice/flow/flow_matching.py
  6. 3 4
      requirements.txt

+ 38 - 7
README.md

@@ -1,4 +1,8 @@
 # CosyVoice
+
+## 👉🏻 [CosyVoice2 Demos](https://funaudiollm.github.io/cosyvoice2/) 👈🏻
+[[CosyVoice2 Paper](https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf)][[CosyVoice2 Studio](https://www.modelscope.cn/studios/iic/CosyVoice-300M)]
+
 ## 👉🏻 [CosyVoice Demos](https://fun-audio-llm.github.io/) 👈🏻
 [[CosyVoice Paper](https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf)][[CosyVoice Studio](https://www.modelscope.cn/studios/iic/CosyVoice-300M)][[CosyVoice Code](https://github.com/FunAudioLLM/CosyVoice)]
 
@@ -6,6 +10,11 @@ For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVo
 
 ## Roadmap
 
+- [x] 2024/12
+
+    - [x] CosyVoice2-0.5B model release
+    - [x] CosyVoice2-0.5B streaming inference with no quality degradation
+
 - [x] 2024/07
 
     - [x] Flow matching training support
@@ -24,9 +33,8 @@ For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVo
 
 - [ ] TBD
 
-    - [ ] 25hz llama based llm model which supports lora finetune
-    - [ ] Support more instruction mode
-    - [ ] Music generation
+    - [ ] CosyVoice2-0.5B bistream inference support
+    - [ ] CosyVoice2-0.5B training and finetune recipie
     - [ ] CosyVoice-500M trained with more multi-lingual data
     - [ ] More...
 
@@ -46,7 +54,7 @@ git submodule update --init --recursive
 - Create Conda env:
 
 ``` sh
-conda create -n cosyvoice python=3.8
+conda create -n cosyvoice python=3.10
 conda activate cosyvoice
 # pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
 conda install -y -c conda-forge pynini==2.1.5
@@ -68,6 +76,7 @@ If you are expert in this field, and you are only interested in training your ow
 ``` python
 # SDK模型下载
 from modelscope import snapshot_download
+snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
 snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
 snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
 snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
@@ -78,6 +87,7 @@ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice
 ``` sh
 # git模型下载,请确保已安装git lfs
 mkdir -p pretrained_models
+git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
 git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
 git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
 git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
@@ -97,9 +107,11 @@ pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
 
 **Basic Usage**
 
-For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
+For zero_shot/cross_lingual inference, please use `CosyVoice2-0.5B` or `CosyVoice-300M` model.
 For sft inference, please use `CosyVoice-300M-SFT` model.
 For instruct inference, please use `CosyVoice-300M-Instruct` model.
+We strongly recommend using `CosyVoice2-0.5B` model for better streaming performance.
+
 First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
 
 ``` sh
@@ -107,10 +119,18 @@ export PYTHONPATH=third_party/Matcha-TTS
 ```
 
 ``` python
-from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
 from cosyvoice.utils.file_utils import load_wav
 import torchaudio
 
+## cosyvoice2 usage
+cosyvoice2 = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, load_trt=False)
+# sft usage
+prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
+for i, j in enumerate(cosyvoice2.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=True)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice2.sample_rate)
+
+## cosyvoice usage
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
 # sft usage
 print(cosyvoice.list_avaliable_spks())
@@ -189,5 +209,16 @@ You can also scan the QR code to join our official Dingding chat group.
 4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
 5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
 
+## Citations
+
+``` bibtex
+@article{du2024cosyvoice,
+  title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
+  author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
+  journal={arXiv preprint arXiv:2407.05407},
+  year={2024}
+}
+```
+
 ## Disclaimer
-The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
+The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.

+ 20 - 16
cosyvoice/cli/cosyvoice.py

@@ -38,6 +38,7 @@ class CosyVoice:
                                           '{}/spk2info.pt'.format(model_dir),
                                           instruct,
                                           configs['allowed_special'])
+        self.sample_rate = configs['sample_rate']
         if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
             load_jit = False
             fp16 = False
@@ -64,7 +65,7 @@ class CosyVoice:
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / 22050
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
@@ -74,11 +75,11 @@ class CosyVoice:
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
             if len(i) < 0.5 * len(prompt_text):
                 logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
-            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
+            model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / 22050
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
@@ -87,11 +88,11 @@ class CosyVoice:
         if self.frontend.instruct is True:
             raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
         for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
-            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
+            model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / 22050
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
@@ -105,23 +106,23 @@ class CosyVoice:
             start_time = time.time()
             logging.info('synthesis text {}'.format(i))
             for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
-                speech_len = model_output['tts_speech'].shape[1] / 22050
+                speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
                 yield model_output
                 start_time = time.time()
 
     def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
-        model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
+        model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
         start_time = time.time()
         for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
-            speech_len = model_output['tts_speech'].shape[1] / 22050
+            speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
             logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
             yield model_output
             start_time = time.time()
 
 class CosyVoice2(CosyVoice):
 
-    def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
+    def __init__(self, model_dir, load_jit=False, load_onnx=False, load_trt=False):
         instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         if not os.path.exists(model_dir):
@@ -135,18 +136,21 @@ class CosyVoice2(CosyVoice):
                                           '{}/spk2info.pt'.format(model_dir),
                                           instruct,
                                           configs['allowed_special'])
-        if torch.cuda.is_available() is False and (fp16 is True or load_jit is True):
+        self.sample_rate = configs['sample_rate']
+        if torch.cuda.is_available() is False and load_jit is True:
             load_jit = False
-            fp16 = False
-            logging.warning('cpu do not support fp16 and jit, force set to False')
-        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
+            logging.warning('cpu do not support jit, force set to False')
+        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
         self.model.load('{}/llm.pt'.format(model_dir),
                         '{}/flow.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
         if load_jit:
-            self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
-                                '{}/llm.llm.fp16.zip'.format(model_dir),
-                                '{}/flow.encoder.fp32.zip'.format(model_dir))
+            self.model.load_jit('{}/flow.encoder.fp32.zip'.format(model_dir))
+        if load_trt is True and load_onnx is True:
+            load_onnx = False
+            logging.warning('can not set both load_trt and load_onnx to True, force set load_onnx to False')
         if load_onnx:
             self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
+        if load_trt:
+            self.model.load_trt('{}/flow.decoder.estimator.fp16.Volta.plan'.format(model_dir))
         del configs

+ 8 - 8
cosyvoice/cli/frontend.py

@@ -142,11 +142,11 @@ class CosyVoiceFrontEnd:
         model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
         return model_input
 
-    def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
+    def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
         tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
         prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
-        prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
-        speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
+        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
+        speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
         speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
         embedding = self._extract_spk_embedding(prompt_speech_16k)
         model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
@@ -157,8 +157,8 @@ class CosyVoiceFrontEnd:
                        'llm_embedding': embedding, 'flow_embedding': embedding}
         return model_input
 
-    def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
-        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
+    def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
+        model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
         # in cross lingual mode, we remove prompt in llm
         del model_input['prompt_text']
         del model_input['prompt_text_len']
@@ -175,10 +175,10 @@ class CosyVoiceFrontEnd:
         model_input['prompt_text_len'] = instruct_text_token_len
         return model_input
 
-    def frontend_vc(self, source_speech_16k, prompt_speech_16k):
+    def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
         prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
-        prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
-        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
+        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
+        prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
         embedding = self._extract_spk_embedding(prompt_speech_16k)
         source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
         model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,

+ 22 - 25
cosyvoice/cli/model.py

@@ -261,16 +261,15 @@ class CosyVoice2Model:
     def __init__(self,
                  llm: torch.nn.Module,
                  flow: torch.nn.Module,
-                 hift: torch.nn.Module,
-                 fp16: bool):
+                 hift: torch.nn.Module):
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         self.llm = llm
         self.flow = flow
         self.hift = hift
-        self.fp16 = fp16
-        self.token_min_hop_len = 1 * self.flow.input_frame_rate
-        self.token_max_hop_len = 2 * self.flow.input_frame_rate
-        self.token_right_context = self.flow.encoder.pre_lookahead_layer.pre_lookahead_len
+        self.token_hop_len = 2 * self.flow.input_frame_rate
+        # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
+        self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
+        self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
         # hift cache
         self.mel_cache_len = 8
         self.source_cache_len = int(self.mel_cache_len * 480)
@@ -278,7 +277,6 @@ class CosyVoice2Model:
         self.speech_window = np.hamming(2 * self.source_cache_len)
         # rtf and decoding related
         self.stream_scale_factor = 1
-        assert self.stream_scale_factor == 1, 'fix stream_scale_factor to 1 as we haven\'t implement cache in flow matching yet, this constraint will be loosen in the future'
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.lock = threading.Lock()
         # dict used to store session related variable
@@ -293,17 +291,13 @@ class CosyVoice2Model:
             self.llm.half()
         self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
         self.flow.to(self.device).eval()
+        self.flow.decoder.fp16 = False
         # in case hift_model is a hifigan model
         hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
         self.hift.load_state_dict(hift_state_dict, strict=True)
         self.hift.to(self.device).eval()
 
-    def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
-        assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
-        llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
-        self.llm.text_encoder = llm_text_encoder
-        llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
-        self.llm.llm = llm_llm
+    def load_jit(self, flow_encoder_model):
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
@@ -316,6 +310,14 @@ class CosyVoice2Model:
         del self.flow.decoder.estimator
         self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
 
+    def load_trt(self, flow_decoder_estimator_model):
+        del self.flow.decoder.estimator
+        import tensorrt as trt
+        with open(flow_decoder_estimator_model, 'rb') as f:
+            self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
+        self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
+        self.flow.decoder.fp16 = True
+
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         if self.fp16 is True:
             llm_embedding = llm_embedding.half()
@@ -339,7 +341,7 @@ class CosyVoice2Model:
                                                   prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                                   embedding=embedding.to(self.device),
                                                   finalize=finalize)
-        tts_mel = tts_mel[:, :, token_offset * self.flow.encoder.up_layer.stride:]
+        tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
         # append hift cache
         if self.hift_cache_dict[uuid] is not None:
             hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
@@ -377,13 +379,11 @@ class CosyVoice2Model:
         p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
         p.start()
         if stream is True:
-            token_hop_len, token_offset = self.token_min_hop_len, 0
-            self.flow.encoder.static_chunk_size = self.token_min_hop_len
-            self.flow.decoder.estimator.static_chunk_size = self.token_min_hop_len * self.flow.encoder.up_layer.stride
+            token_offset = 0
             while True:
                 time.sleep(0.1)
-                if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= token_hop_len + self.token_right_context:
-                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + token_hop_len + self.token_right_context]) \
+                if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
+                    this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]) \
                         .unsqueeze(dim=0)
                     this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                      prompt_token=flow_prompt_speech_token,
@@ -392,11 +392,9 @@ class CosyVoice2Model:
                                                      uuid=this_uuid,
                                                      token_offset=token_offset,
                                                      finalize=False)
-                    token_offset += token_hop_len
+                    token_offset += self.token_hop_len
                     yield {'tts_speech': this_tts_speech.cpu()}
-                    # increase token_hop_len for better speech quality
-                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
-                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < token_hop_len + self.token_right_context:
+                if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
                     break
             p.join()
             # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
@@ -412,14 +410,13 @@ class CosyVoice2Model:
         else:
             # deal with all tokens
             p.join()
-            self.flow.encoder.static_chunk_size = 0
-            self.flow.decoder.estimator.static_chunk_size = 0
             this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
             this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                              prompt_token=flow_prompt_speech_token,
                                              prompt_feat=prompt_speech_feat,
                                              embedding=flow_embedding,
                                              uuid=this_uuid,
+                                             token_offset=0,
                                              finalize=True,
                                              speed=speed)
             yield {'tts_speech': this_tts_speech.cpu()}

+ 54 - 11
cosyvoice/flow/flow_matching.py

@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import onnxruntime
 import torch
 import torch.nn.functional as F
 from matcha.models.components.flow_matching import BASECFM
@@ -88,15 +89,25 @@ class ConditionalCFM(BASECFM):
         # Or in future might add like a return_all_steps flag
         sol = []
 
+        if self.inference_cfg_rate > 0:
+            # Do not use concat, it may cause memory format changed and trt infer with wrong results!
+            x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+            mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
+            mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+            t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
+            spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
+            cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+        else:
+            x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
         for step in range(1, len(t_span)):
             # Classifier-Free Guidance inference introduced in VoiceBox
             if self.inference_cfg_rate > 0:
-                x_in = torch.concat([x, x], dim=0)
-                mask_in = torch.concat([mask, mask], dim=0)
-                mu_in = torch.concat([mu, torch.zeros_like(mu).to(x.device)], dim=0)
-                t_in = torch.concat([t, t], dim=0)
-                spks_in = torch.concat([spks, torch.zeros_like(spks).to(x.device)], dim=0) if spks is not None else None
-                cond_in = torch.concat([cond, torch.zeros_like(cond).to(x.device)], dim=0) if cond is not None else None
+                x_in[:] = x
+                mask_in[:] = mask
+                mu_in[0] = mu
+                t_in[:] = t.unsqueeze(0)
+                spks_in[0] = spks
+                cond_in[0] = cond
             else:
                 x_in, mask_in, mu_in, t_in, spks_in, cond_in = x, mask, mu, t, spks, cond
             dphi_dt = self.forward_estimator(
@@ -114,22 +125,53 @@ class ConditionalCFM(BASECFM):
             if step < len(t_span) - 1:
                 dt = t_span[step + 1] - t
 
-        return sol[-1]
+        return sol[-1].float()
 
     def forward_estimator(self, x, mask, mu, t, spks, cond):
         if isinstance(self.estimator, torch.nn.Module):
             return self.estimator.forward(x, mask, mu, t, spks, cond)
-        else:
+        elif isinstance(self.estimator, onnxruntime.InferenceSession):
             ort_inputs = {
                 'x': x.cpu().numpy(),
                 'mask': mask.cpu().numpy(),
                 'mu': mu.cpu().numpy(),
                 't': t.cpu().numpy(),
-                'spks': spks.cpu().numpy(),
-                'cond': cond.cpu().numpy()
+                'spk': spks.cpu().numpy(),
+                'cond': cond.cpu().numpy(),
+                'mask_rand': torch.randn(1, 1, 1).numpy()
             }
             output = self.estimator.run(None, ort_inputs)[0]
             return torch.tensor(output, dtype=x.dtype, device=x.device)
+        else:
+            if not x.is_contiguous():
+                x = x.contiguous()
+            if not mask.is_contiguous():
+                mask = mask.contiguous()
+            if not mu.is_contiguous():
+                mu = mu.contiguous()
+            if not t.is_contiguous():
+                t = t.contiguous()
+            if not spks.is_contiguous():
+                spks = spks.contiguous()
+            if not cond.is_contiguous():
+                cond = cond.contiguous()
+            self.estimator.set_input_shape('x', (2, 80, x.size(2)))
+            self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
+            self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
+            self.estimator.set_input_shape('t', (2,))
+            self.estimator.set_input_shape('spk', (2, 80))
+            self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
+            self.estimator.set_input_shape('mask_rand', (1, 1, 1))
+            # run trt engine
+            self.estimator.execute_v2([x.data_ptr(),
+                                       mask.data_ptr(),
+                                       mu.data_ptr(),
+                                       t.data_ptr(),
+                                       spks.data_ptr(),
+                                       cond.data_ptr(),
+                                       torch.randn(1, 1, 1).to(x.device).data_ptr(),
+                                       x.data_ptr()])
+            return x
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         """Computes diffusion loss
@@ -199,7 +241,8 @@ class CausalConditionalCFM(ConditionalCFM):
         """
 
         z = self.rand_noise[:, :, :mu.size(2)].to(mu.device) * temperature
-        z[:] = 0
+        if self.sp16 is True:
+            z = z.half()
         # fix prompt and overlap part mu and z
         t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
         if self.t_scheduler == 'cosine':

+ 3 - 4
requirements.txt

@@ -1,5 +1,4 @@
---extra-index-url https://download.pytorch.org/whl/torch_stable.html
-conformer==0.3.2
+--extra-index-url https://download.pytorch.org/whl/cu121
 deepspeed==0.14.2; sys_platform == 'linux'
 diffusers==0.27.2
 gdown==5.1.0
@@ -26,8 +25,8 @@ rich==13.7.1
 soundfile==0.12.1
 tensorboard==2.14.0
 tensorrt-cu12==10.0.1
-torch==2.3.1+cu121
-torchaudio==2.3.1+cu121
+torch==2.3.1
+torchaudio==2.3.1
 uvicorn==0.30.0
 wget==3.2
 fastapi==0.111.0