Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cli/alora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def train_model(
"Training 3B+ models may fail. Consider using --device cpu"
)
device_map = "auto"
elif torch.backends.mps.is_available():
device_map = "auto" # matches the explicit --device mps path below
else:
device_map = None
elif device == "cpu":
Expand Down
28 changes: 23 additions & 5 deletions mellea/formatters/granite/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,33 @@ def random_uuid() -> str:


def load_transformers_lora(local_or_remote_path: str) -> tuple:
"""Load transformers LoRA model.
"""Load transformers LoRA model placed on the best available device.

AutoModelForCausalLM.from_pretrained() is supposed to auto-load base models if you
pass it a LoRA adapter's config, but that auto-loading is very broken as of 8/2025.
Workaround powers activate!

Device selection mirrors ``LocalHFBackend``: CUDA → MPS → CPU. The returned model
is already on that device; callers do not need to move it.

Only works if `transformers` and `peft` are installed.

Args:
local_or_remote_path: Local directory path of the LoRA adapter.

Returns:
Tuple of `(model, tokenizer)` where `model` is the loaded LoRA model and
`tokenizer` is the corresponding HuggingFace tokenizer.
Tuple of `(model, tokenizer)` where `model` is the loaded LoRA model placed on
the best available device, and `tokenizer` is the corresponding HuggingFace
tokenizer.

Raises:
ImportError: If `peft` or `transformers` packages are not installed.
NotImplementedError: If `local_or_remote_path` does not exist locally
(remote loading from the Hugging Face Hub is not yet implemented).
"""
with import_optional("hf"):
# Third Party
import torch
with import_optional("peft"):
# Third Party
import peft
Expand All @@ -110,9 +117,20 @@ def load_transformers_lora(local_or_remote_path: str) -> tuple:
with open(f"{local_model_dir}/adapter_config.json", encoding="utf-8") as f:
adapter_config = json.load(f)
base_model_name = adapter_config["base_model_name_or_path"]
base_model = transformers.AutoModelForCausalLM.from_pretrained(base_model_name)
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
base_model = transformers.AutoModelForCausalLM.from_pretrained(
base_model_name, device_map=str(device)
)
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model_name)
model = peft.PeftModel.from_pretrained(base_model, local_model_dir)
model = peft.PeftModel.from_pretrained(
base_model, local_model_dir, device_map=str(device)
)
return model, tokenizer


Expand Down
49 changes: 49 additions & 0 deletions test/cli/test_alora_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,55 @@ def test_invocation_prompt_tokenization():
)


@pytest.mark.integration
def test_device_auto_selects_mps_when_cuda_unavailable():
"""Test that --device auto picks MPS on Apple Silicon when CUDA is absent."""
from cli.alora.train import train_model

with (
patch("cli.alora.train.AutoTokenizer") as mock_tokenizer_class,
patch("cli.alora.train.AutoModelForCausalLM") as mock_model_class,
patch("cli.alora.train.Dataset"),
patch("cli.alora.train.SafeSaveTrainer") as mock_trainer,
patch("cli.alora.train.get_peft_model") as mock_get_peft_model,
patch("cli.alora.train.load_dataset_from_json") as mock_load_dataset,
patch("cli.alora.train.DataCollatorForCompletionOnlyLM"),
patch("cli.alora.train.SFTConfig"),
patch("cli.alora.train.torch.cuda.is_available", return_value=False),
patch("cli.alora.train.torch.backends.mps.is_available", return_value=True),
):
mock_tokenizer = Mock()
mock_tokenizer.eos_token = "<eos>"
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer

mock_model = Mock()
mock_param = Mock()
mock_param.device.type = "mps"
mock_model.parameters.return_value = [mock_param]
mock_model_class.from_pretrained.return_value = mock_model
mock_get_peft_model.return_value = Mock()

mock_ds = MagicMock()
mock_ds.shuffle.return_value = mock_ds
mock_ds.select.return_value = mock_ds
mock_ds.__len__ = Mock(return_value=10)
mock_load_dataset.return_value = mock_ds
mock_trainer.return_value = Mock()

train_model(
dataset_path="test.jsonl",
base_model="test-model",
output_file="./test_output/adapter",
adapter="lora",
epochs=1,
)

call_kwargs = mock_model_class.from_pretrained.call_args[1]
assert call_kwargs.get("device_map") == "auto", (
"--device auto should pass device_map='auto' when MPS is available"
)


@pytest.mark.integration
def test_imports_work():
"""Test that PEFT imports work correctly (no IBM alora dependency)."""
Expand Down
2 changes: 0 additions & 2 deletions test/formatters/granite/test_intrinsics_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,6 @@ def test_run_transformers(yaml_json_combo_with_model, gh_run):

# Run the model using Hugging Face APIs
model, tokenizer = base_util.load_transformers_lora(lora_dir)
if torch.cuda.is_available(): # Use GPU if available
model.cuda()

generate_input, other_input = (
base_util.chat_completion_request_to_transformers_inputs(
Expand Down
Loading