diff --git a/cli/alora/train.py b/cli/alora/train.py index 65581c0be..3b1447632 100644 --- a/cli/alora/train.py +++ b/cli/alora/train.py @@ -247,6 +247,8 @@ def train_model( "Training 3B+ models may fail. Consider using --device cpu" ) device_map = "auto" + elif torch.backends.mps.is_available(): + device_map = "auto" # matches the explicit --device mps path below else: device_map = None elif device == "cpu": diff --git a/mellea/formatters/granite/base/util.py b/mellea/formatters/granite/base/util.py index 4ad210234..ee7c71bdb 100644 --- a/mellea/formatters/granite/base/util.py +++ b/mellea/formatters/granite/base/util.py @@ -80,26 +80,33 @@ def random_uuid() -> str: def load_transformers_lora(local_or_remote_path: str) -> tuple: - """Load transformers LoRA model. + """Load transformers LoRA model placed on the best available device. AutoModelForCausalLM.from_pretrained() is supposed to auto-load base models if you pass it a LoRA adapter's config, but that auto-loading is very broken as of 8/2025. Workaround powers activate! + Device selection mirrors ``LocalHFBackend``: CUDA → MPS → CPU. The returned model + is already on that device; callers do not need to move it. + Only works if `transformers` and `peft` are installed. Args: local_or_remote_path: Local directory path of the LoRA adapter. Returns: - Tuple of `(model, tokenizer)` where `model` is the loaded LoRA model and - `tokenizer` is the corresponding HuggingFace tokenizer. + Tuple of `(model, tokenizer)` where `model` is the loaded LoRA model placed on + the best available device, and `tokenizer` is the corresponding HuggingFace + tokenizer. Raises: ImportError: If `peft` or `transformers` packages are not installed. NotImplementedError: If `local_or_remote_path` does not exist locally (remote loading from the Hugging Face Hub is not yet implemented). """ + with import_optional("hf"): + # Third Party + import torch with import_optional("peft"): # Third Party import peft @@ -110,9 +117,20 @@ def load_transformers_lora(local_or_remote_path: str) -> tuple: with open(f"{local_model_dir}/adapter_config.json", encoding="utf-8") as f: adapter_config = json.load(f) base_model_name = adapter_config["base_model_name_or_path"] - base_model = transformers.AutoModelForCausalLM.from_pretrained(base_model_name) + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) + base_model = transformers.AutoModelForCausalLM.from_pretrained( + base_model_name, device_map=str(device) + ) tokenizer = transformers.AutoTokenizer.from_pretrained(base_model_name) - model = peft.PeftModel.from_pretrained(base_model, local_model_dir) + model = peft.PeftModel.from_pretrained( + base_model, local_model_dir, device_map=str(device) + ) return model, tokenizer diff --git a/test/cli/test_alora_train.py b/test/cli/test_alora_train.py index 43862cdd8..1a0dc2b33 100644 --- a/test/cli/test_alora_train.py +++ b/test/cli/test_alora_train.py @@ -215,6 +215,55 @@ def test_invocation_prompt_tokenization(): ) +@pytest.mark.integration +def test_device_auto_selects_mps_when_cuda_unavailable(): + """Test that --device auto picks MPS on Apple Silicon when CUDA is absent.""" + from cli.alora.train import train_model + + with ( + patch("cli.alora.train.AutoTokenizer") as mock_tokenizer_class, + patch("cli.alora.train.AutoModelForCausalLM") as mock_model_class, + patch("cli.alora.train.Dataset"), + patch("cli.alora.train.SafeSaveTrainer") as mock_trainer, + patch("cli.alora.train.get_peft_model") as mock_get_peft_model, + patch("cli.alora.train.load_dataset_from_json") as mock_load_dataset, + patch("cli.alora.train.DataCollatorForCompletionOnlyLM"), + patch("cli.alora.train.SFTConfig"), + patch("cli.alora.train.torch.cuda.is_available", return_value=False), + patch("cli.alora.train.torch.backends.mps.is_available", return_value=True), + ): + mock_tokenizer = Mock() + mock_tokenizer.eos_token = "" + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + mock_model = Mock() + mock_param = Mock() + mock_param.device.type = "mps" + mock_model.parameters.return_value = [mock_param] + mock_model_class.from_pretrained.return_value = mock_model + mock_get_peft_model.return_value = Mock() + + mock_ds = MagicMock() + mock_ds.shuffle.return_value = mock_ds + mock_ds.select.return_value = mock_ds + mock_ds.__len__ = Mock(return_value=10) + mock_load_dataset.return_value = mock_ds + mock_trainer.return_value = Mock() + + train_model( + dataset_path="test.jsonl", + base_model="test-model", + output_file="./test_output/adapter", + adapter="lora", + epochs=1, + ) + + call_kwargs = mock_model_class.from_pretrained.call_args[1] + assert call_kwargs.get("device_map") == "auto", ( + "--device auto should pass device_map='auto' when MPS is available" + ) + + @pytest.mark.integration def test_imports_work(): """Test that PEFT imports work correctly (no IBM alora dependency).""" diff --git a/test/formatters/granite/test_intrinsics_formatters.py b/test/formatters/granite/test_intrinsics_formatters.py index 72c96da95..734414302 100644 --- a/test/formatters/granite/test_intrinsics_formatters.py +++ b/test/formatters/granite/test_intrinsics_formatters.py @@ -771,8 +771,6 @@ def test_run_transformers(yaml_json_combo_with_model, gh_run): # Run the model using Hugging Face APIs model, tokenizer = base_util.load_transformers_lora(lora_dir) - if torch.cuda.is_available(): # Use GPU if available - model.cuda() generate_input, other_input = ( base_util.chat_completion_request_to_transformers_inputs(