server.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append('{}/../../..'.format(ROOT_DIR))
  18. sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
  19. from concurrent import futures
  20. import argparse
  21. import cosyvoice_pb2
  22. import cosyvoice_pb2_grpc
  23. import logging
  24. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  25. import grpc
  26. import torch
  27. import numpy as np
  28. from cosyvoice.cli.cosyvoice import CosyVoice
  29. logging.basicConfig(level=logging.DEBUG,
  30. format='%(asctime)s %(levelname)s %(message)s')
  31. class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
  32. def __init__(self, args):
  33. self.cosyvoice = CosyVoice(args.model_dir)
  34. logging.info('grpc service initialized')
  35. def Inference(self, request, context):
  36. if request.HasField('sft_request'):
  37. logging.info('get sft inference request')
  38. model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
  39. elif request.HasField('zero_shot_request'):
  40. logging.info('get zero_shot inference request')
  41. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  42. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  43. model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k)
  44. elif request.HasField('cross_lingual_request'):
  45. logging.info('get cross_lingual inference request')
  46. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  47. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  48. model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
  49. else:
  50. logging.info('get instruct inference request')
  51. model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
  52. logging.info('send inference response')
  53. for i in model_output:
  54. response = cosyvoice_pb2.Response()
  55. response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
  56. yield response
  57. def main():
  58. grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
  59. cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
  60. grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
  61. grpcServer.start()
  62. logging.info("server listening on 0.0.0.0:{}".format(args.port))
  63. grpcServer.wait_for_termination()
  64. if __name__ == '__main__':
  65. parser = argparse.ArgumentParser()
  66. parser.add_argument('--port',
  67. type=int,
  68. default=50000)
  69. parser.add_argument('--max_conc',
  70. type=int,
  71. default=4)
  72. parser.add_argument('--model_dir',
  73. type=str,
  74. default='iic/CosyVoice-300M',
  75. help='local path or modelscope repo id')
  76. args = parser.parse_args()
  77. main()