Parcourir la source

use amp in flow

lyuxiang.lx il y a 4 mois
Parent
commit
5bc4b23f02
3 fichiers modifiés avec 7 ajouts et 141 suppressions
  1. 0 126
      cosyvoice/bin/inference_deprecated.py
  2. 0 9
      cosyvoice/cli/model.py
  3. 7 6
      cosyvoice/flow/flow_matching.py

+ 0 - 126
cosyvoice/bin/inference_deprecated.py

@@ -1,126 +0,0 @@
-# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import argparse
-import logging
-logging.getLogger('matplotlib').setLevel(logging.WARNING)
-import os
-import torch
-from torch.utils.data import DataLoader
-import torchaudio
-from hyperpyyaml import load_hyperpyyaml
-from tqdm import tqdm
-from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
-from cosyvoice.dataset.dataset import Dataset
-
-
-def get_args():
-    parser = argparse.ArgumentParser(description='inference with your model')
-    parser.add_argument('--config', required=True, help='config file')
-    parser.add_argument('--prompt_data', required=True, help='prompt data file')
-    parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
-    parser.add_argument('--tts_text', required=True, help='tts input file')
-    parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
-    parser.add_argument('--llm_model', required=True, help='llm model file')
-    parser.add_argument('--flow_model', required=True, help='flow model file')
-    parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
-    parser.add_argument('--gpu',
-                        type=int,
-                        default=-1,
-                        help='gpu id for this rank, -1 for cpu')
-    parser.add_argument('--mode',
-                        default='sft',
-                        choices=['sft', 'zero_shot'],
-                        help='inference mode')
-    parser.add_argument('--result_dir', required=True, help='asr result file')
-    args = parser.parse_args()
-    print(args)
-    return args
-
-
-def main():
-    args = get_args()
-    logging.basicConfig(level=logging.DEBUG,
-                        format='%(asctime)s %(levelname)s %(message)s')
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
-
-    # Init cosyvoice models from configs
-    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
-    device = torch.device('cuda' if use_cuda else 'cpu')
-    try:
-        with open(args.config, 'r') as f:
-            configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
-        model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
-    except Exception:
-        try:
-            with open(args.config, 'r') as f:
-                configs = load_hyperpyyaml(f)
-            model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
-        except Exception:
-            raise TypeError('no valid model_type!')
-
-    model.load(args.llm_model, args.flow_model, args.hifigan_model)
-
-    test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
-                           tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
-    test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
-
-    sample_rate = configs['sample_rate']
-    del configs
-    os.makedirs(args.result_dir, exist_ok=True)
-    fn = os.path.join(args.result_dir, 'wav.scp')
-    f = open(fn, 'w')
-    with torch.no_grad():
-        for _, batch in tqdm(enumerate(test_data_loader)):
-            utts = batch["utts"]
-            assert len(utts) == 1, "inference mode only support batchsize 1"
-            text_token = batch["text_token"].to(device)
-            text_token_len = batch["text_token_len"].to(device)
-            tts_index = batch["tts_index"]
-            tts_text_token = batch["tts_text_token"].to(device)
-            tts_text_token_len = batch["tts_text_token_len"].to(device)
-            speech_token = batch["speech_token"].to(device)
-            speech_token_len = batch["speech_token_len"].to(device)
-            speech_feat = batch["speech_feat"].to(device)
-            speech_feat_len = batch["speech_feat_len"].to(device)
-            utt_embedding = batch["utt_embedding"].to(device)
-            spk_embedding = batch["spk_embedding"].to(device)
-            if args.mode == 'sft':
-                model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
-                               'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
-            else:
-                model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
-                               'prompt_text': text_token, 'prompt_text_len': text_token_len,
-                               'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
-                               'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
-                               'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
-                               'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
-            tts_speeches = []
-            for model_output in model.tts(**model_input):
-                tts_speeches.append(model_output['tts_speech'])
-            tts_speeches = torch.concat(tts_speeches, dim=1)
-            tts_key = '{}_{}'.format(utts[0], tts_index[0])
-            tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
-            torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
-            f.write('{} {}\n'.format(tts_key, tts_fn))
-            f.flush()
-    f.close()
-    logging.info('Result wav.scp saved in {}'.format(fn))
-
-
-if __name__ == '__main__':
-    logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
-    main()

+ 0 - 9
cosyvoice/cli/model.py

@@ -38,9 +38,6 @@ class CosyVoiceModel:
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
-        if self.fp16 is True:
-            self.llm.half()
-            self.flow.half()
         self.token_min_hop_len = 2 * self.flow.input_frame_rate
         self.token_max_hop_len = 4 * self.flow.input_frame_rate
         self.token_overlap_len = 20
@@ -249,9 +246,6 @@ class CosyVoice2Model(CosyVoiceModel):
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
-        if self.fp16 is True:
-            self.llm.half()
-            self.flow.half()
         # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
         # hift cache
@@ -398,9 +392,6 @@ class CosyVoice3Model(CosyVoice2Model):
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
-        if self.fp16 is True:
-            self.llm.half()
-            self.flow.half()
         # NOTE must matching training static_chunk_size
         self.token_hop_len = 25
         # rtf and decoding related

+ 7 - 6
cosyvoice/flow/flow_matching.py

@@ -91,12 +91,13 @@ class ConditionalCFM(BASECFM):
         sol = []
 
         # Do not use concat, it may cause memory format changed and trt infer with wrong results!
-        x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-        mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
-        mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
-        t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
-        spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
-        cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
+        # NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
+        x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
+        mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
+        mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
+        t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
+        spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
+        cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
         for step in range(1, len(t_span)):
             # Classifier-Free Guidance inference introduced in VoiceBox
             x_in[:] = x