1
0

fastapi_client.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import argparse
  2. import logging
  3. import requests
  4. def saveResponse(path, response):
  5. # 以二进制写入模式打开文件
  6. with open(path, 'wb') as file:
  7. # 将响应的二进制内容写入文件
  8. file.write(response.content)
  9. def main():
  10. api = args.api_base
  11. if args.mode == 'sft':
  12. url = api + "/api/inference/sft"
  13. payload={
  14. 'tts': args.tts_text,
  15. 'role': args.spk_id
  16. }
  17. response = requests.request("POST", url, data=payload)
  18. saveResponse(args.tts_wav, response)
  19. elif args.mode == 'zero_shot':
  20. url = api + "/api/inference/zero-shot"
  21. payload={
  22. 'tts': args.tts_text,
  23. 'prompt': args.prompt_text
  24. }
  25. files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
  26. response = requests.request("POST", url, data=payload, files=files)
  27. saveResponse(args.tts_wav, response)
  28. elif args.mode == 'cross_lingual':
  29. url = api + "/api/inference/cross-lingual"
  30. payload={
  31. 'tts': args.tts_text,
  32. }
  33. files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
  34. response = requests.request("POST", url, data=payload, files=files)
  35. saveResponse(args.tts_wav, response)
  36. else:
  37. url = api + "/api/inference/instruct"
  38. payload = {
  39. 'tts': args.tts_text,
  40. 'role': args.spk_id,
  41. 'instruct': args.instruct_text
  42. }
  43. response = requests.request("POST", url, data=payload)
  44. saveResponse(args.tts_wav, response)
  45. logging.info("Response save to {}", args.tts_wav)
  46. if __name__ == "__main__":
  47. parser = argparse.ArgumentParser()
  48. parser.add_argument('--api_base',
  49. type=str,
  50. default='http://127.0.0.1:6006')
  51. parser.add_argument('--mode',
  52. default='sft',
  53. choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
  54. help='request mode')
  55. parser.add_argument('--tts_text',
  56. type=str,
  57. default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
  58. parser.add_argument('--spk_id',
  59. type=str,
  60. default='中文女')
  61. parser.add_argument('--prompt_text',
  62. type=str,
  63. default='希望你以后能够做的比我还好呦。')
  64. parser.add_argument('--prompt_wav',
  65. type=str,
  66. default='../../zero_shot_prompt.wav')
  67. parser.add_argument('--instruct_text',
  68. type=str,
  69. default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
  70. parser.add_argument('--tts_wav',
  71. type=str,
  72. default='demo.wav')
  73. args = parser.parse_args()
  74. prompt_sr, target_sr = 16000, 22050
  75. main()