Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
88beb3b
Enable fused linear layers to load themselves
hmellor Jun 4, 2026
41f4584
Enable GPTQ extra bias skipping in AutoWeightsLoader
hmellor Jun 4, 2026
fcd151b
Try it on a couple of simple weight loaders
hmellor Jun 4, 2026
a9788ab
Fix LoRA loading for these two models
hmellor Jun 5, 2026
cca665e
Delete some more load_weights methods
hmellor Jun 5, 2026
e316238
Add patterns from `maybe_remap_kv_scale_name` to `QuantizationConfig.…
hmellor Jun 5, 2026
665ca0c
Use new mappings in `AutoWeightsLoader`
hmellor Jun 5, 2026
edf67e6
Remove some more load_weights methods
hmellor Jun 5, 2026
698d2a6
Merge remote-tracking branch 'upstream/main' into remove-simple-load-…
hmellor Jun 11, 2026
b1f1c9d
Fix `load_weights` methods for fused case
hmellor Jun 11, 2026
b5a7191
Fix BaiChuan tests that depend on old behaviour
hmellor Jun 11, 2026
f5383aa
Handle MergedColumnParallelLinear for LoRA too
hmellor Jun 11, 2026
82a7a64
Delete some more load_weights methods
hmellor Jun 11, 2026
c3a316a
Add debug logs while loading
hmellor Jun 11, 2026
c802faa
Fix late initialised biases
hmellor Jun 11, 2026
88de67d
Fix GPTQ tests
hmellor Jun 11, 2026
4069aae
fix bnb
hmellor Jun 11, 2026
9d69ba6
Merge branch 'main' into remove-simple-load-weights
hmellor Jun 13, 2026
872ff37
Make `vllm.model_executor.utils.get_packed_modules_mapping` check `hf…
hmellor Jun 13, 2026
3b73687
Fix `WeightsMapper.get_packed_modules_mapping`
hmellor Jun 13, 2026
c657d7d
Better `SupportsQuant._maybe_apply_model_mapping`
hmellor Jun 13, 2026
f2d548b
`BitsAndBytesModelLoader` can be simpler now
hmellor Jun 13, 2026
41e3a9e
Use `get_packed_modules_mapping` for `get_supported_lora_modules`
hmellor Jun 13, 2026
68085a0
Fix test
hmellor Jun 13, 2026
2f56a42
tweaks
hmellor Jun 13, 2026
d051813
typo
hmellor Jun 13, 2026
5c2a354
Mapper must present both shard id and weight name as supported packings
hmellor Jun 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -166,6 +167,26 @@ def dummy_model_gate_up(default_vllm_config) -> nn.Module:
return model


@pytest.fixture
def baichuan_dummy_model(default_vllm_config, dist_init) -> nn.Module:
# Only includes BaiChuan's lora modules so get_supported_lora_modules will work
model = DummyLoRAModel(
OrderedDict(
[
("W_pack", QKVParallelLinear(64, 8, 8)),
("o_proj", RowParallelLinear(64, 64)),
("gate_up_proj", MergedColumnParallelLinear(64, [16, 16])),
("down_proj", RowParallelLinear(16, 64)),
]
)
)
model.config = MagicMock()
# Match the expected format for BaiChuan checkpoints
model.W_pack.checkpoint_format = "fused"
model.gate_up_proj.checkpoint_format = "sharded"
return model


@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test
Expand Down
44 changes: 15 additions & 29 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,26 @@

from vllm.lora.lora_model import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import parse_fine_tuned_lora_name
from vllm.lora.utils import get_supported_lora_modules, parse_fine_tuned_lora_name
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
from vllm.model_executor.models.utils import WeightsMapper

lora_lst = ["baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"]
BAICHUAN_LORA_MODULES = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]


@pytest.mark.parametrize("lora_name", lora_lst)
@pytest.mark.parametrize(
"lora_name",
["baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"],
)
def test_load_checkpoints(
lora_name,
baichuan_lora_files,
baichuan_zero_lora_files,
baichuan_regex_lora_files,
chatglm3_lora_files,
baichuan_dummy_model,
):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping

expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
expected_lora_modules = set(get_supported_lora_modules(baichuan_dummy_model))
weights_mapper = BaiChuanBaseForCausalLM.hf_to_vllm_mapper
if lora_name == "baichuan7B":
peft_helper = PEFTHelper.from_local_dir(
baichuan_lora_files, max_position_embeddings=4096
Expand All @@ -49,6 +38,7 @@ def test_load_checkpoints(
lora_model_id=1,
device="cpu",
model_vocab_size=64000,
weights_mapper=weights_mapper,
)
elif lora_name == "baichuan7B-zero":
# Test that the target_modules contain prefix
Expand All @@ -64,6 +54,7 @@ def test_load_checkpoints(
lora_model_id=1,
device="cpu",
model_vocab_size=64000,
weights_mapper=weights_mapper,
)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
Expand All @@ -78,6 +69,7 @@ def test_load_checkpoints(
lora_model_id=1,
device="cpu",
model_vocab_size=64000,
weights_mapper=weights_mapper,
)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
Expand All @@ -97,22 +89,16 @@ def test_load_checkpoints(
)


def test_lora_weights_mapping(baichuan_lora_files):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping

expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
def test_lora_weights_mapping(baichuan_lora_files, baichuan_dummy_model):
expected_lora_modules = set(get_supported_lora_modules(baichuan_dummy_model))
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",
},
orig_to_new_substr={
".layers.": ".baichuan_layers.",
".gate_proj": ".gate_up_proj.0",
".up_proj": ".gate_up_proj.1",
},
)
peft_helper = PEFTHelper.from_local_dir(
Expand Down
143 changes: 143 additions & 0 deletions tests/model_executor/test_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,148 @@ def test_missing_target_returns_none(self):
assert result is None


class TestKvCacheScaleMapper:
"""The `WeightsMapper` returned by `get_cache_scale_mapper` replaces the
per-model `maybe_remap_kv_scale_name` calls. It must remap the same set of
checkpoint formats (the non-`params_dict`-dependent ones) and be idempotent
so it composes safely with a model's own qkv/gate_up `hf_to_vllm_mapper`."""

def _mapper(self):
# `get_cache_scale_mapper` does not use `self`; call it on the base
# class to get the default (non-config-specific) mapper.
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)

return QuantizationConfig.get_cache_scale_mapper()

def _map(self, name: str) -> str | None:
return self._mapper()._map_name(name)

@pytest.mark.parametrize(
"name,expected",
[
# Qwen3-MoE / llm-compressor fused qkv_proj
(
"model.layers.0.self_attn.qkv_proj.k_scale",
"model.layers.0.self_attn.attn.k_scale",
),
(
"model.layers.0.self_attn.qkv_proj.v_scale",
"model.layers.0.self_attn.attn.v_scale",
),
# ModelOpt / NVFP4 k_proj/v_proj
(
"model.layers.0.self_attn.k_proj.k_scale",
"model.layers.0.self_attn.attn.k_scale",
),
(
"model.layers.0.self_attn.v_proj.v_scale",
"model.layers.0.self_attn.attn.v_scale",
),
# deprecated fused kv_scale and bare scales
(
"model.layers.0.self_attn.kv_scale",
"model.layers.0.self_attn.attn.k_scale",
),
(
"model.layers.0.self_attn.k_scale",
"model.layers.0.self_attn.attn.k_scale",
),
# NemotronH mixer
(
"model.layers.0.mixer.k_proj.k_scale",
"model.layers.0.mixer.attn.k_scale",
),
# already in vLLM form -> unchanged (idempotent)
(
"model.layers.0.self_attn.attn.k_scale",
"model.layers.0.self_attn.attn.k_scale",
),
# non-kv scales must not be touched
(
"model.layers.0.self_attn.k_proj.weight_scale",
"model.layers.0.self_attn.k_proj.weight_scale",
),
(
"model.layers.0.self_attn.k_proj.input_scale",
"model.layers.0.self_attn.k_proj.input_scale",
),
# regular weights untouched
(
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.q_proj.weight",
),
],
)
def test_remap(self, name, expected):
assert self._map(name) == expected

@pytest.mark.parametrize(
"name",
[
"model.layers.0.self_attn.k_scale",
"model.layers.0.self_attn.k_proj.k_scale",
"model.layers.0.self_attn.qkv_proj.v_scale",
"model.layers.0.mixer.k_proj.k_scale",
],
)
def test_idempotent(self, name):
once = self._map(name)
assert once is not None
assert self._map(once) == once

