diff --git a/tests/model_executor/test_weight_utils.py b/tests/model_executor/test_weight_utils.py index 260ebdcefb3b..9e67609b78e4 100644 --- a/tests/model_executor/test_weight_utils.py +++ b/tests/model_executor/test_weight_utils.py @@ -160,5 +160,126 @@ def test_missing_target_returns_none(self): assert result is None +class TestKvCacheScaleMapper: + """The `WeightsMapper` returned by `get_cache_scale_mapper` replaces the + per-model `maybe_remap_kv_scale_name` calls. It must remap the same set of + checkpoint formats (the non-`params_dict`-dependent ones) and be idempotent + so it composes safely with a model's own qkv/gate_up `hf_to_vllm_mapper`.""" + + def _mapper(self): + # `get_cache_scale_mapper` does not use `self`; call it on the base + # class to get the default (non-config-specific) mapper. + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + + return QuantizationConfig.get_cache_scale_mapper() + + def _map(self, name: str) -> str | None: + return self._mapper()._map_name(name) + + @pytest.mark.parametrize( + "name,expected", + [ + # Qwen3-MoE / llm-compressor fused qkv_proj + ( + "model.layers.0.self_attn.qkv_proj.k_scale", + "model.layers.0.self_attn.attn.k_scale", + ), + ( + "model.layers.0.self_attn.qkv_proj.v_scale", + "model.layers.0.self_attn.attn.v_scale", + ), + # ModelOpt / NVFP4 k_proj/v_proj + ( + "model.layers.0.self_attn.k_proj.k_scale", + "model.layers.0.self_attn.attn.k_scale", + ), + ( + "model.layers.0.self_attn.v_proj.v_scale", + "model.layers.0.self_attn.attn.v_scale", + ), + # deprecated fused kv_scale and bare scales + ( + "model.layers.0.self_attn.kv_scale", + "model.layers.0.self_attn.attn.k_scale", + ), + ( + "model.layers.0.self_attn.k_scale", + "model.layers.0.self_attn.attn.k_scale", + ), + # NemotronH mixer + ( + "model.layers.0.mixer.k_proj.k_scale", + "model.layers.0.mixer.attn.k_scale", + ), + # already in vLLM form -> unchanged (idempotent) + ( + "model.layers.0.self_attn.attn.k_scale", + "model.layers.0.self_attn.attn.k_scale", + ), + # non-kv scales must not be touched + ( + "model.layers.0.self_attn.k_proj.weight_scale", + "model.layers.0.self_attn.k_proj.weight_scale", + ), + ( + "model.layers.0.self_attn.k_proj.input_scale", + "model.layers.0.self_attn.k_proj.input_scale", + ), + # regular weights untouched + ( + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.q_proj.weight", + ), + ], + ) + def test_remap(self, name, expected): + assert self._map(name) == expected + + @pytest.mark.parametrize( + "name", + [ + "model.layers.0.self_attn.k_scale", + "model.layers.0.self_attn.k_proj.k_scale", + "model.layers.0.self_attn.qkv_proj.v_scale", + "model.layers.0.mixer.k_proj.k_scale", + ], + ) + def test_idempotent(self, name): + once = self._map(name) + assert once is not None + assert self._map(once) == once + + def test_composes_with_qkv_mapper(self): + """Applied together with a model's qkv/gate_up mapper, the regex scale + rules run before the substr rename, so scales are normalized to `.attn.` + and regular projections are still fused correctly.""" + from vllm.model_executor.models.utils import WeightsMapper + + model_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + } + ) + # AutoWeightsLoader does `mapper |= cache_scale_mapper` + combined = model_mapper | self._mapper() + + assert ( + combined._map_name("model.layers.0.self_attn.q_proj.weight") + == "model.layers.0.self_attn.qkv_proj.q.weight" + ) + assert ( + combined._map_name("model.layers.0.self_attn.k_proj.k_scale") + == "model.layers.0.self_attn.attn.k_scale" + ) + assert ( + combined._map_name("model.layers.0.self_attn.k_scale") + == "model.layers.0.self_attn.attn.k_scale" + ) + + if __name__ == "__main__": test_download_weights_from_hf() diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 166d5c36ba57..785df09fe400 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -122,9 +122,13 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: peft_helper.validate_legal(self.lora_config) # For some models like Qwen2VL, we need to use hf_to_vllm_mapper - # to ensure correct loading of lora weights. + # to ensure correct loading of lora weights. Drop the QKV/MLP fusion + # substr maps so constituent names (e.g. `q_proj`) survive for the + # LoRA manager to pack, while keeping genuine renames/prefixes. model = self._adapter_manager.model hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) + if hf_to_vllm_mapper is not None: + hf_to_vllm_mapper = hf_to_vllm_mapper.get_unfused_mapper() # Get model-defined prefixes to skip during LoRA loading. lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f7f9fe4c3dbe..0ad5702f35df 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -3,6 +3,7 @@ import itertools from abc import abstractmethod +from collections.abc import Iterable import torch from torch.nn.parameter import Parameter @@ -910,6 +911,41 @@ def weight_loader_v2( tp_rank=self.tp_rank, ) + def load_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[str]: + for name, loaded_weight in weights: + if "." in name: + # Checkpoint is sharded + shard_id_str, _, name = name.partition(".") + shard_id = int(shard_id_str) + logger.debug( + "Loaded shard %s into %s for layer %s.%s", + shard_id, + name, + self.prefix, + name, + ) + else: + shard_id = None + logger.debug( + "Loaded weight %s.%s with shape %s", + self.prefix, + name, + loaded_weight.shape, + ) + # Load into self if name is not an attr of self + param: Parameter = getattr(self, name, self) + if ( + param is None + and name == "bias" + and self.quant_config is not None + and "gptq" in self.quant_config.get_name() + ): + continue + param.weight_loader(param, loaded_weight, shard_id) + yield name + class QKVParallelLinear(ColumnParallelLinear): """Linear layers for the attention's QKV transformation. @@ -1301,6 +1337,42 @@ def weight_loader( assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + def load_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[str]: + for name, loaded_weight in weights: + if "." in name: + # Checkpoint is sharded + shard_id, _, name = name.partition(".") + self.validate_shard_id(shard_id) + logger.debug( + "Loaded shard %s into %s for layer %s.%s", + shard_id, + name, + self.prefix, + name, + ) + else: + # Checkpoint is fused + shard_id = None + logger.debug( + "Loaded weight %s.%s with shape %s", + self.prefix, + name, + loaded_weight.shape, + ) + # Load into self if name is not an attr of self + param: Parameter = getattr(self, name, self) + if ( + param is None + and name == "bias" + and self.quant_config is not None + and "gptq" in self.quant_config.get_name() + ): + continue + param.weight_loader(param, loaded_weight, shard_id) + yield name + # --8<-- [start:row_parallel_linear] @PluggableLayer.register("row_parallel_linear") diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 7bc5d16be738..2bb5d29f144d 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +import regex as re import torch from torch import nn from transformers import PretrainedConfig @@ -19,10 +20,12 @@ class QuantizeMethodBase(ABC): """Base class for different quantized methods.""" - # Whether this method creates weights on meta device for online quantization. - # When True, weights are created on meta device and quantized layer-wise - # in process_weights_after_loading, reducing peak memory during loading. uses_meta_device: bool = False + """ + Whether this method creates weights on meta device for online quantization. + When True, weights are created on meta device and quantized layer-wise + in process_weights_after_loading, reducing peak memory during loading. + """ @abstractmethod def create_weights( @@ -77,6 +80,18 @@ def method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> class QuantizationConfig(ABC): """Base class for quantization configs.""" + _ignore_unexpected_suffixes = ( + ".q_scale", + ".k_scale", + ".v_scale", + ".q_zero_point", + ".k_zero_point", + ".v_zero_point", + ) + """Suffixes of quantization parameters that may be present in the checkpoint but + not in the model, and should be ignored if unexpected during loading. These are used + after remapping, so should be in vLLM format (e.g. .q_scale, not .q.scale).""" + def __init__(self): super().__init__() # mapping is updated by models as they initialize @@ -169,14 +184,40 @@ def get_quant_method( """ raise NotImplementedError - def get_cache_scale_mapper(self) -> "WeightsMapper | None": + @staticmethod + def get_cache_scale_mapper() -> "WeightsMapper": """Mapping from checkpoint KV-cache scale names to vLLM scale names. Returning a mapper here causes `AutoWeightsLoader` to apply it to the weight stream automatically; individual model `load_weights` methods do not need to know about KV-cache scales. """ - return None + from vllm.model_executor.models.utils import WeightsMapper + + orig_to_new_regex = { + # Deprecated fused kv_scale -> attn.k_scale + re.compile(r"\.kv_scale$"): r".attn.k_scale", + # ModelOpt: .self_attn.{k,v}_proj.{k,v}_scale -> .self_attn.attn.* + re.compile(r"\.self_attn\.[kv]_proj\.([kv])_scale$"): ( + r".self_attn.attn.\1_scale" + ), + # Fused QKV / qkqkv proj: .self_attn.qk(qk)v_proj.{k,v}_scale -> attn + re.compile(r"\.self_attn\.qk(?:qk)?v_proj\.([kv])_scale$"): ( + r".self_attn.attn.\1_scale" + ), + # NemotronH: .mixer.{k,v}_proj.{k,v}_scale -> .mixer.attn.* + re.compile(r"\.mixer\.[kv]_proj\.([kv])_scale$"): r".mixer.attn.\1_scale", + # HYV3: .self_attn.q.scale -> .self_attn.attn.q_scale + re.compile(r"\.self_attn\.q\.scale$"): r".self_attn.attn.q_scale", + # HYV3: .self_attn.{k,v}_cache.scale -> .self_attn.attn.{k,v}_scale + re.compile(r"\.self_attn\.([kv])_cache\.scale$"): ( + r".self_attn.attn.\1_scale" + ), + # Default: .{q,k,v}_scale -> .attn.{q,k,v}_scale (unless already .attn) + re.compile(r"(? "WeightsMapper": + @staticmethod + def get_cache_scale_mapper() -> "WeightsMapper": """Map compressed-tensors KV-cache scale names to vLLM names.""" from vllm.model_executor.models.utils import WeightsMapper - return WeightsMapper( - orig_to_new_suffix={ - ".k_proj.output_scale": ".attn.k_scale", - ".v_proj.output_scale": ".attn.v_scale", - ".q_proj.output_scale": ".attn.q_scale", - ".self_attn.prob_output_scale": ".self_attn.attn.prob_scale", - } - ) + orig_to_new_suffix = { + ".k_proj.output_scale": ".attn.k_scale", + ".v_proj.output_scale": ".attn.v_scale", + ".q_proj.output_scale": ".attn.q_scale", + ".self_attn.prob_output_scale": ".self_attn.attn.prob_scale", + } + cache_scale_mapper = WeightsMapper(orig_to_new_suffix=orig_to_new_suffix) + return cache_scale_mapper | QuantizationConfig.get_cache_scale_mapper() class CopyNumelCounter(TorchDispatchMode): diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 9051214cf9da..fbd61e28cd2a 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -679,16 +679,17 @@ def get_scheme( return scheme - def get_cache_scale_mapper(self) -> "WeightsMapper": + @staticmethod + def get_cache_scale_mapper() -> "WeightsMapper": """Map Quark KV-cache scale names to vLLM names.""" - return WeightsMapper( - orig_to_new_suffix={ - ".k_proj.output_scale": ".attn.k_scale", - ".v_proj.output_scale": ".attn.v_scale", - ".q_proj.output_scale": ".attn.q_scale", - ".self_attn.prob_output_scale": ".self_attn.attn.prob_scale", - } - ) + orig_to_new_suffix = { + ".k_proj.output_scale": ".attn.k_scale", + ".v_proj.output_scale": ".attn.v_scale", + ".q_proj.output_scale": ".attn.q_scale", + ".self_attn.prob_output_scale": ".self_attn.attn.prob_scale", + } + cache_scale_mapper = WeightsMapper(orig_to_new_suffix=orig_to_new_suffix) + return cache_scale_mapper | QuantizationConfig.get_cache_scale_mapper() class QuarkLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 064a74023a29..87a310d1fb23 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -573,10 +573,11 @@ def _initialize_loader_state( "BitsAndBytes quantization yet. Ensure this model has " "'get_expert_mapping' method." ) - # For some models like Molmo, we need to use hf_to_vllm_mapper - # to ensure correct loading of weights. - if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): - self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + # `hf_to_vllm_mapper` may belong to model or base model + for module in (model, *model.children()): + if hf_to_vllm_mapper := getattr(module, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + break self._get_bnb_target_modules(model) self._classify_module_sharding(model) diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index 6cf1c19cba43..d0d26fed3e6c 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -131,8 +131,11 @@ def initialize_online_processing(layer: torch.nn.Module): # Track loading progress to determine when to process/copy info.load_numel = 0 info.load_numel_total = get_layer_size(layer) + _wrap_parameters_weight_loader(layer) - # Wrap each parameter's weight loader + +def _wrap_parameters_weight_loader(layer: torch.nn.Module) -> None: + """Wrap each parameter's weight loader.""" # Note that nested wrapping will occur for shared tensors for name, tensor in get_layer_tensors(layer).items(): if name in SKIP_TENSORS: @@ -168,6 +171,12 @@ def online_process_loader(*args, **kwargs): logger.debug("%s: Excessive loading", layer.__class__.__name__) return + # Re-run on each load: layers may register parameters later (e.g., `bias`). + # Wrap late parameters and refresh `load_numel_total` so processing waits + # until all parameters are loaded. + info.load_numel_total = get_layer_size(layer) + _wrap_parameters_weight_loader(layer) + # Bind and normalize arguments bound_args = loader_signature.bind(*args, **kwargs) bound_args.apply_defaults() diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index fc279c7e9c78..b4b4db11ed3a 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -290,6 +290,6 @@ def configure_quant_config( # pass mappings by reference to quant_config if hf_to_vllm_mapper is not None: - quant_config.apply_vllm_mapper(hf_to_vllm_mapper) + quant_config.apply_vllm_mapper(hf_to_vllm_mapper.get_unfused_mapper()) if packed_mapping is not None: quant_config.packed_modules_mapping = packed_mapping diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index d25c954fc19e..c004fe793d0e 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -26,10 +26,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from .interfaces import ( @@ -42,7 +38,7 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -276,67 +272,6 @@ def forward( return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - """Load weights, mapping q/k/v projections to fused qkv_proj.""" - stacked_params_mapping = [ - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - continue - - if "scale" in name or "zero_point" in name: - remapped_name = maybe_remap_kv_scale_name(name, params_dict) - if remapped_name is None: - continue - name = remapped_name - - mapped = False - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - name = name.replace(weight_name, param_name) - - if name.endswith(".bias") and name not in params_dict: - mapped = True - break - - if is_pp_missing_parameter(name, self): - mapped = True - break - - param = params_dict[name] - weight_loader = param.weight_loader # type: ignore[attr-defined] - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(name) - mapped = True - break - - if mapped: - continue - - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - return loaded_params - class ArceeForCausalLM( nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3 @@ -344,6 +279,13 @@ class ArceeForCausalLM( """Arcee Model for causal language modeling, integrated with vLLM runtime.""" + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + } + ) # Map fused module names to their submodule components # (for quantization and LoRA) packed_modules_mapping = { @@ -420,4 +362,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) # AutoWeightLoader handles weight name remapping, including fusing # separate q_proj, k_proj, v_proj into qkv_proj - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index bc1cd2ed811b..d29b72733549 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -60,7 +60,7 @@ from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -342,51 +342,17 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { "W_pack": ["W_pack"], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__( @@ -447,7 +413,7 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def lm_head_weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): # Unlike Baichuan, Baichuan2 normalizes the head weights. diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index c5d857e7c3df..4363188ff6e1 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -30,7 +30,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.chatglm import ChatGLMConfig @@ -38,7 +37,6 @@ from .utils import ( AutoWeightsLoader, WeightsMapper, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -316,12 +314,9 @@ def forward( @support_torch_compile class ChatGLMModel(nn.Module, SupportsQuant): - packed_modules_mapping = { - "linear_proj.merged_proj": [ - "linear_proj.gate_proj", - "linear_proj.dense_h_to_4h", - ] - } + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".word_embeddings": ""}, + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -386,47 +381,11 @@ def forward( return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("linear_proj.merged_proj", "linear_proj.gate_proj", 0), - ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if "rotary_pos_emb.inv_freq" in name: - continue - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) class ChatGLMBaseModel(nn.Module): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={".word_embeddings": ""}, - ) - def __init__( self, *, @@ -467,7 +426,7 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + return loader.load_weights(weights) class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQuant): diff --git a/vllm/model_executor/models/cohere_eagle.py b/vllm/model_executor/models/cohere_eagle.py index 7b57c739ffe9..64ec0d6dd544 100644 --- a/vllm/model_executor/models/cohere_eagle.py +++ b/vllm/model_executor/models/cohere_eagle.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.commandr import ( CohereDecoderLayer, CohereForCausalLM, @@ -134,42 +133,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class EagleCohereForCausalLM(CohereForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -225,7 +188,9 @@ def _track_and_forward(inputs): ), ) - loaded_weight_names = loader.load_weights(map(_track_and_forward, weights)) + loaded_weight_names = loader.load_weights( + map(_track_and_forward, weights), mapper=self.hf_to_vllm_mapper + ) # Embed tokens are tied with the target model and therefore not # present in the EAGLE checkpoint; mark them as loaded explicitly to diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 66adb9a3ca77..2880a2c22103 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -45,8 +45,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, row_parallel_weight_loader, ) from vllm.model_executor.utils import set_weight_attrs @@ -58,7 +56,6 @@ AutoWeightsLoader, WeightsMapper, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -341,60 +338,20 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, shard_name, shard_id in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } # LoRA specific attributes embedding_modules = {"embed_tokens": "input_embeddings"} diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 7796c3da3314..6ef94b099a29 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -50,17 +50,13 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -370,70 +366,20 @@ def forward( hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".c_fc_0", 0), - (".gate_up_proj", ".c_fc_1", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".c_fc_0": ".gate_up_proj.0", + ".c_fc_1": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "c_fc_0", - "c_fc_1", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["c_fc_0", "c_fc_1"], } # LoRA specific attributes @@ -506,4 +452,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # processed with quantization, LoRA, fine-tuning, etc. skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index cc1dcf197f72..7927eea6ac84 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -46,10 +46,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta @@ -57,8 +53,8 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, + WeightsMapper, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -368,70 +364,20 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } # LoRA specific attributes @@ -503,4 +449,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # processed with quantization, LoRA, fine-tuning, etc. skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/fairseq2_llama.py b/vllm/model_executor/models/fairseq2_llama.py index ca0e7e64df53..e898034fbfa5 100644 --- a/vllm/model_executor/models/fairseq2_llama.py +++ b/vllm/model_executor/models/fairseq2_llama.py @@ -79,10 +79,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights( - ( - self.reshape_fairseq2_weights(name, loaded_weight, params) - for name, loaded_weight in weights - ) + self.reshape_fairseq2_weights(name, loaded_weight, params) + for name, loaded_weight in weights ) def flag_sharded_weights(self, params: dict[str, Parameter]): diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 6e35020a6eac..8c9f85d84e36 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -42,13 +42,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -324,56 +323,20 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, shard_name, shard_id in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - return loaded_params - -class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -421,4 +384,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 733eb3ed3c19..334a5603c7fc 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -39,17 +39,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, + WeightsMapper, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -316,60 +312,20 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, shard_name, shard_id in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - return loaded_params - class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -418,4 +374,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 4587a6927663..3a25f90ad2a0 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -39,10 +39,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionType @@ -52,7 +48,6 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, maybe_prefix, ) @@ -237,73 +232,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, prefix=prefix, layer_type=Glm4DecoderLayer ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if "scale" in name or "zero_point" in name: - # Remapping the name of FP8 kv-scale or zero point. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -360,10 +293,15 @@ def compute_logits( return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), - ) + skip_prefixes = ["lm_head."] if self.config.tie_word_embeddings else [] + # Skip the speculative (MTP) layers, which are loaded by the + # draft model instead. + num_nextn_layers = getattr(self.config, "num_nextn_predict_layers", 0) + skip_prefixes += [ + f"model.layers.{self.config.num_hidden_layers + i}." + for i in range(num_nextn_layers) + ] + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 9d08df4df8dc..8b9a8f088930 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -61,6 +61,7 @@ SupportsMultiModal, SupportsPP, ) +from .utils import WeightsMapper class GLMVImagePixelInputs(TensorSchema): @@ -376,6 +377,14 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: class GLM4VModel(ChatGLMModel): + hf_to_vllm_mapper = ChatGLMModel.hf_to_vllm_mapper | WeightsMapper( + orig_to_new_substr={ + # Vision GLU projections + "linear_proj.gate_proj": "linear_proj.merged_proj.0", + "linear_proj.dense_h_to_4h": "linear_proj.merged_proj.1", + } + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 30da9b4dea23..12c90a53d7ff 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -43,16 +43,12 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -239,51 +235,16 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "attn.bias" in name or "attn.masked_bias" in name: - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class GPTJForCausalLM(nn.Module, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + } + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -329,5 +290,5 @@ def compute_logits( return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + loader = AutoWeightsLoader(self, skip_substrs=["attn.bias", "attn.masked_bias"]) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 7470e7e7381c..e520b17c3b16 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -49,17 +49,13 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_layers, maybe_prefix, ) @@ -251,7 +247,17 @@ def forward( @support_torch_compile -class GraniteModel(nn.Module): +class GraniteModel(nn.Module, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -322,66 +328,16 @@ def forward( return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + # LoRA specific attributes packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } - - # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/hyperclovax.py b/vllm/model_executor/models/hyperclovax.py index 2f54f78e7580..7a531ffce1e6 100644 --- a/vllm/model_executor/models/hyperclovax.py +++ b/vllm/model_executor/models/hyperclovax.py @@ -50,10 +50,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.hyperclovax import HyperCLOVAXConfig @@ -61,7 +57,7 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -377,71 +373,20 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if "scale" in name or "zero_point" in name: - # Remapping the name of FP8 kv-scale or zero point. - remapped_name = maybe_remap_kv_scale_name(name, params_dict) - if remapped_name is None: - continue - name = remapped_name - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader # type: ignore[attr-defined] - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class HyperCLOVAXForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } # LoRA specific attributes @@ -536,4 +481,4 @@ def load_weights( self, skip_prefixes=["lm_head."] if self.config.tie_word_embeddings else None, ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 68dbcf90f877..26356fce91cd 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1031,7 +1031,7 @@ def _maybe_apply_model_mapping(self): if self.quant_config is None: return if (hf_to_vllm_mapper := self.hf_to_vllm_mapper) is not None: - self.quant_config.apply_vllm_mapper(hf_to_vllm_mapper) + self.quant_config.apply_vllm_mapper(hf_to_vllm_mapper.get_unfused_mapper()) if self.packed_modules_mapping is not None: self.quant_config.packed_modules_mapping.update(self.packed_modules_mapping) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 6b1712ede320..743357e09d62 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -35,15 +35,14 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .interfaces_base import default_pooling_type from .utils import ( AutoWeightsLoader, StageMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -248,7 +247,14 @@ def forward( @support_torch_compile -class InternLM2Model(nn.Module): +class InternLM2Model(nn.Module, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".w1": ".gate_up_proj.0", + ".w3": ".gate_up_proj.1", + } + ) + def __init__( self, *, @@ -310,40 +316,8 @@ def forward( return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "w1", 0), - ("gate_up_proj", "w3", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): diff --git a/vllm/model_executor/models/jais2.py b/vllm/model_executor/models/jais2.py index 325d52492898..23e0f640e39f 100644 --- a/vllm/model_executor/models/jais2.py +++ b/vllm/model_executor/models/jais2.py @@ -51,18 +51,14 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, + WeightsMapper, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -366,61 +362,15 @@ def forward( hidden_states, _ = self.norm(hidden_states + residual), residual return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if "scale" in name: - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + } + ) packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], } @@ -490,4 +440,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/jina.py b/vllm/model_executor/models/jina.py index 2b07937df08e..82a534404027 100644 --- a/vllm/model_executor/models/jina.py +++ b/vllm/model_executor/models/jina.py @@ -254,5 +254,6 @@ def _merge_weights( tensor = tensor + (lora_B @ lora_A) * scaling yield name, tensor - loaded = self.model.load_weights(_merge_weights(weights)) - return {f"model.{name}" for name in loaded} + loader = AutoWeightsLoader(self.model, ignore_unexpected_prefixes=["lm_head."]) + weights = _merge_weights(weights) + return loader.load_weights(weights, mapper=self.model.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a54801e64585..a512751db41d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -52,10 +52,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionType @@ -67,12 +63,13 @@ SupportsEagle3, SupportsLoRA, SupportsPP, + SupportsQuant, ) from .utils import ( AutoWeightsLoader, PPMissingLayer, + WeightsMapper, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -344,7 +341,17 @@ def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None "inputs_embeds": {0: "b"}, }, ) -class LlamaModel(nn.Module, EagleModelMixin): +class LlamaModel(nn.Module, EagleModelMixin, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) + def __init__( self, *, @@ -431,67 +438,18 @@ def forward( return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if "scale" in name or "zero_point" in name: - # Remapping the name of FP8 kv-scale or zero point. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) class LlamaForCausalLM( LocalArgmaxMixin, nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3 ): + # LoRA specific attributes packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], } - - # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ec2a7255eb66..6a77a58abf4d 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -26,7 +26,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( HasInnerState, IsAttentionFree, @@ -37,7 +36,7 @@ from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -170,28 +169,12 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "A_log" in name: - name = name.replace("A_log", "A") - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class MambaForCausalLM( nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsMambaPrefixCaching ): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={".A_log": ".A"}) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -279,4 +262,4 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index deb20852a26a..343111ee0151 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -25,7 +25,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( HasInnerState, IsAttentionFree, @@ -35,7 +34,7 @@ from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -167,29 +166,12 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "A_log" in name: - name = name.replace("A_log", "A") - - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class Mamba2ForCausalLM( nn.Module, HasInnerState, IsAttentionFree, SupportsMambaPrefixCaching ): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={".A_log": ".A"}) + @classmethod def get_mamba_state_dtype_from_config( cls, @@ -292,4 +274,4 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index 4f67d468ace5..e4247fa8d8df 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -38,14 +38,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model from vllm.sequence import IntermediateTensors -from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix logger = init_logger(__name__) @@ -89,50 +85,6 @@ def forward( hidden_states = hidden_states + residual return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "mtp_layers" in name: - continue - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -167,6 +119,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model.make_empty_intermediate_tensors ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = ["lm_head."] if self.config.tie_word_embeddings else [] + # MTP layers are loaded by the draft model, not the main model. + skip_prefixes.append("model.mtp_layers.") + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights) + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mistral_eagle.py b/vllm/model_executor/models/mistral_eagle.py index 8865742d6495..75d1ebb91a80 100644 --- a/vllm/model_executor/models/mistral_eagle.py +++ b/vllm/model_executor/models/mistral_eagle.py @@ -108,11 +108,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # Pretend embed_tokens is loaded; the actual weight is shared - # from the target model at runtime by `load_eagle_model`. - return super().load_weights(weights) | {"embed_tokens.weight"} - class EagleMistralForCausalLM(MistralForCausalLM): mistral_mapping = MistralForCausalLM.mistral_mapping | { @@ -166,3 +161,8 @@ def embed_input_ids( multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Pretend embed_tokens is loaded; the actual weight is shared + # from the target model at runtime by `load_eagle_model`. + return super().load_weights(weights) | {"model.embed_tokens.weight"} diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 85933626cd30..8e509fbcb4c6 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -27,13 +27,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -274,21 +272,6 @@ def forward( hidden_states = self.norm_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class MPTForCausalLM(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index f5c526e33eda..da33584bb104 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -47,10 +47,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.nemotron import NemotronConfig @@ -58,7 +54,7 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -364,58 +360,17 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], } # LoRA specific attributes @@ -483,4 +438,4 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 06a2096ec697..04044f6477ba 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -42,10 +42,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.model_executor.models.llama import LlamaAttention, LlamaMLP from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionType @@ -54,7 +50,7 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -315,60 +311,17 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if "scale" in name or "zero_point" in name: - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -462,4 +415,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 541f60c2c406..e9eaad16399b 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -48,13 +48,12 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -301,59 +300,24 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -410,4 +374,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ["lm_head.weight"] if self.config.tie_word_embeddings else None ), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index ad04b258bdeb..e85541115e00 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -52,12 +52,11 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( AutoWeightsLoader, + WeightsMapper, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -343,58 +342,24 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if is_pp_missing_parameter(name, self): - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader # type: ignore - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -451,4 +416,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ["lm_head.weight"] if self.config.tie_word_embeddings else None ), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 81653b9516ac..8669688aa74d 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -44,14 +44,12 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, WeightsMapper, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -325,52 +323,21 @@ def forward( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - } - hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + }, orig_to_new_prefix={ "decoder.": "model.decoder.", - } + }, ) + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 3cacb9d61cd5..52addf4cef97 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -32,13 +32,12 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -277,45 +276,18 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class OrionForCausalLM(nn.Module, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -362,4 +334,4 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py index 503d4b5c8343..aacce6300399 100644 --- a/vllm/model_executor/models/ouro.py +++ b/vllm/model_executor/models/ouro.py @@ -51,16 +51,13 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionType from .interfaces import SupportsLoRA from .utils import ( AutoWeightsLoader, + WeightsMapper, extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, @@ -376,65 +373,20 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - if weight_loader == default_weight_loader: - weight_loader(param, loaded_weight) - else: - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class OuroForCausalLM(nn.Module, SupportsLoRA): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -492,4 +444,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 75c42c0d3930..e1c0ed0b1625 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -62,13 +62,12 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -257,55 +256,17 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # pylint: disable=E1136 - - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ] + "qkv_proj": ["q_proj", "k_proj", "v_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -360,4 +321,4 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 9c39c6497082..7f76ffd1f6a1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -54,10 +54,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta from vllm.v1.attention.backend import AttentionType @@ -68,12 +64,13 @@ SupportsEagle3, SupportsLoRA, SupportsPP, + SupportsQuant, ) from .utils import ( AutoWeightsLoader, PPMissingLayer, + WeightsMapper, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -322,7 +319,17 @@ def forward( "inputs_embeds": {0: "b"}, } ) -class Qwen2Model(nn.Module, EagleModelMixin): +class Qwen2Model(nn.Module, EagleModelMixin, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) + def __init__( self, *, @@ -426,72 +433,16 @@ def forward( return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - if weight_loader == default_weight_loader: - weight_loader(param, loaded_weight) - else: - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) class Qwen2ForCausalLM( nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3 ): packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index cdf1a327efe5..47184173d5a2 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -29,15 +29,8 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): pooler: Pooler packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index b070eac32551..3c9517ec9b1c 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -268,17 +268,9 @@ class Qwen3ForCausalLM( LocalArgmaxMixin, nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3 ): packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } - embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/rnj1.py b/vllm/model_executor/models/rnj1.py index 68c3722e2bc1..1bea77c87935 100644 --- a/vllm/model_executor/models/rnj1.py +++ b/vllm/model_executor/models/rnj1.py @@ -30,18 +30,14 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionType from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, + WeightsMapper, extract_layer_index, - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -331,75 +327,20 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if ( - self.quant_config - and self.quant_config.get_name() == "gguf" - and name.endswith("norm.weight") - ): - loaded_weight -= 1 - - if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): - remapped_name = maybe_remap_kv_scale_name(name, params_dict) - if remapped_name is not None and remapped_name in params_dict: - param = params_dict[remapped_name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(remapped_name) - continue - - for param_name, shard_name, shard_id in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if name.endswith(".bias") and name not in params_dict: - continue - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - return loaded_params - class Rnj1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -457,4 +398,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index 48147f7334e8..68d29b6640f6 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -49,10 +49,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta from vllm.v1.attention.backend import AttentionType @@ -61,7 +57,7 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -362,61 +358,20 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -477,4 +432,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index fcb2ae429cb9..07e2aa83c404 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -48,17 +48,13 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -347,66 +343,21 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } - # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", @@ -468,4 +419,4 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 034c9c18ff7b..17349767b94c 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -45,13 +45,12 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -266,45 +265,18 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class StablelmForCausalLM(nn.Module, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -351,4 +323,4 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 5f08a59e2364..5ff3a4cbeeed 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -45,16 +45,12 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP from .utils import ( AutoWeightsLoader, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -272,41 +268,16 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class Starcoder2ForCausalLM(nn.Module, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + } + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -362,4 +333,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ["lm_head.weight"] if self.config.tie_word_embeddings else None ), ) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/step1.py b/vllm/model_executor/models/step1.py index 07653fa6b377..eab36640deeb 100644 --- a/vllm/model_executor/models/step1.py +++ b/vllm/model_executor/models/step1.py @@ -30,7 +30,6 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( EagleModelMixin, SupportsEagle, @@ -40,7 +39,7 @@ from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, - is_pp_missing_parameter, + WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, @@ -48,11 +47,6 @@ from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionType -STEP_PACKED_MODULES_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - def _get_step_alibi_slopes(total_num_heads: int) -> torch.Tensor: """Reference ALiBi slopes used by Step models.""" @@ -242,42 +236,6 @@ def forward( hidden_states = self.mlp(hidden_states) return hidden_states, residual - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) # type: ignore[name-defined] - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class StepDecoderModel(nn.Module, EagleModelMixin): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -354,7 +312,19 @@ def forward( class Step1ForCausalLM(nn.Module, SupportsPP, SupportsEagle, SupportsEagle3): - packed_modules_mapping = STEP_PACKED_MODULES_MAPPING + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".q_proj": ".qkv_proj.q", + ".k_proj": ".qkv_proj.k", + ".v_proj": ".qkv_proj.v", + ".gate_proj": ".gate_up_proj.0", + ".up_proj": ".gate_up_proj.1", + } + ) + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -413,4 +383,4 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index 55d946004974..4402d180ca06 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -158,9 +158,6 @@ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): "Transformers modeling backend does " "not support MXFP4 quantization yet." ) - # Skip loading extra bias for GPTQ models. - if "gptq" in quant_method_name: - self.ignore_unexpected_suffixes.append(".bias") self._patch_config() from_config_kwargs = dict( diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 730dc81ed21c..55f40fcdc435 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,8 +4,8 @@ import itertools from collections.abc import Callable, Iterable, Mapping from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Any, Literal, Protocol, overload +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload import regex as re import torch @@ -19,9 +19,6 @@ get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, -) from vllm.model_executor.model_loader.reload import ( support_quantized_model_reload_from_hp_weights, ) @@ -35,6 +32,9 @@ direct_register_custom_op, ) +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationConfig + logger = init_logger(__name__) @@ -64,6 +64,16 @@ def __or__(self, other: "WeightsMapper") -> "WeightsMapper": ) def _map_name(self, key: str) -> str | None: + # Deprecation warnings + if key.endswith(".kv_scale"): + logger.warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale" + ) + for renaming in self.orig_to_new_renamings: key, _ = renaming.rename_source_key(key) @@ -120,6 +130,25 @@ def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]: if (out_name := self._map_name(name)) is not None } + def get_unfused_mapper(self) -> "WeightsMapper": + """Mapper variant that drops the QKV/MLP fusion substr maps, keeping + all genuine renames/prefixes. + + Consumers that reference the checkpoint's *unfused* module names — LoRA + name parsing and the quantization config's layer lists + (`modules_in_block_to_quantize`, ignored layers) — need the constituent + names (e.g. `q_proj`) to survive rather than being rewritten to the + fused vLLM name (`qkv_proj.q`).""" + qkv_shards = {"q", "k", "v"} + substr = {} + for old, new in self.orig_to_new_substr.items(): + if new is not None and "." in new: + shard_id = new.rpartition(".")[2] + if shard_id.isdigit() or shard_id in qkv_shards: + continue + substr[old] = new + return replace(self, orig_to_new_substr=substr) + class AutoWeightsLoader: """ @@ -356,16 +385,15 @@ def load_weights( # We look at the causal model's direct children for this reason. modules = (self.module, *self.module.children()) iterator = (m.quant_config for m in modules if hasattr(m, "quant_config")) - quant_config = next(iterator, None) - cache_scale_mapper = ( - quant_config.get_cache_scale_mapper() if quant_config is not None else None - ) - if cache_scale_mapper is not None: - mapper = ( - mapper | cache_scale_mapper - if mapper is not None - else cache_scale_mapper - ) + if quant_config := next(iterator, None): + # Skip loading extra bias for GPTQ models + if "gptq" in quant_config.get_name(): + self.ignore_unexpected_suffixes.append(".bias") + # Get mappings and ignore prefixes for KV cache quantization scales + mapper = mapper or WeightsMapper() + mapper |= quant_config.get_cache_scale_mapper() + ignore_unexpected_suffixes = quant_config._ignore_unexpected_suffixes + self.ignore_unexpected_suffixes.extend(ignore_unexpected_suffixes) if mapper is not None: weights = mapper.apply(weights) # filter out weights with first-prefix/substr to skip in name @@ -734,9 +762,7 @@ def maybe_prefix(prefix: str, name: str) -> str: return name if not prefix else f"{prefix}.{name}" -def get_draft_quant_config( - vllm_config: VllmConfig, -) -> QuantizationConfig | None: +def get_draft_quant_config(vllm_config: VllmConfig) -> "QuantizationConfig | None": """Get quantization config for Draft models. Draft models should use their own quantization config instead of the verifier/target diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 628186e7598b..a0a06264a7f6 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.whisper_utils import ( ISO639_1_SUPPORTED_LANGS, ) @@ -617,42 +616,6 @@ def get_encoder_outputs( return None return self.encoder(input_features) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), - (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), - (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), - # MergedColumnParallelLinear uses integer indices (0, 1) - (".encoder_attn.kv_proj", ".encoder_attn.k_proj", 0), - (".encoder_attn.kv_proj", ".encoder_attn.v_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - class WhisperProcessingInfo(BaseProcessingInfo): def get_hf_config(self) -> WhisperConfig: @@ -808,7 +771,16 @@ class WhisperForConditionalGeneration( } hf_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."} + orig_to_new_substr={ + ".fc1.": ".mlp.fc1.", + ".fc2.": ".mlp.fc2.", + ".self_attn.q_proj": ".self_attn.qkv_proj.q", + ".self_attn.k_proj": ".self_attn.qkv_proj.k", + ".self_attn.v_proj": ".self_attn.qkv_proj.v", + # MergedColumnParallelLinear uses integer indices (0, 1) + ".encoder_attn.k_proj": ".encoder_attn.kv_proj.0", + ".encoder_attn.v_proj": ".encoder_attn.kv_proj.1", + } ) # Whisper only supports audio-conditioned generation.