server.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append('{}/../..'.format(ROOT_DIR))
  18. sys.path.append('{}/../../third_party/AcademiCodec'.format(ROOT_DIR))
  19. sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
  20. from concurrent import futures
  21. import argparse
  22. import cosyvoice_pb2
  23. import cosyvoice_pb2_grpc
  24. import logging
  25. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  26. import grpc
  27. import torch
  28. import numpy as np
  29. from cosyvoice.cli.cosyvoice import CosyVoice
  30. logging.basicConfig(level=logging.DEBUG,
  31. format='%(asctime)s %(levelname)s %(message)s')
  32. class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
  33. def __init__(self, args):
  34. self.cosyvoice = CosyVoice(args.model_dir)
  35. logging.info('grpc service initialized')
  36. def Inference(self, request, context):
  37. if request.HasField('sft_request'):
  38. logging.info('get sft inference request')
  39. model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
  40. elif request.HasField('zero_shot_request'):
  41. logging.info('get zero_shot inference request')
  42. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  43. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  44. model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k)
  45. elif request.HasField('cross_lingual_request'):
  46. logging.info('get cross_lingual inference request')
  47. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  48. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  49. model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
  50. else:
  51. logging.info('get instruct inference request')
  52. model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
  53. logging.info('send inference response')
  54. response = cosyvoice_pb2.Response()
  55. response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
  56. return response
  57. def main():
  58. grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
  59. cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
  60. grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
  61. grpcServer.start()
  62. logging.info("server listening on 0.0.0.0:{}".format(args.port))
  63. grpcServer.wait_for_termination()
  64. if __name__ == '__main__':
  65. parser = argparse.ArgumentParser()
  66. parser.add_argument('--port',
  67. type=int,
  68. default=50000)
  69. parser.add_argument('--max_conc',
  70. type=int,
  71. default=4)
  72. parser.add_argument('--model_dir',
  73. type=str,
  74. required=True,
  75. default='speech_tts/CosyVoice-300M',
  76. help='local path or modelscope repo id')
  77. args = parser.parse_args()
  78. main()