浏览代码

Merge pull request #865 from FunAudioLLM/dev/lyuxiang.lx

Dev/lyuxiang.lx
Xiang Lyu 10 月之前
父节点
当前提交
8a1bce6c81

+ 6 - 2
cosyvoice/cli/cosyvoice.py

@@ -53,7 +53,9 @@ class CosyVoice:
                                 '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                 '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
         if load_trt:
-            self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
+            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
+                                self.fp16)
         del configs
 
     def list_available_spks(self):
@@ -149,7 +151,9 @@ class CosyVoice2(CosyVoice):
         if load_jit:
             self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
         if load_trt:
-            self.model.load_trt('{}/flow.decoder.estimator.{}.v100.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
+            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
+                                self.fp16)
         del configs
 
     def inference_instruct(self, *args, **kwargs):

+ 15 - 4
cosyvoice/cli/model.py

@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import os
 import torch
 import numpy as np
 import threading
@@ -19,6 +20,7 @@ from torch.nn import functional as F
 from contextlib import nullcontext
 import uuid
 from cosyvoice.utils.common import fade_in_out
+from cosyvoice.utils.file_utils import convert_onnx_to_trt
 
 
 class CosyVoiceModel:
@@ -35,6 +37,9 @@ class CosyVoiceModel:
         self.fp16 = fp16
         self.llm.fp16 = fp16
         self.flow.fp16 = fp16
+        if self.fp16 is True:
+            self.llm.half()
+            self.flow.half()
         self.token_min_hop_len = 2 * self.flow.input_frame_rate
         self.token_max_hop_len = 4 * self.flow.input_frame_rate
         self.token_overlap_len = 20
@@ -69,9 +74,6 @@ class CosyVoiceModel:
         hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
         self.hift.load_state_dict(hift_state_dict, strict=True)
         self.hift.to(self.device).eval()
-        if self.fp16 is True:
-            self.llm.half()
-            self.flow.half()
 
     def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
         llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
@@ -81,7 +83,10 @@ class CosyVoiceModel:
         flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
         self.flow.encoder = flow_encoder
 
-    def load_trt(self, flow_decoder_estimator_model):
+    def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
+        assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
+        if not os.path.exists(flow_decoder_estimator_model):
+            convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
         del self.flow.decoder.estimator
         import tensorrt as trt
         with open(flow_decoder_estimator_model, 'rb') as f:
@@ -204,6 +209,7 @@ class CosyVoiceModel:
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
             self.flow_cache_dict.pop(this_uuid)
+        torch.cuda.empty_cache()
 
     def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
@@ -257,6 +263,7 @@ class CosyVoiceModel:
             self.llm_end_dict.pop(this_uuid)
             self.mel_overlap_dict.pop(this_uuid)
             self.hift_cache_dict.pop(this_uuid)
+        torch.cuda.empty_cache()
 
 
 class CosyVoice2Model(CosyVoiceModel):
@@ -273,6 +280,9 @@ class CosyVoice2Model(CosyVoiceModel):
         self.fp16 = fp16
         self.llm.fp16 = fp16
         self.flow.fp16 = fp16
+        if self.fp16 is True:
+            self.llm.half()
+            self.flow.half()
         self.token_hop_len = 2 * self.flow.input_frame_rate
         # here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
         self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
@@ -385,3 +395,4 @@ class CosyVoice2Model(CosyVoiceModel):
         with self.lock:
             self.tts_speech_token_dict.pop(this_uuid)
             self.llm_end_dict.pop(this_uuid)
+        torch.cuda.empty_cache()

+ 0 - 1
cosyvoice/dataset/processor.py

@@ -21,7 +21,6 @@ import torchaudio
 from torch.nn.utils.rnn import pad_sequence
 import torch.nn.functional as F
 
-torchaudio.set_audio_backend('soundfile')
 
 AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
 

+ 6 - 6
cosyvoice/flow/flow_matching.py

@@ -134,12 +134,12 @@ class ConditionalCFM(BASECFM):
                 self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
                 # run trt engine
                 self.estimator.execute_v2([x.contiguous().data_ptr(),
-                                        mask.contiguous().data_ptr(),
-                                        mu.contiguous().data_ptr(),
-                                        t.contiguous().data_ptr(),
-                                        spks.contiguous().data_ptr(),
-                                        cond.contiguous().data_ptr(),
-                                        x.data_ptr()])
+                                           mask.contiguous().data_ptr(),
+                                           mu.contiguous().data_ptr(),
+                                           t.contiguous().data_ptr(),
+                                           spks.contiguous().data_ptr(),
+                                           cond.contiguous().data_ptr(),
+                                           x.data_ptr()])
             return x
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):

+ 1 - 1
cosyvoice/hifigan/discriminator.py

@@ -1,6 +1,6 @@
 import torch
 import torch.nn as nn
-from torch.nn.utils import weight_norm
+from torch.nn.utils.parametrizations import weight_norm
 from typing import List, Optional, Tuple
 from einops import rearrange
 from torchaudio.transforms import Spectrogram

