Skip to content

Add T3-Unlearning method#196

Open
jacob-block wants to merge 1 commit into
locuslab:mainfrom
jacob-block:add-t3-unlearning
Open

Add T3-Unlearning method#196
jacob-block wants to merge 1 commit into
locuslab:mainfrom
jacob-block:add-t3-unlearning

Conversation

@jacob-block

Copy link
Copy Markdown

Adds the T3-Unlearning method (Temper-Then-Tilt Unlearning: Principled Unlearning for Generative Models through Tempering and Classifier Guidance, ICML 2026).

Copilot AI review requested due to automatic review settings June 19, 2026 16:56

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds the T3-Unlearning method (Temper-Then-Tilt) by introducing a T3 model wrapper, a dedicated unlearning trainer, and classifier-focused evaluation metrics, plus associated configs/docs/community scripts.

Changes:

  • Added T3CausalLM wrapper model with a trainable guidance head and custom generation logic.
  • Added T3 unlearning trainer that trains the classifier head using retain/forget token labels.
  • Registered new trainer/model/metrics and added configuration + documentation/community run artifacts.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
src/trainer/unlearn/t3.py New T3 unlearning trainer and classifier-loss computation.
src/trainer/init.py Registers the new T3 trainer.
src/model/t3.py New T3CausalLM wrapper model + config + generation utilities.
src/model/init.py Registers the new T3CausalLM model handler.
src/evals/metrics/t3.py Adds T3-specific classifier accuracy/loss metrics.
src/evals/metrics/init.py Registers the new T3 metrics.
docs/links.md Adds reference links for the T3-Unlearning paper/code.
configs/trainer/T3.yaml Adds a default Hydra trainer config for T3.
community/methods/T3/run.sh Adds a community sweep script for running T3 locally.
community/methods/T3/README.md Adds a community README with setup/citation info.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/model/t3.py
class T3CausalLMConfig(PretrainedConfig):
model_type = "t3_causal_lm"

def __init__(self, guidance_kwargs=None, pooling="last", pool_temp=None, extraction_layer=-1, guidance_scale=1, base_temp=1, **kwargs):
Comment thread src/model/t3.py
Comment on lines +292 to +304
@classmethod
def from_pretrained_base(cls, pretrained_model_name_or_path: str, *args, guidance_kwargs=None, pooling="last", pool_temp=None, extraction_layer=-1, guidance_scale=1, base_temp=1, **kwargs):
base_lm = super(T3CausalLM, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return cls.from_pretrained_base_obj(
base_lm=base_lm,
guidance_kwargs=guidance_kwargs,
pooling=pooling,
pool_temp=pool_temp,
extraction_layer=extraction_layer,
guidance_scale=guidance_scale,
base_temp=base_temp
)

Comment thread src/model/t3.py
Comment on lines +234 to +236
except Exception as e:
print(f"Failed to load config.json from {pretrained_model_name_or_path} due to error {e}")

Comment thread src/model/t3.py
self.lm_loss = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
logger.info(
f"Initialized a T3CausalLM model:\n"
f"base_lm: {self.base_lm.model.config._name_or_path}\n"
Comment thread src/model/t3.py
Comment on lines +627 to +651
else:
_ = kwargs.pop("labels", None)
kwargs["output_hidden_states"] = True
kwargs["output_attentions"] = False
base_outputs = self.base_lm(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

# Adjust logits using the classifier guidance
extracted_states = base_outputs.hidden_states[self.extraction_layer] # (batch, seq_len, hidden_size)
pooled_states = self.pooling_fn(extracted_states, attention_mask)

classifier_logits = self.guidance_head(pooled_states) # (batch, seq_len, vocab)

with torch.no_grad():
guided_logits = self._guide_logits(base_outputs.logits, classifier_logits)

outputs = T3CausalLMOutputWithPast(
loss=None,
logits=guided_logits,
past_key_values = base_outputs.past_key_values,
hidden_states = base_outputs.hidden_states,
attentions = base_outputs.attentions,
base_logits = base_outputs.logits,
classifier_logits = classifier_logits
)
return outputs
Comment thread src/trainer/unlearn/t3.py
Comment on lines +39 to +45
self.pooling = pooling
self.pool_temp = pool_temp
if self.pool_temp is not None and self.pooling != "attn":
raise RuntimeError(f"Attempting to set pool_temp for pooling function {pooling}, but pool_temp is only supported for attn pooling.")
self.extraction_layer = extraction_layer
self.guidance_scale=guidance_scale
self.base_temp=base_temp
Comment thread src/trainer/unlearn/t3.py
Comment on lines +107 to +121
loss_sum = 0.0
total_valid = 0

for logits, labels in (
(retain_classifier_logits, retain_classifier_labels),
(forget_classifier_logits, forget_classifier_labels),
):
valid = labels != -100
if valid.any():
y = labels[valid].to(logits.dtype)
x = logits[valid]
loss_sum = loss_sum + F.binary_cross_entropy_with_logits(x, y, reduction="sum")
total_valid += valid.sum()

return loss_sum / total_valid.clamp_min(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants