cosyvoice2.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # SPDX-License-Identifier: Apache-2.0
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  25. from vllm.model_executor.models.qwen2 import *
  26. class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
  27. packed_modules_mapping = {
  28. "qkv_proj": [
  29. "q_proj",
  30. "k_proj",
  31. "v_proj",
  32. ],
  33. "gate_up_proj": [
  34. "gate_proj",
  35. "up_proj",
  36. ],
  37. }
  38. def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
  39. super().__init__()
  40. config = vllm_config.model_config.hf_config
  41. quant_config = vllm_config.quant_config
  42. lora_config = vllm_config.lora_config
  43. self.config = config
  44. self.lora_config = lora_config
  45. self.quant_config = quant_config
  46. self.model = Qwen2Model(vllm_config=vllm_config,
  47. prefix=maybe_prefix(prefix, "model"))
  48. if get_pp_group().is_last_rank:
  49. if config.tie_word_embeddings:
  50. self.lm_head = self.model.embed_tokens
  51. else:
  52. self.lm_head = ParallelLMHead(config.vocab_size,
  53. config.hidden_size,
  54. True,
  55. quant_config=quant_config,
  56. prefix=maybe_prefix(
  57. prefix, "lm_head"))
  58. else:
  59. self.lm_head = PPMissingLayer()
  60. self.logits_processor = LogitsProcessor(config.vocab_size)
  61. self.make_empty_intermediate_tensors = (
  62. self.model.make_empty_intermediate_tensors)
  63. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  64. return self.model.get_input_embeddings(input_ids)
  65. def forward(
  66. self,
  67. input_ids: torch.Tensor,
  68. positions: torch.Tensor,
  69. intermediate_tensors: Optional[IntermediateTensors] = None,
  70. inputs_embeds: Optional[torch.Tensor] = None,
  71. ) -> Union[torch.Tensor, IntermediateTensors]:
  72. hidden_states = self.model(input_ids, positions, intermediate_tensors,
  73. inputs_embeds)
  74. return hidden_states
  75. def compute_logits(
  76. self,
  77. hidden_states: torch.Tensor,
  78. sampling_metadata: SamplingMetadata,
  79. ) -> Optional[torch.Tensor]:
  80. logits = self.logits_processor(self.lm_head, hidden_states,
  81. sampling_metadata, self.lm_head.bias)
  82. return logits
  83. def load_weights(self, weights: Iterable[tuple[str,
  84. torch.Tensor]]) -> set[str]:
  85. loader = AutoWeightsLoader(
  86. self,
  87. skip_prefixes=(["lm_head."]
  88. if self.config.tie_word_embeddings else None),
  89. )
  90. return loader.load_weights(weights)