+ 1 - 1
cosyvoice/hifigan/f0_predictor.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 import torch
 import torch.nn as nn
-from torch.nn.utils import weight_norm
+from torch.nn.utils.parametrizations import weight_norm
 
 
 class ConvRNNF0Predictor(nn.Module):

+ 1 - 1
cosyvoice/hifigan/generator.py

@@ -23,7 +23,7 @@ import torch.nn.functional as F
 from torch.nn import Conv1d
 from torch.nn import ConvTranspose1d
 from torch.nn.utils import remove_weight_norm
-from torch.nn.utils import weight_norm
+from torch.nn.utils.parametrizations import weight_norm
 from torch.distributions.uniform import Uniform
 
 from cosyvoice.transformer.activation import Snake

+ 43 - 1
cosyvoice/utils/file_utils.py

@@ -1,5 +1,5 @@
 # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
-#               2024 Alibaba Inc (authors: Xiang Lyu)
+#               2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import json
+import tensorrt as trt
 import torchaudio
 import logging
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
@@ -45,3 +46,44 @@ def load_wav(wav, target_sr):
         assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
         speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
     return speech
+
+
+def convert_onnx_to_trt(trt_model, onnx_model, fp16):
+    _min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
+    _opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
+    _max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
+    input_names = ["x", "mask", "mu", "t", "spks", "cond"]
+
+    logging.info("Converting onnx to trt...")
+    network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
+    logger = trt.Logger(trt.Logger.INFO)
+    builder = trt.Builder(logger)
+    network = builder.create_network(network_flags)
+    parser = trt.OnnxParser(network, logger)
+    config = builder.create_builder_config()
+    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33)  # 8GB
+    if fp16:
+        config.set_flag(trt.BuilderFlag.FP16)
+    profile = builder.create_optimization_profile()
+    # load onnx model
+    with open(onnx_model, "rb") as f:
+        if not parser.parse(f.read()):
+            for error in range(parser.num_errors):
+                print(parser.get_error(error))
+            raise ValueError('failed to parse {}'.format(onnx_model))
+    # set input shapes
+    for i in range(len(input_names)):
+        profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
+    tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
+    # set input and output data type
+    for i in range(network.num_inputs):
+        input_tensor = network.get_input(i)
+        input_tensor.dtype = tensor_dtype
+    for i in range(network.num_outputs):
+        output_tensor = network.get_output(i)
+        output_tensor.dtype = tensor_dtype
+    config.add_optimization_profile(profile)
+    engine_bytes = builder.build_serialized_network(network, config)
+    # save trt engine
+    with open(trt_model, "wb") as f:
+        f.write(engine_bytes)

+ 8 - 2
runtime/python/fastapi/server.py

@@ -24,7 +24,7 @@ import numpy as np
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../../..'.format(ROOT_DIR))
 sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
 from cosyvoice.utils.file_utils import load_wav
 
 app = FastAPI()
@@ -79,5 +79,11 @@ if __name__ == '__main__':
                         default='iic/CosyVoice-300M',
                         help='local path or modelscope repo id')
     args = parser.parse_args()
-    cosyvoice = CosyVoice(args.model_dir)
+    try:
+        cosyvoice = CosyVoice(args.model_dir)
+    except Exception:
+        try:
+            cosyvoice = CosyVoice2(args.model_dir)
+        except Exception:
+            raise TypeError('no valid model_type!')
     uvicorn.run(app, host="0.0.0.0", port=args.port)

+ 8 - 2
runtime/python/grpc/server.py

@@ -25,7 +25,7 @@ import numpy as np
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
 sys.path.append('{}/../../..'.format(ROOT_DIR))
 sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
-from cosyvoice.cli.cosyvoice import CosyVoice
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
 
 logging.basicConfig(level=logging.DEBUG,
                     format='%(asctime)s %(levelname)s %(message)s')
@@ -33,7 +33,13 @@ logging.basicConfig(level=logging.DEBUG,
 
 class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
     def __init__(self, args):
-        self.cosyvoice = CosyVoice(args.model_dir)
+        try:
+            self.cosyvoice = CosyVoice(args.model_dir)
+        except Exception:
+            try:
+                self.cosyvoice = CosyVoice2(args.model_dir)
+            except Exception:
+                raise TypeError('no valid model_type!')
         logging.info('grpc service initialized')
 
     def Inference(self, request, context):

+ 8 - 1
webui.py

@@ -184,7 +184,14 @@ if __name__ == '__main__':
                         default='pretrained_models/CosyVoice2-0.5B',
                         help='local path or modelscope repo id')
     args = parser.parse_args()
-    cosyvoice = CosyVoice2(args.model_dir) if 'CosyVoice2' in args.model_dir else CosyVoice(args.model_dir)
+    try:
+        cosyvoice = CosyVoice(args.model_dir)
+    except Exception:
+        try:
+            cosyvoice = CosyVoice2(args.model_dir)
+        except Exception:
+            raise TypeError('no valid model_type!')
+
     sft_spk = cosyvoice.list_available_spks()
     prompt_sr = 16000
     default_data = np.zeros(cosyvoice.sample_rate)