client_http.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of NVIDIA CORPORATION nor the names of its
  12. # contributors may be used to endorse or promote products derived
  13. # from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
  16. # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  18. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  19. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  20. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  21. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  22. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  23. # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  24. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  25. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  26. import requests
  27. import soundfile as sf
  28. import numpy as np
  29. import argparse
  30. def get_args():
  31. parser = argparse.ArgumentParser(
  32. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  33. )
  34. parser.add_argument(
  35. "--server-url",
  36. type=str,
  37. default="localhost:8000",
  38. help="Address of the server",
  39. )
  40. parser.add_argument(
  41. "--reference-audio",
  42. type=str,
  43. default="../../example/prompt_audio.wav",
  44. help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
  45. )
  46. parser.add_argument(
  47. "--reference-text",
  48. type=str,
  49. default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。",
  50. help="",
  51. )
  52. parser.add_argument(
  53. "--target-text",
  54. type=str,
  55. default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
  56. help="",
  57. )
  58. parser.add_argument(
  59. "--model-name",
  60. type=str,
  61. default="spark_tts",
  62. choices=[
  63. "f5_tts",
  64. "spark_tts",
  65. "cosyvoice2"],
  66. help="triton model_repo module name to request",
  67. )
  68. parser.add_argument(
  69. "--output-audio",
  70. type=str,
  71. default="output.wav",
  72. help="Path to save the output audio",
  73. )
  74. return parser.parse_args()
  75. def prepare_request(
  76. waveform,
  77. reference_text,
  78. target_text,
  79. sample_rate=16000,
  80. padding_duration: int = None,
  81. audio_save_dir: str = "./",
  82. ):
  83. assert len(waveform.shape) == 1, "waveform should be 1D"
  84. lengths = np.array([[len(waveform)]], dtype=np.int32)
  85. if padding_duration:
  86. # padding to nearset 10 seconds
  87. samples = np.zeros(
  88. (
  89. 1,
  90. padding_duration
  91. * sample_rate
  92. * ((int(len(waveform) / sample_rate) // padding_duration) + 1),
  93. ),
  94. dtype=np.float32,
  95. )
  96. samples[0, : len(waveform)] = waveform
  97. else:
  98. samples = waveform
  99. samples = samples.reshape(1, -1).astype(np.float32)
  100. data = {
  101. "inputs": [
  102. {
  103. "name": "reference_wav",
  104. "shape": samples.shape,
  105. "datatype": "FP32",
  106. "data": samples.tolist()
  107. },
  108. {
  109. "name": "reference_wav_len",
  110. "shape": lengths.shape,
  111. "datatype": "INT32",
  112. "data": lengths.tolist(),
  113. },
  114. {
  115. "name": "reference_text",
  116. "shape": [1, 1],
  117. "datatype": "BYTES",
  118. "data": [reference_text]
  119. },
  120. {
  121. "name": "target_text",
  122. "shape": [1, 1],
  123. "datatype": "BYTES",
  124. "data": [target_text]
  125. }
  126. ]
  127. }
  128. return data
  129. if __name__ == "__main__":
  130. args = get_args()
  131. server_url = args.server_url
  132. if not server_url.startswith(("http://", "https://")):
  133. server_url = f"http://{server_url}"
  134. url = f"{server_url}/v2/models/{args.model_name}/infer"
  135. waveform, sr = sf.read(args.reference_audio)
  136. assert sr == 16000, "sample rate hardcoded in server"
  137. samples = np.array(waveform, dtype=np.float32)
  138. data = prepare_request(samples, args.reference_text, args.target_text)
  139. rsp = requests.post(
  140. url,
  141. headers={"Content-Type": "application/json"},
  142. json=data,
  143. verify=False,
  144. params={"request_id": '0'}
  145. )
  146. result = rsp.json()
  147. audio = result["outputs"][0]["data"]
  148. audio = np.array(audio, dtype=np.float32)
  149. if args.model_name == "spark_tts":
  150. sample_rate = 16000
  151. else:
  152. sample_rate = 24000
  153. sf.write(args.output_audio, audio, sample_rate, "PCM_16")