Browse Source

add flow cache inference code

lyuxiang.lx 10 months ago
parent
commit
39ffc50dec
4 changed files with 19 additions and 18 deletions
  1. 1 1
      README.md
  2. 3 3
      cosyvoice/cli/cosyvoice.py
  3. 11 10
      cosyvoice/cli/model.py
  4. 4 4
      examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

+ 1 - 1
README.md

@@ -128,7 +128,7 @@ import torchaudio
 
 **CosyVoice2 Usage**
 ```python
-cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
+cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False, use_flow_cache=False)
 
 # NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
 # zero_shot usage

+ 3 - 3
cosyvoice/cli/cosyvoice.py

@@ -129,7 +129,7 @@ class CosyVoice:
 
 class CosyVoice2(CosyVoice):
 
-    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
+    def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
         self.instruct = True if '-Instruct' in model_dir else False
         self.model_dir = model_dir
         self.fp16 = fp16
@@ -151,9 +151,9 @@ class CosyVoice2(CosyVoice):
         if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
             load_jit, load_trt, fp16 = False, False, False
             logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
-        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
+        self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
         self.model.load('{}/llm.pt'.format(model_dir),
-                        '{}/flow.pt'.format(model_dir),
+                        '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
                         '{}/hift.pt'.format(model_dir))
         if load_jit:
             self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))

+ 11 - 10
cosyvoice/cli/model.py

@@ -288,19 +288,20 @@ class CosyVoice2Model(CosyVoiceModel):
                  llm: torch.nn.Module,
                  flow: torch.nn.Module,
                  hift: torch.nn.Module,
-                 fp16: bool):
+                 fp16: bool,
+                 use_flow_cache: bool):
         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         self.llm = llm
         self.flow = flow
         self.hift = hift
         self.fp16 = fp16
+        self.use_flow_cache = use_flow_cache
         if self.fp16 is True:
             self.llm.half()
             self.flow.half()
-        self.token_hop_len = self.flow.encoder.static_chunk_size
-        # flow decoder required_cache_size
-        # TODO 基模型训练时没有设置num_decoding_left_chunks,需要重新训一下才能指定flow_decoder_required_cache_size
-        self.flow_decoder_required_cache_size = 999
+        # stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
+        self.token_hop_len = 25
+        self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len
         # hift cache
         self.mel_cache_len = 8
         self.source_cache_len = int(self.mel_cache_len * 480)
@@ -339,7 +340,7 @@ class CosyVoice2Model(CosyVoiceModel):
         return cache
 
     def trim_flow_cache(self, cache):
-        if cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size:
+        if self.flow_decoder_required_cache_size > 0:
             cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
             cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
             cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
@@ -399,10 +400,10 @@ class CosyVoice2Model(CosyVoiceModel):
             prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
         # this_uuid is used to track variables related to this inference thread
         this_uuid = str(uuid.uuid1())
-        # NOTE flow model is only trained with static_chunk_size, so we need to trim flow prompt
-        n_chunk = int(flow_prompt_speech_token.size(1) / self.token_hop_len)
-        flow_prompt_speech_token = flow_prompt_speech_token[:, :n_chunk * self.token_hop_len]
-        prompt_speech_feat = prompt_speech_feat[:, :n_chunk * self.token_hop_len * 2]
+        # NOTE in cache mode, trim flow_prompt to same size as flow_decoder_required_cache_size
+        if self.use_flow_cache is True:
+            flow_prompt_speech_token = flow_prompt_speech_token[:, -self.flow_decoder_required_cache_size:]
+            prompt_speech_feat = prompt_speech_feat[:, -self.flow_decoder_required_cache_size * 2:]
         with self.lock:
             self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
             self.hift_cache_dict[this_uuid] = None

+ 4 - 4
examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

@@ -14,8 +14,8 @@ token_frame_rate: 25
 token_mel_ratio: 2
 
 # stream related params
-chunk_size: 2 # streaming inference chunk size, in second
-num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size
+chunk_size: 25 # streaming inference chunk size, in token
+num_decoding_left_chunks: 1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
 
 # model params
 # for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
@@ -60,7 +60,7 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
         input_size: 512
         use_cnn_module: False
         macaron_style: False
-        static_chunk_size: !ref <chunk_size> * <token_frame_rate>
+        static_chunk_size: !ref <chunk_size>
     decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
         in_channels: 240
         n_spks: 1
@@ -83,7 +83,7 @@ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
             num_mid_blocks: 12
             num_heads: 8
             act_fn: 'gelu'
-            static_chunk_size: !ref <chunk_size> * <token_frame_rate> * <token_mel_ratio> # here we use static_chunk_size because we want to fix kv cache size during inference
+            static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
             num_decoding_left_chunks: !ref <num_decoding_left_chunks>
 
 hift: !new:cosyvoice.hifigan.generator.HiFTGenerator