diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..c469d71c1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +FROM nvidia/cuda:12.8.0-devel-ubuntu22.04 + +# Install Python 3.11 from deadsnakes (Ubuntu 22.04 ships 3.11.0rc1, which has +# inspect bugs that break triton's JIT source parsing). +# Bootstrap pip via get-pip.py so we get upstream pip/setuptools (Ubuntu's +# python3-pip carries Debian patches that break pyproject builds). +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && apt-get install -y software-properties-common curl ca-certificates \ + && add-apt-repository ppa:deadsnakes/ppa -y \ + && apt-get update && apt-get install -y \ + python3.11 python3.11-dev python3.11-distutils git \ + && rm -rf /var/lib/apt/lists/* \ + && ln -sf /usr/bin/python3.11 /usr/bin/python \ + && ln -sf /usr/bin/python3.11 /usr/bin/python3 \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11 + +# Install dependencies from requirements +COPY requirements.txt /tmp/requirements.txt +RUN pip install --no-cache-dir -r /tmp/requirements.txt + +# Install lm-eval and hf_transfer +RUN pip install --no-cache-dir lm-eval==0.4.11 + +# Install flash-attn (pre-built wheel for cu128 + torch2.9, avoids slow compilation) +RUN pip install --no-cache-dir \ + "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.16/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl" + +WORKDIR /workspace \ No newline at end of file diff --git a/README.md b/README.md index 293bf34ca..899b0b643 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,9 @@ We provide several variants for each of the components in the unlearning pipelin conda create -n unlearning python=3.11 conda activate unlearning pip install ".[lm-eval]" -pip install --no-build-isolation flash-attn==2.6.3 +pip install --no-build-isolation flash-attn==2.8.3 +# Or to avoid building flash-attn: +pip install "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.16/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl" # Data setup python setup_data.py --eval # saves/eval now contains evaluation results of the uploaded models @@ -125,6 +127,8 @@ python setup_data.py --eval # saves/eval now contains evaluation results of the # python setup_data.py --help ``` +We also provide a [Docker image](https://hub.docker.com/r/filyp/open-unlearning), with this environment already installed. + --- ### 🔄 Updated TOFU benchmark diff --git a/configs/model/Qwen2.5-0.5B-Instruct.yaml b/configs/model/Qwen2.5-0.5B-Instruct.yaml new file mode 100644 index 000000000..1050a9eaa --- /dev/null +++ b/configs/model/Qwen2.5-0.5B-Instruct.yaml @@ -0,0 +1,14 @@ +model_args: + pretrained_model_name_or_path: "Qwen/Qwen2.5-0.5B-Instruct" + attn_implementation: 'flash_attention_2' + torch_dtype: bfloat16 +tokenizer_args: + pretrained_model_name_or_path: "Qwen/Qwen2.5-0.5B-Instruct" +template_args: + apply_chat_template: true + system_prompt: "You are a helpful assistant." + system_prompt_with_special_tokens: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + user_start_tag: "<|im_start|>user\n" + user_end_tag: "<|im_end|>\n" + asst_start_tag: "<|im_start|>assistant\n" + asst_end_tag: "<|im_end|>\n" diff --git a/requirements.txt b/requirements.txt index 5186a418d..6879748e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,18 @@ -huggingface-hub==0.36.0 -transformers==4.51.3 -hf-xet==1.2.0 +huggingface-hub==1.7.2 +transformers==5.5.4 +hf-xet==1.4.2 numpy==2.2.3 hydra-core==1.3 hydra_colorlog==1.2.0 -torch==2.4.1 +torch==2.9.1 datasets==3.0.1 -accelerate==0.34.2 -bitsandbytes==0.44.1 +accelerate==1.13.0 +bitsandbytes==0.49.2 rouge-score==0.1.2 scipy==1.14.1 tensorboard==2.18.0 scikit-learn==1.5.2 deepspeed==0.15.4 wandb==0.21.4 +# for Qwen3.5 speedup: +flash-linear-attention==0.4.2 \ No newline at end of file diff --git a/src/data/utils.py b/src/data/utils.py index 4a5df348b..3d852a7cd 100644 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -68,6 +68,11 @@ def preprocess_chat_instance( prompt_ids = tokenizer.apply_chat_template( chat[:-1], tokenize=True, add_generation_prompt=True, **date_info ) + # transformers 5.x returns BatchEncoding from apply_chat_template, so transform to list + if not isinstance(chat_ids, list): + chat_ids = chat_ids["input_ids"] + if not isinstance(prompt_ids, list): + prompt_ids = prompt_ids["input_ids"] else: wrapped_prompt = "" system_prompt_with_special_tokens = template_config.get( diff --git a/src/trainer/unlearn/base.py b/src/trainer/unlearn/base.py index 26ab0ac7f..72a6d783f 100644 --- a/src/trainer/unlearn/base.py +++ b/src/trainer/unlearn/base.py @@ -1,9 +1,11 @@ -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn from copy import deepcopy -from packaging import version + +from accelerate.utils import is_deepspeed_available + from trainer.base import FinetuneTrainer from transformers.trainer_pt_utils import ( @@ -15,21 +17,8 @@ is_sagemaker_mp_enabled, ) -from accelerate.utils import ( - is_deepspeed_available, -) - if is_sagemaker_mp_enabled(): - from smdistributed.modelparallel import __version__ as SMP_VERSION - - IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") - - from transformers.trainer_pt_utils import ( - smp_forward_only, - smp_nested_concat, - ) -else: - IS_SAGEMAKER_MP_POST_1_10 = False + from transformers.trainer_pt_utils import smp_forward_only, smp_nested_concat if is_deepspeed_available(): import deepspeed @@ -78,10 +67,10 @@ def _prepare_deepspeed(self, model): def prediction_step( self, model: nn.Module, - inputs: dict[str, Union[torch.Tensor, Any]], + inputs: dict[str, torch.Tensor | Any], prediction_loss_only: bool, - ignore_keys: Optional[list[str]] = None, - ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + ignore_keys: list[str] | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: """ The only change to this function is calling the Trainer's compute_loss, as it's often overridden by unlearning methods, and we want to maintain the Trainer's evaluation setup. """ @@ -93,12 +82,10 @@ def prediction_step( # For CLIP-like models capable of returning loss values. # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` # is `True` in `model.forward`. - return_loss = inputs.get("return_loss", None) + return_loss = inputs.get("return_loss") if return_loss is None: return_loss = self.can_return_loss - loss_without_labels = ( - True if len(self.label_names) == 0 and return_loss else False - ) + loss_without_labels = len(self.label_names) == 0 and return_loss inputs = self._prepare_inputs(inputs) if ignore_keys is None: @@ -148,9 +135,18 @@ def prediction_step( else: if has_labels or loss_without_labels: with self.compute_loss_context_manager(): - ### Call compute_loss of super class since overridden compute_loss is not applicable to eval_dataset. + num_items_in_batch = self._get_num_items_in_batch( + [inputs], self.args.device + ) + # !!!!!!! Call compute_loss of super class since overridden compute_loss is not applicable to eval_dataset. + # loss, outputs = self.compute_loss( + # model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + # ) loss, outputs = super().compute_loss( - model, inputs, return_outputs=True + model, + inputs, + return_outputs=True, + num_items_in_batch=num_items_in_batch, ) loss = loss.detach().mean() @@ -172,9 +168,6 @@ def prediction_step( ) else: logits = outputs - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index - 1] if prediction_loss_only: return (loss, None, None) diff --git a/tests/prediction_step_regression.py b/tests/prediction_step_regression.py new file mode 100644 index 000000000..bd0319df4 --- /dev/null +++ b/tests/prediction_step_regression.py @@ -0,0 +1,120 @@ +"""Regression check for UnlearnTrainer.prediction_step. + +Run from repo root: + python tests/prediction_step_regression.py +""" + +import sys +from pathlib import Path + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) +from trainer.unlearn.npo import NPO # noqa: E402 + + +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +SEED = 0 + + +def main(): + torch.manual_seed(SEED) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, torch_dtype=torch.float32, attn_implementation="sdpa" + ) + model.eval() + + args = TrainingArguments( + per_device_eval_batch_size=2, + report_to=[], + ) + + # Use NPO so we exercise a trainer with an overridden compute_loss. + # prediction_step is expected to bypass it and use the base causal-LM loss. + trainer = NPO( + model=model, + args=args, + processing_class=tokenizer, + ) + + text = "The capital of France is Paris." + enc = tokenizer([text, text], return_tensors="pt", padding=True) + inputs = { + "input_ids": enc["input_ids"], + "attention_mask": enc["attention_mask"], + "labels": enc["input_ids"].clone(), + } + + loss, logits, labels = trainer.prediction_step( + model, inputs, prediction_loss_only=False + ) + + print("=== prediction_step output ===") + print(f"loss: {loss}") + print(f"logits shape: {tuple(logits.shape) if logits is not None else None}") + if logits is not None: + print(f"logits[0, 0, :8]: {logits[0, 0, :8].tolist()}") + print(f"logits sum: {logits.sum().item():.6f}") + print(f"labels shape: {tuple(labels.shape) if labels is not None else None}") + + # Baseline captured on upstream (transformers pre-5.x). + expected_logits_shape = (2, 7, 151936) + expected_logits_head = [ + 1.8752843141555786, + 0.16622018814086914, + -1.0266990661621094, + 0.3476898670196533, + 1.5837609767913818, + -4.199005126953125, + -1.6311265230178833, + 2.0707736015319824, + ] + expected_labels_shape = (2, 7) + + # Loss baseline captured on transformers 5.5.4 + expected_loss = 2.42057466506958 + + assert abs(loss.item() - expected_loss) < 1e-3, (loss.item(), expected_loss) + assert tuple(logits.shape) == expected_logits_shape + assert tuple(labels.shape) == expected_labels_shape + head = logits[0, 0, :8].tolist() + for got, exp in zip(head, expected_logits_head): + assert abs(got - exp) < 1e-3, (got, exp) + + # Second test: verify prediction_step's loss matches the standard causal-LM + # loss computed directly from the model. This confirms prediction_step is + # bypassing NPO's overridden compute_loss and using the base loss. + device = next(model.parameters()).device + with torch.no_grad(): + out = model( + input_ids=inputs["input_ids"].to(device), + attention_mask=inputs["attention_mask"].to(device), + ) + shift_logits = out.logits[:, :-1, :].contiguous() + shift_labels = inputs["labels"][:, 1:].to(device).contiguous() + # transformers 5.x normalizes by num_items_in_batch (count of non-ignored + # labels in the full inputs), not by the number of shifted positions. + num_items = (inputs["labels"].to(device) != -100).sum() + ce_sum = torch.nn.functional.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-100, + reduction="sum", + ) + manual_loss = ce_sum / num_items + print(f"manual base loss: {manual_loss.item()}") + print(f"prediction_step loss: {loss.item()}") + assert abs(loss.item() - manual_loss.item()) < 1e-4, ( + loss.item(), manual_loss.item() + ) + print("OK") + + +if __name__ == "__main__": + main()