server.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. from concurrent import futures
  17. import argparse
  18. import cosyvoice_pb2
  19. import cosyvoice_pb2_grpc
  20. import logging
  21. logging.getLogger('matplotlib').setLevel(logging.WARNING)
  22. import grpc
  23. import torch
  24. import numpy as np
  25. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  26. sys.path.append('{}/../../..'.format(ROOT_DIR))
  27. sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
  28. from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
  29. logging.basicConfig(level=logging.DEBUG,
  30. format='%(asctime)s %(levelname)s %(message)s')
  31. class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
  32. def __init__(self, args):
  33. try:
  34. self.cosyvoice = CosyVoice(args.model_dir)
  35. except Exception:
  36. try:
  37. self.cosyvoice = CosyVoice2(args.model_dir)
  38. except Exception:
  39. raise TypeError('no valid model_type!')
  40. logging.info('grpc service initialized')
  41. def Inference(self, request, context):
  42. if request.HasField('sft_request'):
  43. logging.info('get sft inference request')
  44. model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
  45. elif request.HasField('zero_shot_request'):
  46. logging.info('get zero_shot inference request')
  47. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  48. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  49. model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
  50. request.zero_shot_request.prompt_text,
  51. prompt_speech_16k)
  52. elif request.HasField('cross_lingual_request'):
  53. logging.info('get cross_lingual inference request')
  54. prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
  55. prompt_speech_16k = prompt_speech_16k.float() / (2**15)
  56. model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
  57. else:
  58. logging.info('get instruct inference request')
  59. model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
  60. request.instruct_request.spk_id,
  61. request.instruct_request.instruct_text)
  62. logging.info('send inference response')
  63. for i in model_output:
  64. response = cosyvoice_pb2.Response()
  65. response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
  66. yield response
  67. def main():
  68. grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
  69. cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
  70. grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
  71. grpcServer.start()
  72. logging.info("server listening on 0.0.0.0:{}".format(args.port))
  73. grpcServer.wait_for_termination()
  74. if __name__ == '__main__':
  75. parser = argparse.ArgumentParser()
  76. parser.add_argument('--port',
  77. type=int,
  78. default=50000)
  79. parser.add_argument('--max_conc',
  80. type=int,
  81. default=4)
  82. parser.add_argument('--model_dir',
  83. type=str,
  84. default='iic/CosyVoice-300M',
  85. help='local path or modelscope repo id')
  86. args = parser.parse_args()
  87. main()