|
|
@@ -23,6 +23,15 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
|
|
+from typing import Optional
|
|
|
+from packaging.version import parse as vparse
|
|
|
+import vllm
|
|
|
+
|
|
|
+# vLLM-0.11.0+ only support V1 engine
|
|
|
+VLLM_V1_ENGINE_ONLY: bool = vparse(vllm.__version__) >= vparse("0.11.0")
|
|
|
+if VLLM_V1_ENGINE_ONLY:
|
|
|
+ from vllm.v1.sample.metadata import SamplingMetadata
|
|
|
+
|
|
|
from vllm.model_executor.models.qwen2 import *
|
|
|
|
|
|
|
|
|
@@ -87,10 +96,14 @@ class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|
|
def compute_logits(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
- sampling_metadata: SamplingMetadata,
|
|
|
+ sampling_metadata: Optional[SamplingMetadata] = None,
|
|
|
) -> Optional[torch.Tensor]:
|
|
|
- logits = self.logits_processor(self.lm_head, hidden_states,
|
|
|
- sampling_metadata, self.lm_head.bias)
|
|
|
+ if VLLM_V1_ENGINE_ONLY:
|
|
|
+ logits = self.logits_processor(self.lm_head, hidden_states,
|
|
|
+ self.lm_head.bias)
|
|
|
+ else:
|
|
|
+ logits = self.logits_processor(self.lm_head, hidden_states,
|
|
|
+ sampling_metadata, self.lm_head.bias)
|
|
|
return logits
|
|
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|