Explorar el Código

add fastapi client

iflamed hace 1 año
padre
commit
eb53ccbc19
Se han modificado 3 ficheros con 90 adiciones y 2 borrados
  1. 5 2
      README.md
  2. 78 0
      runtime/python/fastapi_client.py
  3. 7 0
      runtime/python/fastapi_server.py

+ 5 - 2
README.md

@@ -121,10 +121,13 @@ You can get familiar with CosyVoice following this recipie.
 The `main.py` file has added a `TTS` api with `CosyVoice-300M-SFT` model, you can update the code based on **Basic Usage** as above.
 
 ```sh
+cd runtime/python
+# Set inference model
+export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
 # For development
-fastapi dev --port 3003
+fastapi dev --port 6006 fastapi_server.py 
 # For production
-fastapi run --port 3003
+fastapi run --port 6006 fastapi_server.py 
 ```
 
 **Build for deployment**

+ 78 - 0
runtime/python/fastapi_client.py

@@ -0,0 +1,78 @@
+import argparse
+import logging
+import requests
+
+def saveResponse(path, response):
+    # 以二进制写入模式打开文件
+    with open(path, 'wb') as file:
+        # 将响应的二进制内容写入文件
+        file.write(response.content)
+
+def main():
+    api = args.api_base
+    if args.mode == 'sft':
+        url = api + "/api/inference/sft"
+        payload={
+            'tts': args.tts_text,
+            'role': args.spk_id
+        }
+        response = requests.request("POST", url, data=payload)
+        saveResponse(args.tts_wav, response)
+    elif args.mode == 'zero_shot':
+        url = api + "/api/inference/zero-shot"
+        payload={
+            'tts': args.tts_text,
+            'prompt': args.prompt_text
+        }
+        files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
+        response = requests.request("POST", url, data=payload, files=files)
+        saveResponse(args.tts_wav, response)
+    elif args.mode == 'cross_lingual':
+        url = api + "/api/inference/cross-lingual"
+        payload={
+            'tts': args.tts_text,
+        }
+        files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
+        response = requests.request("POST", url, data=payload, files=files)
+        saveResponse(args.tts_wav, response)
+    else:
+        url = api + "/api/inference/instruct"
+        payload = {
+            'tts': args.tts_text,
+            'role': args.spk_id,
+            'instruct': args.instruct_text
+        }
+        response = requests.request("POST", url, data=payload)
+        saveResponse(args.tts_wav, response)
+    logging.info("Response save to {}", args.tts_wav)
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--api_base',
+                        type=str,
+                        default='http://127.0.0.1:6006')
+    parser.add_argument('--mode',
+                        default='sft',
+                        choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
+                        help='request mode')
+    parser.add_argument('--tts_text',
+                        type=str,
+                        default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
+    parser.add_argument('--spk_id',
+                        type=str,
+                        default='中文女')
+    parser.add_argument('--prompt_text',
+                        type=str,
+                        default='希望你以后能够做的比我还好呦。')
+    parser.add_argument('--prompt_wav',
+                        type=str,
+                        default='../../zero_shot_prompt.wav')
+    parser.add_argument('--instruct_text',
+                        type=str,
+                        default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
+    parser.add_argument('--tts_wav',
+                        type=str,
+                        default='demo.wav')
+    args = parser.parse_args()
+    prompt_sr, target_sr = 16000, 22050
+    main()

+ 7 - 0
runtime/python/fastapi_server.py

@@ -1,3 +1,10 @@
+# Set inference model
+# export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
+# For development
+# fastapi dev --port 6006 fastapi_server.py 
+# For production deployment
+# fastapi run --port 6006 fastapi_server.py 
+
 import os
 import sys
 import io,time