Add T3-Unlearning method#196
Open
jacob-block wants to merge 1 commit into
Open
Conversation
There was a problem hiding this comment.
Pull request overview
Adds the T3-Unlearning method (Temper-Then-Tilt) by introducing a T3 model wrapper, a dedicated unlearning trainer, and classifier-focused evaluation metrics, plus associated configs/docs/community scripts.
Changes:
- Added
T3CausalLMwrapper model with a trainable guidance head and custom generation logic. - Added
T3unlearning trainer that trains the classifier head using retain/forget token labels. - Registered new trainer/model/metrics and added configuration + documentation/community run artifacts.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| src/trainer/unlearn/t3.py | New T3 unlearning trainer and classifier-loss computation. |
| src/trainer/init.py | Registers the new T3 trainer. |
| src/model/t3.py | New T3CausalLM wrapper model + config + generation utilities. |
| src/model/init.py | Registers the new T3CausalLM model handler. |
| src/evals/metrics/t3.py | Adds T3-specific classifier accuracy/loss metrics. |
| src/evals/metrics/init.py | Registers the new T3 metrics. |
| docs/links.md | Adds reference links for the T3-Unlearning paper/code. |
| configs/trainer/T3.yaml | Adds a default Hydra trainer config for T3. |
| community/methods/T3/run.sh | Adds a community sweep script for running T3 locally. |
| community/methods/T3/README.md | Adds a community README with setup/citation info. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class T3CausalLMConfig(PretrainedConfig): | ||
| model_type = "t3_causal_lm" | ||
|
|
||
| def __init__(self, guidance_kwargs=None, pooling="last", pool_temp=None, extraction_layer=-1, guidance_scale=1, base_temp=1, **kwargs): |
Comment on lines
+292
to
+304
| @classmethod | ||
| def from_pretrained_base(cls, pretrained_model_name_or_path: str, *args, guidance_kwargs=None, pooling="last", pool_temp=None, extraction_layer=-1, guidance_scale=1, base_temp=1, **kwargs): | ||
| base_lm = super(T3CausalLM, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | ||
| return cls.from_pretrained_base_obj( | ||
| base_lm=base_lm, | ||
| guidance_kwargs=guidance_kwargs, | ||
| pooling=pooling, | ||
| pool_temp=pool_temp, | ||
| extraction_layer=extraction_layer, | ||
| guidance_scale=guidance_scale, | ||
| base_temp=base_temp | ||
| ) | ||
|
|
Comment on lines
+234
to
+236
| except Exception as e: | ||
| print(f"Failed to load config.json from {pretrained_model_name_or_path} due to error {e}") | ||
|
|
| self.lm_loss = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) | ||
| logger.info( | ||
| f"Initialized a T3CausalLM model:\n" | ||
| f"base_lm: {self.base_lm.model.config._name_or_path}\n" |
Comment on lines
+627
to
+651
| else: | ||
| _ = kwargs.pop("labels", None) | ||
| kwargs["output_hidden_states"] = True | ||
| kwargs["output_attentions"] = False | ||
| base_outputs = self.base_lm(input_ids=input_ids, attention_mask=attention_mask, **kwargs) | ||
|
|
||
| # Adjust logits using the classifier guidance | ||
| extracted_states = base_outputs.hidden_states[self.extraction_layer] # (batch, seq_len, hidden_size) | ||
| pooled_states = self.pooling_fn(extracted_states, attention_mask) | ||
|
|
||
| classifier_logits = self.guidance_head(pooled_states) # (batch, seq_len, vocab) | ||
|
|
||
| with torch.no_grad(): | ||
| guided_logits = self._guide_logits(base_outputs.logits, classifier_logits) | ||
|
|
||
| outputs = T3CausalLMOutputWithPast( | ||
| loss=None, | ||
| logits=guided_logits, | ||
| past_key_values = base_outputs.past_key_values, | ||
| hidden_states = base_outputs.hidden_states, | ||
| attentions = base_outputs.attentions, | ||
| base_logits = base_outputs.logits, | ||
| classifier_logits = classifier_logits | ||
| ) | ||
| return outputs |
Comment on lines
+39
to
+45
| self.pooling = pooling | ||
| self.pool_temp = pool_temp | ||
| if self.pool_temp is not None and self.pooling != "attn": | ||
| raise RuntimeError(f"Attempting to set pool_temp for pooling function {pooling}, but pool_temp is only supported for attn pooling.") | ||
| self.extraction_layer = extraction_layer | ||
| self.guidance_scale=guidance_scale | ||
| self.base_temp=base_temp |
Comment on lines
+107
to
+121
| loss_sum = 0.0 | ||
| total_valid = 0 | ||
|
|
||
| for logits, labels in ( | ||
| (retain_classifier_logits, retain_classifier_labels), | ||
| (forget_classifier_logits, forget_classifier_labels), | ||
| ): | ||
| valid = labels != -100 | ||
| if valid.any(): | ||
| y = labels[valid].to(logits.dtype) | ||
| x = logits[valid] | ||
| loss_sum = loss_sum + F.binary_cross_entropy_with_logits(x, y, reduction="sum") | ||
| total_valid += valid.sum() | ||
|
|
||
| return loss_sum / total_valid.clamp_min(1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds the T3-Unlearning method (Temper-Then-Tilt Unlearning: Principled Unlearning for Generative Models through Tempering and Classifier Guidance, ICML 2026).