|
|
@@ -0,0 +1,144 @@
|
|
|
+import argparse
|
|
|
+import asyncio
|
|
|
+import io
|
|
|
+import os
|
|
|
+import sys
|
|
|
+
|
|
|
+from enum import Enum
|
|
|
+from fastapi import FastAPI, HTTPException, Form, UploadFile, File
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
+from fastapi.responses import StreamingResponse, Response, JSONResponse, FileResponse
|
|
|
+import uvicorn
|
|
|
+from pydantic import BaseModel, Field
|
|
|
+from typing import Optional, Annotated
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from cosyvoice.cli.cosyvoice import CosyVoice2
|
|
|
+from cosyvoice.utils.file_utils import load_wav
|
|
|
+from cosyvoice.utils.common import set_all_random_seed
|
|
|
+import torchaudio
|
|
|
+
|
|
|
+# FastAPI实例
|
|
|
+app = FastAPI()
|
|
|
+
|
|
|
+app.add_middleware(
|
|
|
+ CORSMiddleware,
|
|
|
+ allow_origins=["*"],
|
|
|
+ allow_credentials=True,
|
|
|
+ allow_methods=["*"],
|
|
|
+ allow_headers=["*"],
|
|
|
+)
|
|
|
+
|
|
|
+# 读取模组路径
|
|
|
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
+sys.path.append(f'{ROOT_DIR}/third_party/Matcha-TTS')
|
|
|
+
|
|
|
+
|
|
|
+class ModeEnum(Enum):
|
|
|
+ zero_shot = "zero_shot"
|
|
|
+ instruct = "instruct"
|
|
|
+ sft = "sft"
|
|
|
+
|
|
|
+
|
|
|
+def load_wav_from_upload_file(wav, target_sr):
|
|
|
+ speech, sample_rate = torchaudio.load(wav)
|
|
|
+ speech = speech.mean(dim=0, keepdim=True)
|
|
|
+ if sample_rate != target_sr:
|
|
|
+ assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
|
|
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
|
|
+ return speech
|
|
|
+
|
|
|
+
|
|
|
+class AudioForm(BaseModel):
|
|
|
+ tts_text: str = Field(default="你好,世界", description="需要转换为语音的文本内容。")
|
|
|
+ mode: ModeEnum = Field(description="指定转换模式,可选值包括:zero_shot, instruct, sft。")
|
|
|
+ sft_dropdown: Optional[str] = Field(default=None, description="自定义音色名称,仅在使用 sft 模式时需要。")
|
|
|
+ prompt_text: Optional[str] = Field(default=None, description="额外的提示文本,仅在使用 zero_shot 模式时需要。")
|
|
|
+ instruct_text: Optional[str] = Field(default=None, description="指令文本,仅在使用 instruct 模式时需要。")
|
|
|
+ seed: Optional[int] = Field(default=0, description="随机种子,用于控制生成的随机性,默认为 0。")
|
|
|
+ stream: Optional[bool] = Field(default=False, description="是否启用流式输出,默认为 false。")
|
|
|
+ speed: Optional[float] = Field(default=1.0, description="语音速度,默认为 1.0。")
|
|
|
+ prompt_voice: Optional[UploadFile] = Field(default=None, description="自定义音频文件,默认为None")
|
|
|
+ model_config = {"extra": "forbid"}
|
|
|
+
|
|
|
+# 音频生成函数
|
|
|
+async def generate_audio(data: Annotated[AudioForm, Form(..., media_type="multipart/form-data")]):
|
|
|
+ set_all_random_seed(data.seed)
|
|
|
+ prompt_speech_16k = load_wav_from_upload_file(await data.prompt_voice.read(), 16000) if data.prompt_voice else None
|
|
|
+
|
|
|
+ inference_map = {
|
|
|
+ 'zero_shot': cosyvoice.inference_zero_shot,
|
|
|
+ 'instruct': cosyvoice.inference_instruct2,
|
|
|
+ 'sft': cosyvoice.inference_sft
|
|
|
+ }
|
|
|
+
|
|
|
+ if data.mode.value not in inference_map:
|
|
|
+ raise HTTPException(status_code=400, detail="Invalid mode")
|
|
|
+
|
|
|
+ args = None
|
|
|
+ if data.mode.value == 'sft':
|
|
|
+ args = (data.tts_text, data.sft_dropdown, data.stream, data.speed)
|
|
|
+ elif data.mode.value == 'zero_shot':
|
|
|
+ args = (data.tts_text, data.prompt_text, prompt_speech_16k, data.stream, data.speed)
|
|
|
+ elif data.mode.value == 'instruct':
|
|
|
+ args = (data.tts_text, data.instruct_text, prompt_speech_16k, data.stream, data.speed)
|
|
|
+
|
|
|
+ try:
|
|
|
+ result = await asyncio.to_thread(inference_map[data.mode.value], *args)
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(status_code=500, detail=f"Audio generation error: {str(e)}")
|
|
|
+
|
|
|
+ if result is None:
|
|
|
+ raise HTTPException(status_code=500, detail="Failed to generate audio")
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+# 流式处理
|
|
|
+async def generate_audio_stream(data: Annotated[AudioForm, Form(..., media_type="multipart/form-data")]):
|
|
|
+ result = await generate_audio(data)
|
|
|
+ for i in result:
|
|
|
+ audio_data = i['tts_speech'].numpy().flatten()
|
|
|
+ audio_bytes = (audio_data * (2**15)).astype(np.int16).tobytes()
|
|
|
+ yield audio_bytes
|
|
|
+
|
|
|
+# 非流式处理
|
|
|
+async def generate_audio_buffer(data: Annotated[AudioForm, Form(..., media_type="multipart/form-data")]):
|
|
|
+ result = await generate_audio(data)
|
|
|
+ buffer = io.BytesIO()
|
|
|
+ audio_data = torch.cat([j['tts_speech'] for j in result], dim=1)
|
|
|
+ torchaudio.save(buffer, audio_data, cosyvoice.sample_rate, format="wav")
|
|
|
+ buffer.seek(0)
|
|
|
+ return buffer
|
|
|
+
|
|
|
+@app.post("/text-tts")
|
|
|
+async def text_tts(data: Annotated[AudioForm, Form(..., media_type="multipart/form-data")]):
|
|
|
+ print(data)
|
|
|
+ if not data.tts_text:
|
|
|
+ raise HTTPException(status_code=400, detail="Query parameter 'tts_text' is required")
|
|
|
+
|
|
|
+ if data.stream:
|
|
|
+ # 流式输出
|
|
|
+ return StreamingResponse(generate_audio_stream(data), media_type="audio/pcm")
|
|
|
+ else:
|
|
|
+ # 非流式输出
|
|
|
+ buffer = await generate_audio_buffer(data)
|
|
|
+ return Response(buffer.read(), media_type="audio/wav")
|
|
|
+ # audio = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
|
|
|
+ # audio.write(buffer.read())
|
|
|
+ # audio.close()
|
|
|
+ # return FileResponse(audio.name, media_type="audio/wav")
|
|
|
+
|
|
|
+# 音色列表
|
|
|
+@app.get("/sft_spk")
|
|
|
+async def get_sft_spk():
|
|
|
+ sft_spk = cosyvoice.list_available_spks()
|
|
|
+ return JSONResponse(content=sft_spk)
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice2-0.5B', help='local path or modelscope repo id')
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # 初始化CosyVoice模型
|
|
|
+ cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False)
|
|
|
+ uvicorn.run(app, host='0.0.0.0', port=50001)
|