token2wav_asr_server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
  2. # SPDX-License-Identifier: Apache-2.0
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Pytriton server for token2wav conversion and ASR"""
  16. import argparse
  17. import io
  18. import logging
  19. from typing import Any, List
  20. import numpy as np
  21. import torch
  22. from scipy.signal import resample
  23. import sys
  24. import random
  25. import re
  26. from jiwer import wer
  27. from pypinyin import lazy_pinyin, Style
  28. from tn.chinese.normalizer import Normalizer as ZhNormalizer
  29. # Chinese text normalizer (cached globally)
  30. zh_tn_model = ZhNormalizer(
  31. cache_dir="./cache",
  32. remove_erhua=False,
  33. remove_interjections=False,
  34. remove_puncts=True,
  35. overwrite_cache=True,
  36. )
  37. from pytriton.decorators import batch
  38. from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
  39. from pytriton.triton import Triton, TritonConfig
  40. from pytriton.proxy.types import Request
  41. from omnisense.models import OmniSenseVoiceSmall
  42. from cosyvoice.cli.cosyvoice import CosyVoice2
  43. from datasets import load_dataset
  44. sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
  45. logger = logging.getLogger("token2wav_asr_server")
  46. class _ASR_Server:
  47. """Wraps a single OmniSenseVoiceSmall model instance for Triton."""
  48. def __init__(self, device_id: int):
  49. self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
  50. @batch
  51. def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
  52. """
  53. WAV: np.ndarray, WAV_LENS: np.ndarray
  54. LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
  55. See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
  56. """
  57. logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
  58. wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
  59. results = self._model.transcribe_single_batch(
  60. wavs,
  61. language="zh",
  62. textnorm="woitn",
  63. )
  64. texts = [result.text for result in results]
  65. transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
  66. return {"TRANSCRIPTS": transcripts}
  67. def audio_decode_cosyvoice2(
  68. audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
  69. ):
  70. """
  71. Generate audio from tokens with optional tone and prompt embedding.
  72. """
  73. model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
  74. "empty", prompt_text, prompt_speech_16k, 24000
  75. )
  76. tts_mel, _ = codec_decoder.model.flow.inference(
  77. token=audio_tokens.to(codec_decoder.model.device),
  78. token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
  79. codec_decoder.model.device
  80. ),
  81. prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
  82. codec_decoder.model.device
  83. ),
  84. prompt_token_len=torch.tensor(
  85. [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
  86. ).to(codec_decoder.model.device),
  87. prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
  88. codec_decoder.model.device
  89. ),
  90. prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
  91. codec_decoder.model.device
  92. ),
  93. embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
  94. finalize=True,
  95. )
  96. audio_hat, _ = codec_decoder.model.hift.inference(
  97. speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
  98. )
  99. return audio_hat
  100. def get_random_prompt_from_dataset(dataset):
  101. """
  102. Get random prompt text and speech from the pre-loaded dataset.
  103. Returns (prompt_text, prompt_speech_16k)
  104. """
  105. random_idx = random.randint(0, len(dataset) - 1)
  106. sample = dataset[random_idx]
  107. # Extract audio data
  108. audio_data = sample["audio"]
  109. audio_array = audio_data["array"]
  110. sample_rate = audio_data["sampling_rate"]
  111. # Convert audio to 16kHz if needed
  112. if sample_rate != 16000:
  113. num_samples = int(len(audio_array) * (16000 / sample_rate))
  114. audio_array = resample(audio_array, num_samples)
  115. # Convert to torch tensor
  116. prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
  117. prompt_text = sample["text"]
  118. # remove space in prompt_text
  119. prompt_text = prompt_text.replace(" ", "")
  120. return prompt_text, prompt_speech_16k
  121. class _Token2Wav_ASR:
  122. """Wraps a single OmniSenseVoiceSmall model instance for Triton."""
  123. def __init__(self, device_id: int):
  124. self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
  125. self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
  126. # Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
  127. # CosyVoice2 internally uses generic "cuda" device, so we first switch the
  128. # current CUDA context to the desired card before the object is created.
  129. # Afterwards, all parameters loaded with the generic "cuda" device will
  130. # reside on this GPU. We keep the selected id in `self.device_id` and
  131. # will set the context again for every forward call to avoid race
  132. # conditions when several instances are used in the same process.
  133. self.device_id = device_id
  134. # Construct the TTS codec decoder under the correct CUDA device context
  135. with torch.cuda.device(self.device_id):
  136. self.codec_decoder = CosyVoice2(
  137. "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
  138. )
  139. @batch
  140. def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
  141. """
  142. WAV: np.ndarray, WAV_LENS: np.ndarray
  143. LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
  144. See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
  145. """
  146. # Ensure the default CUDA device is set correctly for this invocation
  147. torch.cuda.set_device(self.device_id)
  148. if self.device_id == 0:
  149. print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
  150. tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
  151. # Decode ground-truth text strings (BYTES → str)
  152. if GT_TEXT.ndim == 2:
  153. gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
  154. else:
  155. gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
  156. wavs = []
  157. for tokens in tokens_list:
  158. prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
  159. audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
  160. audio_hat = audio_decode_cosyvoice2(
  161. audio_tokens,
  162. prompt_text,
  163. prompt_speech_16k,
  164. self.codec_decoder,
  165. )
  166. # resample to 16000 using soundfile
  167. audio_hat = audio_hat.squeeze(0).float().cpu()
  168. audio_hat = audio_hat.numpy()
  169. num_samples = int(len(audio_hat) * (16000 / 24000))
  170. audio_hat = resample(audio_hat, num_samples)
  171. wavs.append(audio_hat)
  172. results = self.asr_model.transcribe_single_batch(
  173. wavs,
  174. language="zh",
  175. textnorm="woitn",
  176. )
  177. texts = [result.text for result in results]
  178. # ---------------- Reward computation ----------------
  179. rewards = []
  180. for gt_text, hyp_text in zip(gt_texts, texts):
  181. gt_norm = zh_tn_model.normalize(gt_text).lower()
  182. hyp_norm = zh_tn_model.normalize(hyp_text).lower()
  183. gt_pinyin = lazy_pinyin(
  184. gt_norm,
  185. style=Style.TONE3,
  186. tone_sandhi=True,
  187. neutral_tone_with_five=True,
  188. )
  189. hyp_pinyin = lazy_pinyin(
  190. hyp_norm,
  191. style=Style.TONE3,
  192. tone_sandhi=True,
  193. neutral_tone_with_five=True,
  194. )
  195. c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
  196. reward_val = 1.0 - np.tanh(3.0 * c)
  197. reward_val = max(0.0, min(1.0, reward_val))
  198. rewards.append(reward_val)
  199. print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
  200. transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
  201. rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
  202. return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
  203. def _infer_function_factory(device_ids: List[int], model_name: str):
  204. """Creates a list of inference functions, one for each requested device ID."""
  205. infer_funcs = []
  206. for device_id in device_ids:
  207. if model_name == "sensevoice":
  208. infer_funcs.append(_ASR_Server(device_id=device_id))
  209. else:
  210. infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
  211. return infer_funcs
  212. def main():
  213. parser = argparse.ArgumentParser(description=__doc__)
  214. parser.add_argument(
  215. "--max-batch-size",
  216. type=int,
  217. default=32,
  218. help="Batch size of request.",
  219. required=False,
  220. )
  221. parser.add_argument(
  222. "--verbose",
  223. action="store_true",
  224. default=False,
  225. )
  226. parser.add_argument(
  227. "--number-of-instances-per-device",
  228. type=int,
  229. default=1,
  230. help="Number of model instances to load.",
  231. required=False,
  232. )
  233. parser.add_argument(
  234. "--number-of-devices",
  235. type=int,
  236. default=8,
  237. help="Number of devices to use.",
  238. )
  239. parser.add_argument(
  240. "--model-name",
  241. type=str,
  242. default="token2wav_asr",
  243. choices=["token2wav_asr", "sensevoice"],
  244. help="Model name.",
  245. )
  246. args = parser.parse_args()
  247. log_level = logging.DEBUG if args.verbose else logging.INFO
  248. logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
  249. triton_config = TritonConfig(
  250. http_port=8000,
  251. grpc_port=8001,
  252. metrics_port=8002,
  253. )
  254. device_ids = [i for i in range(args.number_of_devices)]
  255. device_ids = device_ids * args.number_of_instances_per_device
  256. with Triton(config=triton_config) as triton:
  257. logger.info("Loading SenseVoice model on device ids: %s", device_ids)
  258. if args.model_name == "sensevoice":
  259. triton.bind(
  260. model_name="sensevoice",
  261. infer_func=_infer_function_factory(device_ids, args.model_name),
  262. inputs=[
  263. Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
  264. Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
  265. Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
  266. Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
  267. ],
  268. outputs=[
  269. Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
  270. ],
  271. config=ModelConfig(
  272. max_batch_size=args.max_batch_size,
  273. batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
  274. ),
  275. strict=True,
  276. )
  277. else:
  278. triton.bind(
  279. model_name="token2wav_asr",
  280. infer_func=_infer_function_factory(device_ids, args.model_name),
  281. inputs=[
  282. Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
  283. Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
  284. Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
  285. ],
  286. outputs=[
  287. Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
  288. Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
  289. ],
  290. config=ModelConfig(
  291. max_batch_size=args.max_batch_size,
  292. batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
  293. ),
  294. strict=True,
  295. )
  296. logger.info("Serving inference")
  297. triton.serve()
  298. if __name__ == "__main__":
  299. main()