wuyue 9 maanden geleden
bovenliggende
commit
42373bcf69
1 gewijzigde bestanden met toevoegingen van 144 en 0 verwijderingen
  1. 144 0
      api.py

+ 144 - 0
api.py

@@ -0,0 +1,144 @@
+import argparse
+import asyncio
+import io
+import os
+import sys
+
+from enum import Enum
+from fastapi import FastAPI, HTTPException, Form, UploadFile, File
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse, Response, JSONResponse, FileResponse
+import uvicorn
+from pydantic import BaseModel, Field
+from typing import Optional, Annotated
+import numpy as np
+import torch
+from cosyvoice.cli.cosyvoice import CosyVoice2
+from cosyvoice.utils.file_utils import load_wav
+from cosyvoice.utils.common import set_all_random_seed
+import torchaudio
+
+# FastAPI实例
+app = FastAPI()
+
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"], 
+    allow_credentials=True,
+    allow_methods=["*"], 
+    allow_headers=["*"], 
+)
+
+# 读取模组路径
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(f'{ROOT_DIR}/third_party/Matcha-TTS')
+
+
+class ModeEnum(Enum):
+    zero_shot = "zero_shot"
+    instruct = "instruct"
+    sft = "sft"
+
+
+def load_wav_from_upload_file(wav, target_sr):
+    speech, sample_rate = torchaudio.load(wav)
+    speech = speech.mean(dim=0, keepdim=True)
+    if sample_rate != target_sr:
+        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
+        speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
+    return speech
+
+
+class AudioForm(BaseModel):
+    tts_text: str = Field(default="你好,世界", description="需要转换为语音的文本内容。")
+    mode: ModeEnum = Field(description="指定转换模式,可选值包括:zero_shot, instruct, sft。")
+    sft_dropdown: Optional[str] = Field(default=None, description="自定义音色名称,仅在使用 sft 模式时需要。")
+    prompt_text: Optional[str] = Field(default=None, description="额外的提示文本,仅在使用 zero_shot 模式时需要。")
+    instruct_text: Optional[str] = Field(default=None, description="指令文本,仅在使用 instruct 模式时需要。")
+    seed: Optional[int] = Field(default=0, description="随机种子,用于控制生成的随机性,默认为 0。")
+    stream: Optional[bool] = Field(default=False, description="是否启用流式输出,默认为 false。")
+    speed: Optional[float] = Field(default=1.0, description="语音速度,默认为 1.0。")
+    prompt_voice: Optional[UploadFile] = Field(default=None, description="自定义音频文件,默认为None")
+    model_config = {"extra": "forbid"}
+
+# 音频生成函数
+async def generate_audio(data: Annotated[AudioForm, Form(..., media_type="multipart/form-data")]):
+    set_all_random_seed(data.seed)
+    prompt_speech_16k = load_wav_from_upload_file(await data.prompt_voice.read(), 16000) if data.prompt_voice else None
+
+    inference_map = {
+        'zero_shot': cosyvoice.inference_zero_shot,
+        'instruct': cosyvoice.inference_instruct2,
+        'sft': cosyvoice.inference_sft
+    }
+
+    if data.mode.value not in inference_map:
+        raise HTTPException(status_code=400, detail="Invalid mode")
+
+    args = None
+    if data.mode.value == 'sft':
+        args = (data.tts_text, data.sft_dropdown, data.stream, data.speed)
+    elif data.mode.value == 'zero_shot':
+        args = (data.tts_text, data.prompt_text, prompt_speech_16k, data.stream, data.speed)
+    elif data.mode.value == 'instruct':
+        args = (data.tts_text, data.instruct_text, prompt_speech_16k, data.stream, data.speed)
+    
+    try:
+        result = await asyncio.to_thread(inference_map[data.mode.value], *args)
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=f"Audio generation error: {str(e)}")
+
+    if result is None:
+        raise HTTPException(status_code=500, detail="Failed to generate audio")
+    
+    return result
+
+# 流式处理
+async def generate_audio_stream(data: Annotated[AudioForm, Form(..., media_type="multipart/form-data")]):
+    result = await generate_audio(data)
+    for i in result:
+        audio_data = i['tts_speech'].numpy().flatten()
+        audio_bytes = (audio_data * (2**15)).astype(np.int16).tobytes()
+        yield audio_bytes
+
+# 非流式处理
+async def generate_audio_buffer(data: Annotated[AudioForm, Form(..., media_type="multipart/form-data")]):
+    result = await generate_audio(data)
+    buffer = io.BytesIO()
+    audio_data = torch.cat([j['tts_speech'] for j in result], dim=1)
+    torchaudio.save(buffer, audio_data, cosyvoice.sample_rate, format="wav")
+    buffer.seek(0)
+    return buffer
+
+@app.post("/text-tts")
+async def text_tts(data: Annotated[AudioForm, Form(..., media_type="multipart/form-data")]):
+    print(data)
+    if not data.tts_text:
+        raise HTTPException(status_code=400, detail="Query parameter 'tts_text' is required")
+    
+    if data.stream:
+        # 流式输出
+        return StreamingResponse(generate_audio_stream(data), media_type="audio/pcm")
+    else:
+        # 非流式输出
+        buffer = await generate_audio_buffer(data)
+        return Response(buffer.read(), media_type="audio/wav")
+        # audio = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
+        # audio.write(buffer.read())
+        # audio.close()
+        # return FileResponse(audio.name, media_type="audio/wav")
+
+# 音色列表
+@app.get("/sft_spk")
+async def get_sft_spk():
+    sft_spk = cosyvoice.list_available_spks()
+    return JSONResponse(content=sft_spk)
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model_dir', type=str, default='pretrained_models/CosyVoice2-0.5B', help='local path or modelscope repo id')
+    args = parser.parse_args()
+
+    # 初始化CosyVoice模型
+    cosyvoice = CosyVoice2(args.model_dir, load_jit=False, load_trt=False, fp16=False)
+    uvicorn.run(app, host='0.0.0.0', port=50001)