1
0

fastapi_server.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import os
  2. import sys
  3. import io,time
  4. from fastapi import FastAPI, Response, File, UploadFile, Form
  5. from fastapi.responses import HTMLResponse
  6. from contextlib import asynccontextmanager
  7. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  8. sys.path.append('{}/../..'.format(ROOT_DIR))
  9. sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
  10. from cosyvoice.cli.cosyvoice import CosyVoice
  11. from cosyvoice.utils.file_utils import load_wav
  12. import numpy as np
  13. import torch
  14. import torchaudio
  15. import logging
  16. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  17. class LaunchFailed(Exception):
  18. pass
  19. @asynccontextmanager
  20. async def lifespan(app: FastAPI):
  21. model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
  22. if model_dir:
  23. logging.info("MODEL_DIR is {}", model_dir)
  24. app.cosyvoice = CosyVoice('../../'+model_dir)
  25. # sft usage
  26. logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
  27. else:
  28. raise LaunchFailed("MODEL_DIR environment must set")
  29. yield
  30. app = FastAPI(lifespan=lifespan)
  31. def buildResponse(output):
  32. buffer = io.BytesIO()
  33. torchaudio.save(buffer, output, 22050, format="wav")
  34. buffer.seek(0)
  35. return Response(content=buffer.read(-1), media_type="audio/wav")
  36. @app.post("/api/inference/sft")
  37. @app.get("/api/inference/sft")
  38. async def sft(tts: str = Form(), role: str = Form()):
  39. start = time.process_time()
  40. output = app.cosyvoice.inference_sft(tts, role)
  41. end = time.process_time()
  42. logging.info("infer time is {} seconds", end-start)
  43. return buildResponse(output['tts_speech'])
  44. @app.post("/api/inference/zero-shot")
  45. async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
  46. start = time.process_time()
  47. prompt_speech = load_wav(audio.file, 16000)
  48. prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
  49. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  50. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  51. output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k)
  52. end = time.process_time()
  53. logging.info("infer time is {} seconds", end-start)
  54. return buildResponse(output['tts_speech'])
  55. @app.post("/api/inference/cross-lingual")
  56. async def crossLingual(tts: str = Form(), audio: UploadFile = File()):
  57. start = time.process_time()
  58. prompt_speech = load_wav(audio.file, 16000)
  59. prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
  60. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  61. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  62. output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k)
  63. end = time.process_time()
  64. logging.info("infer time is {} seconds", end-start)
  65. return buildResponse(output['tts_speech'])
  66. @app.post("/api/inference/instruct")
  67. @app.get("/api/inference/instruct")
  68. async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()):
  69. start = time.process_time()
  70. output = app.cosyvoice.inference_instruct(tts, role, instruct)
  71. end = time.process_time()
  72. logging.info("infer time is {} seconds", end-start)
  73. return buildResponse(output['tts_speech'])
  74. @app.get("/api/roles")
  75. async def roles():
  76. return {"roles": app.cosyvoice.list_avaliable_spks()}
  77. @app.get("/", response_class=HTMLResponse)
  78. async def root():
  79. return """
  80. <!DOCTYPE html>
  81. <html lang=zh-cn>
  82. <head>
  83. <meta charset=utf-8>
  84. <title>Api information</title>
  85. </head>
  86. <body>
  87. Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
  88. </body>
  89. </html>
  90. """