server.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Set inference model
  2. # export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
  3. # For development
  4. # fastapi dev --port 6006 fastapi_server.py
  5. # For production deployment
  6. # fastapi run --port 6006 fastapi_server.py
  7. import os
  8. import sys
  9. import io,time
  10. from fastapi import FastAPI, Response, File, UploadFile, Form
  11. from fastapi.responses import HTMLResponse
  12. from fastapi.middleware.cors import CORSMiddleware #引入 CORS中间件模块
  13. from contextlib import asynccontextmanager
  14. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  15. sys.path.append('{}/../../..'.format(ROOT_DIR))
  16. sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
  17. from cosyvoice.cli.cosyvoice import CosyVoice
  18. from cosyvoice.utils.file_utils import load_wav
  19. import numpy as np
  20. import torch
  21. import torchaudio
  22. import logging
  23. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  24. class LaunchFailed(Exception):
  25. pass
  26. @asynccontextmanager
  27. async def lifespan(app: FastAPI):
  28. model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
  29. if model_dir:
  30. logging.info("MODEL_DIR is {}", model_dir)
  31. app.cosyvoice = CosyVoice(model_dir)
  32. # sft usage
  33. logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
  34. else:
  35. raise LaunchFailed("MODEL_DIR environment must set")
  36. yield
  37. app = FastAPI(lifespan=lifespan)
  38. #设置允许访问的域名
  39. origins = ["*"] #"*",即为所有,也可以改为允许的特定ip。
  40. app.add_middleware(
  41. CORSMiddleware,
  42. allow_origins=origins, #设置允许的origins来源
  43. allow_credentials=True,
  44. allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
  45. allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。
  46. def buildResponse(output):
  47. buffer = io.BytesIO()
  48. torchaudio.save(buffer, output, 22050, format="wav")
  49. buffer.seek(0)
  50. return Response(content=buffer.read(-1), media_type="audio/wav")
  51. @app.post("/api/inference/sft")
  52. @app.get("/api/inference/sft")
  53. async def sft(tts: str = Form(), role: str = Form()):
  54. start = time.process_time()
  55. output = app.cosyvoice.inference_sft(tts, role)
  56. end = time.process_time()
  57. logging.info("infer time is {} seconds", end-start)
  58. return buildResponse(output['tts_speech'])
  59. @app.post("/api/inference/zero-shot")
  60. async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
  61. start = time.process_time()
  62. prompt_speech = load_wav(audio.file, 16000)
  63. prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
  64. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  65. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  66. output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k)
  67. end = time.process_time()
  68. logging.info("infer time is {} seconds", end-start)
  69. return buildResponse(output['tts_speech'])
  70. @app.post("/api/inference/cross-lingual")
  71. async def crossLingual(tts: str = Form(), audio: UploadFile = File()):
  72. start = time.process_time()
  73. prompt_speech = load_wav(audio.file, 16000)
  74. prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
  75. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  76. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  77. output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k)
  78. end = time.process_time()
  79. logging.info("infer time is {} seconds", end-start)
  80. return buildResponse(output['tts_speech'])
  81. @app.post("/api/inference/instruct")
  82. @app.get("/api/inference/instruct")
  83. async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()):
  84. start = time.process_time()
  85. output = app.cosyvoice.inference_instruct(tts, role, instruct)
  86. end = time.process_time()
  87. logging.info("infer time is {} seconds", end-start)
  88. return buildResponse(output['tts_speech'])
  89. @app.get("/api/roles")
  90. async def roles():
  91. return {"roles": app.cosyvoice.list_avaliable_spks()}
  92. @app.get("/", response_class=HTMLResponse)
  93. async def root():
  94. return """
  95. <!DOCTYPE html>
  96. <html lang=zh-cn>
  97. <head>
  98. <meta charset=utf-8>
  99. <title>Api information</title>
  100. </head>
  101. <body>
  102. Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
  103. </body>
  104. </html>
  105. """