|
|
@@ -301,7 +301,7 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
self.flow.half()
|
|
|
# stream related params, check examples/libritts/cosyvoice2/conf/cosyvoice2.yaml
|
|
|
self.token_hop_len = 25
|
|
|
- self.flow_decoder_required_cache_size = -1 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
|
|
|
+ self.flow_decoder_required_cache_size = 0 if use_flow_cache is False else 1 * self.token_hop_len * self.flow.token_mel_ratio
|
|
|
# hift cache
|
|
|
self.mel_cache_len = 8
|
|
|
self.source_cache_len = int(self.mel_cache_len * 480)
|
|
|
@@ -325,11 +325,11 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
'upsample_kv_cache': torch.zeros(4, 1, 8, 0, 64 * 2).to(self.device)}
|
|
|
decoder_cache = {'offset': 0,
|
|
|
'down_blocks_conv_cache': torch.zeros(10, 1, 2, 832, 2).to(self.device),
|
|
|
- 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
|
|
|
+ 'down_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
|
'mid_blocks_conv_cache': torch.zeros(10, 12, 2, 512, 2).to(self.device),
|
|
|
- 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, 0, 512, 2).to(self.device),
|
|
|
+ 'mid_blocks_kv_cache': torch.zeros(10, 12, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
|
'up_blocks_conv_cache': torch.zeros(10, 1, 2, 1024, 2).to(self.device),
|
|
|
- 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, 0, 512, 2).to(self.device),
|
|
|
+ 'up_blocks_kv_cache': torch.zeros(10, 1, 4, 2, self.flow_decoder_required_cache_size, 512, 2).to(self.device),
|
|
|
'final_blocks_conv_cache': torch.zeros(10, 2, 256, 2).to(self.device)}
|
|
|
if self.fp16 is True:
|
|
|
for cache in [encoder_cache, decoder_cache]:
|
|
|
@@ -339,13 +339,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
cache = {'encoder_cache': encoder_cache, 'decoder_cache': decoder_cache}
|
|
|
return cache
|
|
|
|
|
|
- def trim_flow_cache(self, cache):
|
|
|
- if self.flow_decoder_required_cache_size > 0 and cache['decoder_cache']['down_blocks_kv_cache'].size(4) > self.flow_decoder_required_cache_size:
|
|
|
- cache['decoder_cache']['down_blocks_kv_cache'] = cache['decoder_cache']['down_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
|
|
- cache['decoder_cache']['mid_blocks_kv_cache'] = cache['decoder_cache']['mid_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
|
|
- cache['decoder_cache']['up_blocks_kv_cache'] = cache['decoder_cache']['up_blocks_kv_cache'][:, :, :, :, -self.flow_decoder_required_cache_size:]
|
|
|
- return cache
|
|
|
-
|
|
|
def load_jit(self, flow_encoder_model):
|
|
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
|
self.flow.encoder = flow_encoder
|
|
|
@@ -369,7 +362,6 @@ class CosyVoice2Model(CosyVoiceModel):
|
|
|
embedding=embedding.to(self.device),
|
|
|
cache=self.flow_cache_dict[uuid],
|
|
|
finalize=finalize)
|
|
|
- self.flow_cache_dict[uuid] = self.trim_flow_cache(self.flow_cache_dict[uuid])
|
|
|
# append hift cache
|
|
|
if self.hift_cache_dict[uuid] is not None:
|
|
|
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|