server.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import argparse
  17. import logging
  18. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  19. from fastapi import FastAPI, UploadFile, Form, File
  20. from fastapi.responses import StreamingResponse
  21. from fastapi.middleware.cors import CORSMiddleware
  22. import uvicorn
  23. import numpy as np
  24. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  25. sys.path.append('{}/../../..'.format(ROOT_DIR))
  26. sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
  27. from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
  28. from cosyvoice.utils.file_utils import load_wav
  29. app = FastAPI()
  30. # set cross region allowance
  31. app.add_middleware(
  32. CORSMiddleware,
  33. allow_origins=["*"],
  34. allow_credentials=True,
  35. allow_methods=["*"],
  36. allow_headers=["*"])
  37. def generate_data(model_output):
  38. for i in model_output:
  39. tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
  40. yield tts_audio
  41. @app.get("/inference_sft")
  42. @app.post("/inference_sft")
  43. async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
  44. model_output = cosyvoice.inference_sft(tts_text, spk_id)
  45. return StreamingResponse(generate_data(model_output))
  46. @app.get("/inference_zero_shot")
  47. @app.post("/inference_zero_shot")
  48. async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
  49. prompt_speech_16k = load_wav(prompt_wav.file, 16000)
  50. model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
  51. return StreamingResponse(generate_data(model_output))
  52. @app.get("/inference_cross_lingual")
  53. @app.post("/inference_cross_lingual")
  54. async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
  55. prompt_speech_16k = load_wav(prompt_wav.file, 16000)
  56. model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
  57. return StreamingResponse(generate_data(model_output))
  58. @app.get("/inference_instruct")
  59. @app.post("/inference_instruct")
  60. async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
  61. model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
  62. return StreamingResponse(generate_data(model_output))
  63. @app.get("/inference_instruct2")
  64. @app.post("/inference_instruct2")
  65. async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()):
  66. prompt_speech_16k = load_wav(prompt_wav.file, 16000)
  67. model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k)
  68. return StreamingResponse(generate_data(model_output))
  69. if __name__ == '__main__':
  70. parser = argparse.ArgumentParser()
  71. parser.add_argument('--port',
  72. type=int,
  73. default=50000)
  74. parser.add_argument('--model_dir',
  75. type=str,
  76. default='iic/CosyVoice-300M',
  77. help='local path or modelscope repo id')
  78. args = parser.parse_args()
  79. try:
  80. cosyvoice = CosyVoice(args.model_dir)
  81. except Exception:
  82. try:
  83. cosyvoice = CosyVoice2(args.model_dir)
  84. except Exception:
  85. raise TypeError('no valid model_type!')
  86. uvicorn.run(app, host="0.0.0.0", port=args.port)