|
|
@@ -32,12 +32,16 @@ from typing import List, Dict
|
|
|
|
|
|
import torch
|
|
|
from torch.utils.dlpack import to_dlpack
|
|
|
+from torch.nn import functional as F
|
|
|
|
|
|
import triton_python_backend_utils as pb_utils
|
|
|
|
|
|
from hyperpyyaml import load_hyperpyyaml
|
|
|
+from cosyvoice.utils.common import fade_in_out
|
|
|
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
|
|
from cosyvoice.utils.common import TrtContextWrapper
|
|
|
+from collections import defaultdict
|
|
|
+import numpy as np
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
logger = logging.getLogger(__name__)
|
|
|
@@ -81,6 +85,13 @@ class CosyVoice2Model:
|
|
|
if self.fp16 is True:
|
|
|
self.flow.half()
|
|
|
|
|
|
+ # streaming tts config
|
|
|
+ self.token_hop_len = 25
|
|
|
+ self.mel_cache_len = 8
|
|
|
+ self.source_cache_len = int(self.mel_cache_len * 480)
|
|
|
+ self.speech_window = np.hamming(2 * self.source_cache_len)
|
|
|
+ self.hift_cache_dict = defaultdict(lambda: None)
|
|
|
+
|
|
|
def load_jit(self, flow_encoder_model):
|
|
|
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
|
|
self.flow.encoder = flow_encoder
|
|
|
@@ -112,6 +123,43 @@ class CosyVoice2Model:
|
|
|
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
|
|
|
|
|
|
|
|
+ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
|
|
+ with torch.cuda.amp.autocast(self.fp16):
|
|
|
+ tts_mel, _ = self.flow.inference(token=token.to(self.device),
|
|
|
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
+ prompt_token=prompt_token.to(self.device),
|
|
|
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
+ prompt_feat=prompt_feat.to(self.device),
|
|
|
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
+ embedding=embedding.to(self.device),
|
|
|
+ streaming=stream,
|
|
|
+ finalize=finalize)
|
|
|
+ tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
|
|
+ # append hift cache
|
|
|
+ if self.hift_cache_dict[uuid] is not None:
|
|
|
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
|
|
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
|
|
+ else:
|
|
|
+ hift_cache_source = torch.zeros(1, 1, 0)
|
|
|
+ # keep overlap mel and hift cache
|
|
|
+ if finalize is False:
|
|
|
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
|
|
+ if self.hift_cache_dict[uuid] is not None:
|
|
|
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
|
|
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
|
|
+ 'source': tts_source[:, :, -self.source_cache_len:],
|
|
|
+ 'speech': tts_speech[:, -self.source_cache_len:]}
|
|
|
+ tts_speech = tts_speech[:, :-self.source_cache_len]
|
|
|
+ else:
|
|
|
+ if speed != 1.0:
|
|
|
+ assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
|
|
+ tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
|
|
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
|
|
+ if self.hift_cache_dict[uuid] is not None:
|
|
|
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
|
|
+ return tts_speech
|
|
|
+
|
|
|
+
|
|
|
class TritonPythonModel:
|
|
|
"""Triton Python model for vocoder.
|
|
|
|
|
|
@@ -166,25 +214,49 @@ class TritonPythonModel:
|
|
|
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
|
|
|
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
|
|
|
|
|
|
- tts_mel, _ = self.token2wav_model.model.flow.inference(
|
|
|
- token=target_speech_tokens,
|
|
|
- token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
|
|
- self.device
|
|
|
- ),
|
|
|
- prompt_token=prompt_speech_tokens,
|
|
|
- prompt_token_len=torch.tensor(
|
|
|
- [prompt_speech_tokens.shape[1]], dtype=torch.int32
|
|
|
- ).to(self.device),
|
|
|
- prompt_feat=prompt_speech_feat,
|
|
|
- prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
- embedding=prompt_spk_embedding,
|
|
|
- streaming=False,
|
|
|
- finalize=True,
|
|
|
- )
|
|
|
-
|
|
|
- audio_hat, _ = self.token2wav_model.model.hift.inference(
|
|
|
- speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
|
|
- )
|
|
|
+ # We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
|
|
|
+ token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
|
|
|
+ if token_offset is not None:
|
|
|
+ token_offset = token_offset.as_numpy().item()
|
|
|
+ finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
|
|
|
+ if not finalize:
|
|
|
+ stream = True
|
|
|
+ else:
|
|
|
+ stream = False
|
|
|
+ request_id = request.request_id()
|
|
|
+ print(f"token_offset: {token_offset}, finalize: {finalize}, request_id: {request_id}")
|
|
|
+ audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
|
|
|
+ prompt_token=prompt_speech_tokens,
|
|
|
+ prompt_feat=prompt_speech_feat,
|
|
|
+ embedding=prompt_spk_embedding,
|
|
|
+ token_offset=token_offset,
|
|
|
+ uuid=request_id,
|
|
|
+ stream=stream,
|
|
|
+ finalize=finalize)
|
|
|
+ if finalize:
|
|
|
+ print(f"dict keys: {self.token2wav_model.model.hift_cache_dict.keys()}")
|
|
|
+ self.token2wav_model.model.hift_cache_dict.pop(request_id)
|
|
|
+
|
|
|
+ else:
|
|
|
+ tts_mel, _ = self.token2wav_model.model.flow.inference(
|
|
|
+ token=target_speech_tokens,
|
|
|
+ token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
|
|
|
+ self.device
|
|
|
+ ),
|
|
|
+ prompt_token=prompt_speech_tokens,
|
|
|
+ prompt_token_len=torch.tensor(
|
|
|
+ [prompt_speech_tokens.shape[1]], dtype=torch.int32
|
|
|
+ ).to(self.device),
|
|
|
+ prompt_feat=prompt_speech_feat,
|
|
|
+ prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
+ embedding=prompt_spk_embedding,
|
|
|
+ streaming=False,
|
|
|
+ finalize=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ audio_hat, _ = self.token2wav_model.model.hift.inference(
|
|
|
+ speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
|
|
+ )
|
|
|
|
|
|
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
|
|
|