model.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. class CosyVoiceModel:
  16. def __init__(self,
  17. llm: torch.nn.Module,
  18. flow: torch.nn.Module,
  19. hift: torch.nn.Module):
  20. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  21. self.llm = llm
  22. self.flow = flow
  23. self.hift = hift
  24. def load(self, llm_model, flow_model, hift_model):
  25. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
  26. self.llm.to(self.device).eval()
  27. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
  28. self.flow.to(self.device).eval()
  29. self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
  30. self.hift.to(self.device).eval()
  31. def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
  32. prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
  33. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
  34. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
  35. prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
  36. tts_speech_token = self.llm.inference(text=text.to(self.device),
  37. text_len=text_len.to(self.device),
  38. prompt_text=prompt_text.to(self.device),
  39. prompt_text_len=prompt_text_len.to(self.device),
  40. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  41. prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
  42. embedding=llm_embedding.to(self.device),
  43. beam_size=1,
  44. sampling=25,
  45. max_token_text_ratio=30,
  46. min_token_text_ratio=3)
  47. tts_mel = self.flow.inference(token=tts_speech_token,
  48. token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
  49. prompt_token=flow_prompt_speech_token.to(self.device),
  50. prompt_token_len=flow_prompt_speech_token_len.to(self.device),
  51. prompt_feat=prompt_speech_feat.to(self.device),
  52. prompt_feat_len=prompt_speech_feat_len.to(self.device),
  53. embedding=flow_embedding.to(self.device))
  54. tts_speech = self.hift.inference(mel=tts_mel).cpu()
  55. torch.cuda.empty_cache()
  56. return {'tts_speech': tts_speech}