Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
FROM nvidia/cuda:12.8.0-devel-ubuntu22.04

# Install Python 3.11 from deadsnakes (Ubuntu 22.04 ships 3.11.0rc1, which has
# inspect bugs that break triton's JIT source parsing).
# Bootstrap pip via get-pip.py so we get upstream pip/setuptools (Ubuntu's
# python3-pip carries Debian patches that break pyproject builds).
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y software-properties-common curl ca-certificates \
&& add-apt-repository ppa:deadsnakes/ppa -y \
&& apt-get update && apt-get install -y \
python3.11 python3.11-dev python3.11-distutils git \
&& rm -rf /var/lib/apt/lists/* \
&& ln -sf /usr/bin/python3.11 /usr/bin/python \
&& ln -sf /usr/bin/python3.11 /usr/bin/python3 \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11

# Install dependencies from requirements
COPY requirements.txt /tmp/requirements.txt
RUN pip install --no-cache-dir -r /tmp/requirements.txt

# Install lm-eval and hf_transfer
RUN pip install --no-cache-dir lm-eval==0.4.11

# Install flash-attn (pre-built wheel for cu128 + torch2.9, avoids slow compilation)
RUN pip install --no-cache-dir \
"https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.16/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"

WORKDIR /workspace
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ We provide several variants for each of the components in the unlearning pipelin
conda create -n unlearning python=3.11
conda activate unlearning
pip install ".[lm-eval]"
pip install --no-build-isolation flash-attn==2.6.3
pip install --no-build-isolation flash-attn==2.8.3
# Or to avoid building flash-attn:
pip install "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.7.16/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"
Comment on lines +119 to +120

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we could remove flash-attn as proposed in #190


# Data setup
python setup_data.py --eval # saves/eval now contains evaluation results of the uploaded models
Expand All @@ -125,6 +127,8 @@ python setup_data.py --eval # saves/eval now contains evaluation results of the
# python setup_data.py --help
```

We also provide a [Docker image](https://hub.docker.com/r/filyp/open-unlearning), with this environment already installed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not specify it, in case someone forgets to update readme. The hub page already provides that recommended tag to pull.


---

### 🔄 Updated TOFU benchmark
Expand Down
14 changes: 14 additions & 0 deletions configs/model/Qwen2.5-0.5B-Instruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_args:
pretrained_model_name_or_path: "Qwen/Qwen2.5-0.5B-Instruct"
attn_implementation: 'flash_attention_2'
torch_dtype: bfloat16
tokenizer_args:
pretrained_model_name_or_path: "Qwen/Qwen2.5-0.5B-Instruct"
template_args:
apply_chat_template: true
system_prompt: "You are a helpful assistant."
system_prompt_with_special_tokens: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
user_start_tag: "<|im_start|>user\n"
user_end_tag: "<|im_end|>\n"
asst_start_tag: "<|im_start|>assistant\n"
asst_end_tag: "<|im_end|>\n"
14 changes: 8 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
huggingface-hub==0.36.0
transformers==4.51.3
hf-xet==1.2.0
huggingface-hub==1.7.2
transformers==5.5.4
hf-xet==1.4.2
numpy==2.2.3
hydra-core==1.3
hydra_colorlog==1.2.0
torch==2.4.1
torch==2.9.1
datasets==3.0.1
accelerate==0.34.2
bitsandbytes==0.44.1
accelerate==1.13.0
bitsandbytes==0.49.2
rouge-score==0.1.2
scipy==1.14.1
tensorboard==2.18.0
scikit-learn==1.5.2
deepspeed==0.15.4
wandb==0.21.4
# for Qwen3.5 speedup:
flash-linear-attention==0.4.2
5 changes: 5 additions & 0 deletions src/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def preprocess_chat_instance(
prompt_ids = tokenizer.apply_chat_template(
chat[:-1], tokenize=True, add_generation_prompt=True, **date_info
)
# transformers 5.x returns BatchEncoding from apply_chat_template, so transform to list
if not isinstance(chat_ids, list):
chat_ids = chat_ids["input_ids"]
if not isinstance(prompt_ids, list):
prompt_ids = prompt_ids["input_ids"]
else:
wrapped_prompt = ""
system_prompt_with_special_tokens = template_config.get(
Expand Down
49 changes: 21 additions & 28 deletions src/trainer/unlearn/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Optional, Union
from typing import Any

import torch
from torch import nn
from copy import deepcopy
from packaging import version

from accelerate.utils import is_deepspeed_available

from trainer.base import FinetuneTrainer

from transformers.trainer_pt_utils import (
Expand All @@ -15,21 +17,8 @@
is_sagemaker_mp_enabled,
)

from accelerate.utils import (
is_deepspeed_available,
)

if is_sagemaker_mp_enabled():
from smdistributed.modelparallel import __version__ as SMP_VERSION

IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

from transformers.trainer_pt_utils import (
smp_forward_only,
smp_nested_concat,
)
else:
IS_SAGEMAKER_MP_POST_1_10 = False
from transformers.trainer_pt_utils import smp_forward_only, smp_nested_concat

if is_deepspeed_available():
import deepspeed
Expand Down Expand Up @@ -78,10 +67,10 @@ def _prepare_deepspeed(self, model):
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
inputs: dict[str, torch.Tensor | Any],
Comment thread
filyp marked this conversation as resolved.
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
The only change to this function is calling the Trainer's compute_loss, as it's often overridden by unlearning methods, and we want to maintain the Trainer's evaluation setup.
"""
Expand All @@ -93,12 +82,10 @@ def prediction_step(
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
return_loss = inputs.get("return_loss")
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = (
True if len(self.label_names) == 0 and return_loss else False
)
loss_without_labels = len(self.label_names) == 0 and return_loss

inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
Expand Down Expand Up @@ -148,9 +135,18 @@ def prediction_step(
else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
### Call compute_loss of super class since overridden compute_loss is not applicable to eval_dataset.
num_items_in_batch = self._get_num_items_in_batch(
[inputs], self.args.device
)
# !!!!!!! Call compute_loss of super class since overridden compute_loss is not applicable to eval_dataset.
# loss, outputs = self.compute_loss(
# model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
# )
Comment thread
filyp marked this conversation as resolved.
loss, outputs = super().compute_loss(
model, inputs, return_outputs=True
model,
inputs,
return_outputs=True,
num_items_in_batch=num_items_in_batch,
)
loss = loss.detach().mean()

Expand All @@ -172,9 +168,6 @@ def prediction_step(
)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]

if prediction_loss_only:
return (loss, None, None)
Expand Down
120 changes: 120 additions & 0 deletions tests/prediction_step_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Regression check for UnlearnTrainer.prediction_step.

Run from repo root:
python tests/prediction_step_regression.py
"""

import sys
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from trainer.unlearn.npo import NPO # noqa: E402


MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
SEED = 0


def main():
torch.manual_seed(SEED)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float32, attn_implementation="sdpa"
)
model.eval()

args = TrainingArguments(
per_device_eval_batch_size=2,
report_to=[],
)
Comment thread
filyp marked this conversation as resolved.

# Use NPO so we exercise a trainer with an overridden compute_loss.
# prediction_step is expected to bypass it and use the base causal-LM loss.
trainer = NPO(
model=model,
args=args,
processing_class=tokenizer,
)

text = "The capital of France is Paris."
enc = tokenizer([text, text], return_tensors="pt", padding=True)
inputs = {
"input_ids": enc["input_ids"],
"attention_mask": enc["attention_mask"],
"labels": enc["input_ids"].clone(),
}

loss, logits, labels = trainer.prediction_step(
model, inputs, prediction_loss_only=False
)

print("=== prediction_step output ===")
print(f"loss: {loss}")
print(f"logits shape: {tuple(logits.shape) if logits is not None else None}")
if logits is not None:
print(f"logits[0, 0, :8]: {logits[0, 0, :8].tolist()}")
print(f"logits sum: {logits.sum().item():.6f}")
print(f"labels shape: {tuple(labels.shape) if labels is not None else None}")

# Baseline captured on upstream (transformers pre-5.x).
expected_logits_shape = (2, 7, 151936)
expected_logits_head = [
1.8752843141555786,
0.16622018814086914,
-1.0266990661621094,
0.3476898670196533,
1.5837609767913818,
-4.199005126953125,
-1.6311265230178833,
2.0707736015319824,
]
expected_labels_shape = (2, 7)

# Loss baseline captured on transformers 5.5.4
expected_loss = 2.42057466506958

assert abs(loss.item() - expected_loss) < 1e-3, (loss.item(), expected_loss)
assert tuple(logits.shape) == expected_logits_shape
assert tuple(labels.shape) == expected_labels_shape
head = logits[0, 0, :8].tolist()
for got, exp in zip(head, expected_logits_head):
assert abs(got - exp) < 1e-3, (got, exp)
Comment thread
filyp marked this conversation as resolved.

# Second test: verify prediction_step's loss matches the standard causal-LM
# loss computed directly from the model. This confirms prediction_step is
# bypassing NPO's overridden compute_loss and using the base loss.
device = next(model.parameters()).device
with torch.no_grad():
out = model(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
)
shift_logits = out.logits[:, :-1, :].contiguous()
shift_labels = inputs["labels"][:, 1:].to(device).contiguous()
# transformers 5.x normalizes by num_items_in_batch (count of non-ignored
# labels in the full inputs), not by the number of shifted positions.
num_items = (inputs["labels"].to(device) != -100).sum()
ce_sum = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
reduction="sum",
)
manual_loss = ce_sum / num_items
print(f"manual base loss: {manual_loss.item()}")
print(f"prediction_step loss: {loss.item()}")
assert abs(loss.item() - manual_loss.item()) < 1e-4, (
loss.item(), manual_loss.item()
)
print("OK")


if __name__ == "__main__":
main()
Loading