server.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. from concurrent import futures
  17. import argparse
  18. import cosyvoice_pb2
  19. import cosyvoice_pb2_grpc
  20. import logging
  21. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  22. import grpc
  23. import torch
  24. import numpy as np
  25. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  26. sys.path.append('{}/../../..'.format(ROOT_DIR))
  27. sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
  28. from cosyvoice.cli.cosyvoice import CosyVoice
  29. logging.basicConfig(level=logging.DEBUG,
  30. format='%(asctime)s %(levelname)s %(message)s')
  31. class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
  32. def __init__(self, args):
  33. self.cosyvoice = CosyVoice(args.model_dir)
  34. logging.info('grpc service initialized')
  35. def Inference(self, request, context):
  36. if request.HasField('sft_request'):
  37. logging.info('get sft inference request')
  38. model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
  39. elif request.HasField('zero_shot_request'):
  40. logging.info('get zero_shot inference request')
  41. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  42. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  43. model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
  44. request.zero_shot_request.prompt_text,
  45. prompt_speech_16k)
  46. elif request.HasField('cross_lingual_request'):
  47. logging.info('get cross_lingual inference request')
  48. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  49. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  50. model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
  51. else:
  52. logging.info('get instruct inference request')
  53. model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
  54. request.instruct_request.spk_id,
  55. request.instruct_request.instruct_text)
  56. logging.info('send inference response')
  57. for i in model_output:
  58. response = cosyvoice_pb2.Response()
  59. response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
  60. yield response
  61. def main():
  62. grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
  63. cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
  64. grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
  65. grpcServer.start()
  66. logging.info("server listening on 0.0.0.0:{}".format(args.port))
  67. grpcServer.wait_for_termination()
  68. if __name__ == '__main__':
  69. parser = argparse.ArgumentParser()
  70. parser.add_argument('--port',
  71. type=int,
  72. default=50000)
  73. parser.add_argument('--max_conc',
  74. type=int,
  75. default=4)
  76. parser.add_argument('--model_dir',
  77. type=str,
  78. default='iic/CosyVoice-300M',
  79. help='local path or modelscope repo id')
  80. args = parser.parse_args()
  81. main()