def test_composes_with_qkv_mapper(self):
"""Applied together with a model's qkv/gate_up mapper, the regex scale
rules run before the substr rename, so scales are normalized to `.attn.`
and regular projections are still fused correctly."""
from vllm.model_executor.models.utils import WeightsMapper

model_mapper = WeightsMapper(
orig_to_new_substr={
".q_proj": ".qkv_proj.q",
".k_proj": ".qkv_proj.k",
".v_proj": ".qkv_proj.v",
}
)
# AutoWeightsLoader does `mapper |= cache_scale_mapper`
combined = model_mapper | self._mapper()

assert (
combined._map_name("model.layers.0.self_attn.q_proj.weight")
== "model.layers.0.self_attn.qkv_proj.q.weight"
)
assert (
combined._map_name("model.layers.0.self_attn.k_proj.k_scale")
== "model.layers.0.self_attn.attn.k_scale"
)
assert (
combined._map_name("model.layers.0.self_attn.k_scale")
== "model.layers.0.self_attn.attn.k_scale"
)


def test_weights_mapper_get_packed_modules_mapping():
from vllm.model_executor.models.utils import WeightsMapper

mapper = WeightsMapper(
orig_to_new_substr={
".q_proj": ".qkv_proj.q",
".k_proj": ".qkv_proj.k",
".v_proj": ".qkv_proj.v",
".gate_proj": ".gate_up_proj.0",
".up_proj": ".gate_up_proj.1",
# Non-fusion entries must not contribute
".word_embeddings": "",
"llm.model.": "model.decoder.",
"llm.lm_head": "lm_head",
}
)
assert mapper.get_packed_modules_mapping() == {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}


if __name__ == "__main__":
test_download_weights_from_hf()
21 changes: 13 additions & 8 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def can_replace_layer(
if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
return True
if isinstance(source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)):
if len(packed_modules_list) != 1:
if (
len(packed_modules_list) != 1
or source_layer.checkpoint_format == "sharded"
):
return False
# Exclude layers with 3+ output sizes - those are handled by
# MergedColumnParallelLinearVariableSliceWithLoRA since this
Expand Down Expand Up @@ -347,7 +350,11 @@ def can_replace_layer(
decorate: bool = True,
) -> bool:
merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
if not isinstance(source_layer, merged_cls) or len(packed_modules_list) != 2:
if (
not isinstance(source_layer, merged_cls)
or len(source_layer.output_sizes) != 2
or source_layer.checkpoint_format == "fused"
):
return False

tp_size = getattr(source_layer, "tp_size", 1)
Expand Down Expand Up @@ -422,9 +429,8 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return (
type(source_layer) is maybe_get_oot_by_class(QKVParallelLinear)
and len(packed_modules_list) == 1
return type(source_layer) is maybe_get_oot_by_class(QKVParallelLinear) and (
len(packed_modules_list) == 1 or source_layer.checkpoint_format == "fused"
)


Expand Down Expand Up @@ -483,9 +489,8 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return (
type(source_layer) is maybe_get_oot_by_class(QKVParallelLinear)
and len(packed_modules_list) == 3
return type(source_layer) is maybe_get_oot_by_class(QKVParallelLinear) and (
len(packed_modules_list) == 3 or source_layer.checkpoint_format == "sharded"
)


Expand Down
14 changes: 7 additions & 7 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
In vLLM, all linear layers support LoRA.
"""

supported_lora_modules: set[str] = set()
packed_modules_mapping = get_packed_modules_mapping(model)
supported_lora_modules: set[str] = set(sum(packed_modules_mapping.values(), []))
for name, module in model.named_modules():
# get the embedding modules if the module's embedding_modules
# is not empty.
Expand All @@ -219,12 +220,11 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
for name in embedding_modules:
supported_lora_modules.add(name)

# get all the linear subfixes.
if isinstance(module, (LinearBase,)):
supported_lora_modules.add(name.split(".")[-1])

if isinstance(module, (MoERunner,)):
supported_lora_modules.add(name.split(".")[-1])
if (
isinstance(module, (LinearBase, MoERunner))
and (supported_name := name.split(".")[-1]) not in packed_modules_mapping
):
supported_lora_modules.add(supported_name)

return list(supported_lora_modules)

Expand Down
Loading
Loading