server.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Set inference model
  2. # export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
  3. # For development
  4. # fastapi dev --port 6006 fastapi_server.py
  5. # For production deployment
  6. # fastapi run --port 6006 fastapi_server.py
  7. import os
  8. import sys
  9. import io,time
  10. from fastapi import FastAPI, Response, File, UploadFile, Form
  11. from fastapi.responses import HTMLResponse
  12. from contextlib import asynccontextmanager
  13. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  14. sys.path.append('{}/../../..'.format(ROOT_DIR))
  15. sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
  16. from cosyvoice.cli.cosyvoice import CosyVoice
  17. from cosyvoice.utils.file_utils import load_wav
  18. import numpy as np
  19. import torch
  20. import torchaudio
  21. import logging
  22. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  23. class LaunchFailed(Exception):
  24. pass
  25. @asynccontextmanager
  26. async def lifespan(app: FastAPI):
  27. model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
  28. if model_dir:
  29. logging.info("MODEL_DIR is {}", model_dir)
  30. app.cosyvoice = CosyVoice(model_dir)
  31. # sft usage
  32. logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
  33. else:
  34. raise LaunchFailed("MODEL_DIR environment must set")
  35. yield
  36. app = FastAPI(lifespan=lifespan)
  37. def buildResponse(output):
  38. buffer = io.BytesIO()
  39. torchaudio.save(buffer, output, 22050, format="wav")
  40. buffer.seek(0)
  41. return Response(content=buffer.read(-1), media_type="audio/wav")
  42. @app.post("/api/inference/sft")
  43. @app.get("/api/inference/sft")
  44. async def sft(tts: str = Form(), role: str = Form()):
  45. start = time.process_time()
  46. output = app.cosyvoice.inference_sft(tts, role)
  47. end = time.process_time()
  48. logging.info("infer time is {} seconds", end-start)
  49. return buildResponse(output['tts_speech'])
  50. @app.post("/api/inference/zero-shot")
  51. async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
  52. start = time.process_time()
  53. prompt_speech = load_wav(audio.file, 16000)
  54. prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
  55. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  56. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  57. output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k)
  58. end = time.process_time()
  59. logging.info("infer time is {} seconds", end-start)
  60. return buildResponse(output['tts_speech'])
  61. @app.post("/api/inference/cross-lingual")
  62. async def crossLingual(tts: str = Form(), audio: UploadFile = File()):
  63. start = time.process_time()
  64. prompt_speech = load_wav(audio.file, 16000)
  65. prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
  66. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  67. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  68. output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k)
  69. end = time.process_time()
  70. logging.info("infer time is {} seconds", end-start)
  71. return buildResponse(output['tts_speech'])
  72. @app.post("/api/inference/instruct")
  73. @app.get("/api/inference/instruct")
  74. async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()):
  75. start = time.process_time()
  76. output = app.cosyvoice.inference_instruct(tts, role, instruct)
  77. end = time.process_time()
  78. logging.info("infer time is {} seconds", end-start)
  79. return buildResponse(output['tts_speech'])
  80. @app.get("/api/roles")
  81. async def roles():
  82. return {"roles": app.cosyvoice.list_avaliable_spks()}
  83. @app.get("/", response_class=HTMLResponse)
  84. async def root():
  85. return """
  86. <!DOCTYPE html>
  87. <html lang=zh-cn>
  88. <head>
  89. <meta charset=utf-8>
  90. <title>Api information</title>
  91. </head>
  92. <body>
  93. Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
  94. </body>
  95. </html>
  96. """