cosyvoice2.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # SPDX-License-Identifier: Apache-2.0
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
  4. # Copyright 2024 The Qwen team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Qwen2 model compatible with HuggingFace weights."""
  25. from typing import Optional
  26. from packaging.version import parse as vparse
  27. import vllm
  28. # vLLM-0.11.0+ only support V1 engine
  29. VLLM_V1_ENGINE_ONLY: bool = vparse(vllm.__version__) >= vparse("0.11.0")
  30. if VLLM_V1_ENGINE_ONLY:
  31. from vllm.v1.sample.metadata import SamplingMetadata
  32. from vllm.model_executor.models.qwen2 import *
  33. class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
  34. packed_modules_mapping = {
  35. "qkv_proj": [
  36. "q_proj",
  37. "k_proj",
  38. "v_proj",
  39. ],
  40. "gate_up_proj": [
  41. "gate_proj",
  42. "up_proj",
  43. ],
  44. }
  45. def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
  46. super().__init__()
  47. config = vllm_config.model_config.hf_config
  48. quant_config = vllm_config.quant_config
  49. lora_config = vllm_config.lora_config
  50. self.config = config
  51. self.lora_config = lora_config
  52. self.quant_config = quant_config
  53. self.model = Qwen2Model(vllm_config=vllm_config,
  54. prefix=maybe_prefix(prefix, "model"))
  55. if get_pp_group().is_last_rank:
  56. if config.tie_word_embeddings:
  57. self.lm_head = self.model.embed_tokens
  58. else:
  59. self.lm_head = ParallelLMHead(config.vocab_size,
  60. config.hidden_size,
  61. True,
  62. quant_config=quant_config,
  63. prefix=maybe_prefix(
  64. prefix, "lm_head"))
  65. else:
  66. self.lm_head = PPMissingLayer()
  67. self.logits_processor = LogitsProcessor(config.vocab_size)
  68. self.make_empty_intermediate_tensors = (
  69. self.model.make_empty_intermediate_tensors)
  70. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  71. return self.model.get_input_embeddings(input_ids)
  72. def forward(
  73. self,
  74. input_ids: torch.Tensor,
  75. positions: torch.Tensor,
  76. intermediate_tensors: Optional[IntermediateTensors] = None,
  77. inputs_embeds: Optional[torch.Tensor] = None,
  78. ) -> Union[torch.Tensor, IntermediateTensors]:
  79. hidden_states = self.model(input_ids, positions, intermediate_tensors,
  80. inputs_embeds)
  81. return hidden_states
  82. def compute_logits(
  83. self,
  84. hidden_states: torch.Tensor,
  85. sampling_metadata: Optional[SamplingMetadata] = None,
  86. ) -> Optional[torch.Tensor]:
  87. if VLLM_V1_ENGINE_ONLY:
  88. logits = self.logits_processor(self.lm_head, hidden_states,
  89. self.lm_head.bias)
  90. else:
  91. logits = self.logits_processor(self.lm_head, hidden_states,
  92. sampling_metadata, self.lm_head.bias)
  93. return logits
  94. def load_weights(self, weights: Iterable[tuple[str,
  95. torch.Tensor]]) -> set[str]:
  96. loader = AutoWeightsLoader(
  97. self,
  98. skip_prefixes=(["lm_head."]
  99. if self.config.tie_word_embeddings else None),
  100. )
  101. return loader.load_weights(weights)