Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
*.pyc
*.png
*.log
.coverage
*.prof
memory_snapshot.html

# Folders
reference/
Expand All @@ -12,6 +15,8 @@ local/
output/
wandb/
build/
.pytest_cache/
__pycache__/

test/
data/
Expand Down
142 changes: 142 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# CLAUDE.md

Guidance for Claude Code (claude.ai/code) when working in this directory.

## What this is

`tah-release` is the cleaned, single-target version of public TaH (forked from
[thu-nics/TaH](https://github.com/thu-nics/TaH)). Its only supported model is
the released [`nics-efc/TaH-plus-1.7B`](https://huggingface.co/nics-efc/TaH-plus-1.7B).
The user's *other* HR2R fork (a research branch with `Qwen3MLPIterDecider`,
`Qwen3MLPUpdater`, `CombinedLoss`, `iter_attention_mode='causal'`,
`weighted_hidden_method='stop_prob'`) is a separate codebase and is
**not** loadable here — those classes don't exist in this package.

## Architecture (two-sentence version)

`TaHForCausalLM` (`tah/model/tah_model.py`) wraps a HF causal LM (Qwen3) so
that, on each forward pass, every token first runs the base model
(iter_depth=0); tokens for which the iter decider votes "continue" then
re-run with LoRA enabled (iter_depth >= 1), accumulating their logits via a
residual additive update. Per-(layer, iteration) KV is stored in a single
`TaHCache` (`tah/model/causal_cache.py`) so future iterations can causally
see prior ones without disturbing iter-0.

## Key invariants

* **Single-impl interfaces are inlined.** input_updater / output_updater /
iter_label_generator / adapter / iter_attention_mode all have exactly one
implementation, called directly inside `TaHForCausalLM.forward`. Don't
reintroduce a registry for these — the simplification was deliberate.

* **Multi-impl interfaces stay modular.** `iter_decider` (`IterLabelDecider`
for step-1 SFT, `MLPIterDecider` for step-2 + eval) and `loss`
(`NextTokenPredLoss`, `IterDeciderLoss`) keep their own modules with
`_BY_NAME` dicts for dispatch. No registries.

* **The wrapper exposes a minimal contract.** Forward signature is
`(input_ids, attention_mask?, position_ids?, past_key_values?, labels?,
iter_count_labels?, use_cache?, new_sequence?)`. Public TaH had several
more args (`iter_count`, `output_attentions`, `output_hidden_states`) that
were either unused or unsupported; assertions or removals.

* **Persistence layout.** A saved TaH checkpoint must contain:
- `tah_config.json` (config; type/dtype objects round-tripped via
`type_to_dict_string` / `dict_string_to_type`)
- `iter_decider.bin` (pickled `{class, init_args, state_dict}`)
- `lora/` (PEFT adapter dir)
- `model.safetensors` (base model with cleaned keys: no `.base_layer`
PEFT prefix, no `lora_*` weights — those live in `lora/`)
Downstream consumers (this repo's eval driver, `minisgl-tah` server's
`TaHQwen3ForCausalLM`) all rely on this layout. **Don't change without
also updating those consumers.**

## Common Commands

### Install
```bash
conda activate tah
uv pip install -e ".[dev,training,evaluation]"
```

### Tests + benchmarks
```bash
pytest tests/ # 11 component + wrapper + roundtrip tests
pytest tests/test_<name>.py -v # one file
TAH_TEST_DEVICE=cpu pytest tests/ # run on CPU (skip needs no flag)
python tests/bench.py components # per-helper microbench (B200 baseline in README)
python tests/bench.py e2e # forward + generate on TaH-plus-1.7B
```
See `tests/README.md` for the snapshot/baseline harness details.

### Training (3-stage)
```bash
# Step 0
python script/preparation/label.py --num_gpu 8 \
--dataset_path <jsonl> --test_model_list <hf-id> --output_path <out>

# Steps 1 + 2
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
./script/train/SFT_TaH.py \
--config ./script/recipes/qwen3_1.7/sft_tah_step{1,2}.yaml
```

### Evaluation (3 backends)
```bash
python script/evaluation/eval.py \
--eval_config ./script/recipes/qwen3_1.7/eval_tah.yaml \
--model_path nics-efc/TaH-plus-1.7B \
--dataset_name gsm8k --backend {tah,hf,sglang} \
--job_nums 8 --tp_size_per_job 1
```

### Quick inference demo
```bash
python script/playground/inference_example.py
```

## Layout

```
tah/
├── __init__.py # re-exports TaHForCausalLM, TaHConfig, TaHCache, …
├── model/
│ ├── tah_model.py # TaHForCausalLM + inlined slot helpers
│ ├── iter_decider.py # IterLabelDecider, MLPIterDecider, ITER_DECIDER_BY_NAME
│ ├── loss.py # NextTokenPredLoss, IterDeciderLoss, LOSS_BY_NAME
│ ├── causal_cache.py # TaHCache: per-(layer, iter) KV with up-to-iter views
│ ├── tah_config.py # @dataclass TaHConfig
│ └── utils.py # generation helper + IterCountColors + freeze/seed/sampling helpers
├── train/ # HF Trainer subclass + collator + iter-aware callback
├── evaluate/
│ ├── datasets.py # benchmark loading + standardisation
│ ├── backends.py # sglang / hf / tah model + inference fn
│ ├── jobs.py # job-sharded runner + result aggregation
│ ├── matheval.py # math benchmark graders (math_verify)
│ └── codeeval.py # humaneval / mbpp via evalplus
└── utils/data_prepare.py # SFT preprocessing
script/
├── preparation/ # download.py, label.py, prune.py, filter_split.py
├── train/SFT_TaH.py # YAML → wrapper → HF Trainer.train()
├── evaluation/eval.py # CLI wrapper over tah.evaluate.allocate_gpus_and_run_jobs
├── playground/inference_example.py
└── recipes/ # qwen3_{0.6,1.7}/sft_tah_step{1,2}.yaml + eval_tah.yaml
tests/ # _harness.py + per-component test_*.py + baselines/ (gitignored)
```

## Conventions

- All `tah.model.*` modules are designed to be importable in isolation with
small synthetic shapes — that's what `tests/conftest.py` exercises.
- `tah/model/tah_model.py` is the only place that mutates the wrapper's
internal state. New iter-loop behaviour goes there, not into the
iter_decider or loss classes.
- `iter_decider_kwargs.dtype` may be a `torch.dtype` — round-trip through the
``_config_to_serialisable`` / ``_config_from_serialisable`` helpers in
`tah/model/tah_model.py` (called from save_pretrained / from_pretrained).
- The wrapper assumes the base model's hidden size matches
`iter_decider_kwargs.hidden_states_size`. Recipes set this to 1024 for
Qwen3-0.6B and 2048 for Qwen3-1.7B; mismatches show up as a Linear
shape error.
122 changes: 76 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Feel free to star the repo or cite the paper if you find it interesting.
@article{fu2025tah,
title={Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models},
author={Tianyu Fu and Yichen You and Zekai Chen and Guohao Dai and Huazhong Yang and Yu Wang},
journal={arXiv preprint arXiv:2510.08577},
journal={arXiv preprint arXiv:2511.08577},
year={2025},
}
```
Expand Down Expand Up @@ -55,10 +55,17 @@ pip install -e ".[training,evaluation]"

For code generation evaluation, install [evalplus](https://github.com/evalplus/evalplus)

> **Note** if you ``git pull`` and the top-level package layout changes
> (e.g. ``__init__.py`` is added or removed), re-run ``pip install -e .``
> — the editable install caches the layout in
> ``site-packages/__editable___tah_*_finder.py`` and stale state will
> silently drop ``tah/__init__.py``'s re-exports.

## Run an example for TaH

```bash
python script/playground/inference_example.py
python script/playground/inference_example.py # quick demo (~1 min)
python script/playground/inference_example.py --max-new-tokens 16384 # full reasoning chain
```

This script demonstrates TaH's selective latent iteration mechanism, with color-coded output showing the iteration count for each token.
Expand All @@ -82,22 +89,37 @@ Key parameters:
- `--model_path`: Path to the model
- `--dataset_name`: Dataset name (supports gsm8k, math500, aime24, etc. Detailed configs can be found in `tah/evaluate/eval_configs/dataset_configs.json`)
- `--backend`: Inference backend (`tah` for TaH)
- `--job_nums`: Number of parallel jobs
- `--job_nums`: Number of parallel jobs (one job pins `tp_size_per_job` GPUs)
- `--tp_size_per_job`: Tensor parallel size per job
- `--data_range N` / `--data_range start end`: subset slice — handy for smoke tests
- `--data_ids gsm8k_0,gsm8k_5`: run only specific problem ids

### Evaluate standard baseline model
#### Single-GPU smoke
The default recipe targets 8 GPUs (`--job_nums 8`). To sanity-check the pipeline on
one GPU in a couple of minutes, slice the dataset and shrink `max_new_tokens`:
```bash
python script/evaluation/eval.py \
--eval_config ./script/recipes/qwen3_1.7/eval_base.yaml \
--model_path nics-efc/Standard-1.7B \
--dataset_name gsm8k \
--backend hf \
--job_nums 8 \
--tp_size_per_job 1
# clone the recipe and shrink generation length
sed 's/max_new_tokens: 4096/max_new_tokens: 512/' \
script/recipes/qwen3_1.7/eval_tah.yaml > /tmp/eval_tah_smoke.yaml

CUDA_VISIBLE_DEVICES=0 python script/evaluation/eval.py \
--eval_config /tmp/eval_tah_smoke.yaml \
--model_path nics-efc/TaH-plus-1.7B \
--dataset_name gsm8k --backend tah \
--job_nums 1 --tp_size_per_job 1 \
--data_range 5 \
--output_dir /tmp/tah_eval_smoke
```
The TaH backend is a token-by-token Python loop intended for research; for serving
throughput, use `--backend sglang` or the dedicated `minisgl-tah` server.

### Evaluate with a different backend

Similar to TaH evaluation, but using:
- `--backend hf` or `--backend sglang`
The same `script/evaluation/eval.py` accepts `--backend hf` (vanilla
`AutoModelForCausalLM.generate` — useful for non-TaH baselines) or
`--backend sglang` (sgl Engine for high-throughput serving). All three
backends share the same job-sharded driver under
`tah/evaluate/jobs.py:allocate_gpus_and_run_jobs`.

## Train your own TaH model

Expand All @@ -110,7 +132,6 @@ Training a TaH model consists of three stages:
Use a reference model to generate hard token labels for the training and validation data:

```bash
### step 0
# download the default subset of OpenR1-Math-220k
python script/preparation/download.py
# filter and split
Expand All @@ -135,7 +156,6 @@ python script/preparation/label.py \
For the TaH version, prune one layer from the base model to match the parameter count of the standard baseline (skip this step for TaH+ version):

```bash
### step 0
python script/preparation/prune.py \
--model Qwen/Qwen3-1.7B-Base \
--dataset ./data/processed_data/openr1-math/1_7/eval \
Expand All @@ -148,7 +168,6 @@ python script/preparation/prune.py \
The first stage uses fixed iteration labels for training:

```bash
### step 1
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
Expand All @@ -157,20 +176,21 @@ python -m accelerate.commands.launch \
```

Key configurations in Step1 (`sft_tah_step1.yaml`):
- `max_iter: 2`: Maximum number of iterations
- `iter_decider: "FixedLabelIterDecider"`: Use fixed labels to decide iterations
- `iter_label_generator: "FixedIterLabelGenerator"`: Generate labels from mismatch field in data
- `input_updater: "AdditiveUpdater"`: Use additive updater for input updates
- `adapter: "lora"`: Use LoRA adapter for deeper iteration
- `train_loss: "NextTokenPredLoss"`: Next token prediction loss
- `max_iter: 2` — maximum number of iterations.
- `iter_decider: "IterLabelDecider"` — continue iff the per-token oracle
``iter_count_labels`` (derived from ``mismatch``) say so. Used to teach
the LoRA adapter on tokens marked "hard" by the labeller.
- `adapter: "lora"` — only LoRA is supported in tah-release.
- `train_loss: "NextTokenPredLoss"` — standard causal-LM cross-entropy.

Single-implementation hooks (input/output updaters, iter labels, adapter) are inlined into the wrapper — only `iter_decider` and `train_loss` are config-selectable.

### Step2: Train Iteration Decider

The second stage trains the iteration decider:


```bash
### step 2
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
Expand All @@ -192,32 +212,42 @@ After two-stage training, the model can automatically decide when to perform lat

```
TaH/
├── tah/ # Core package
│ ├── model/ # Core model components
│ ├── train/ # Training components
│ ├── evaluate/ # Evaluation utilities
│ └── utils/ # General utilities
├── bash/ # Bash scripts for training and evaluation
├── script/ # Execution scripts
│ ├── analysis/ # Analysis scripts
│ ├── evaluation/ # Evaluation scripts
│ ├── preparation/ # Preparation for training
│ │ ├── label.py # Data labeling (generate mismatch labels)
│ │ └── prune.py # Model pruning
│ ├── playground/ # Some examples
│ └── recipes/ # Configuration files
│ ├── qwen3_0.6/ # Qwen3-0.6B-Base configs
│ ├── qwen3_1.7/ # Qwen3-1.7B-Base configs
│ └── accelerate_configs/ # Distributed training configs
└── pyproject.toml # Project configuration
├── tah/
│ ├── model/ # core model
│ │ ├── tah_model.py # TaHForCausalLM wrapper + inlined slot helpers
│ │ ├── iter_decider.py # IterLabelDecider, MLPIterDecider, _BY_NAME
│ │ ├── loss.py # NextTokenPredLoss, IterDeciderLoss, _BY_NAME
│ │ ├── causal_cache.py # TaHCache: per-(layer, iter) KV
│ │ ├── tah_config.py # @dataclass TaHConfig
│ │ └── utils.py # generation helper, IterCountColors
│ ├── train/ # HF Trainer subclass + collator + iter-aware callback
│ ├── evaluate/ # multi-backend eval driver
│ │ ├── datasets.py # benchmark loading + standardisation
│ │ ├── backends.py # sglang / hf / tah model + inference fn
│ │ ├── jobs.py # job-sharded runner + result aggregation
│ │ ├── matheval.py # math benchmark graders (math_verify)
│ │ └── codeeval.py # humaneval / mbpp via evalplus
│ └── utils/ # SFT preprocessing
├── script/
│ ├── preparation/ # download.py, label.py, prune.py, filter_split.py
│ ├── train/SFT_TaH.py # SFT entrypoint
│ ├── evaluation/eval.py # eval CLI entrypoint
│ ├── playground/ # inference demo
│ └── recipes/qwen3_{0.6,1.7}/ # training + eval YAML recipes
├── tests/ # per-component acc/speed tests
└── pyproject.toml
```

## Future Work
## Tests + benchmarks

```bash
pytest tests/ -q # 21 acc + roundtrip tests in ~30s on a B200
python tests/bench.py components # microbenchmarks for the wrapper's hot helpers
python tests/bench.py e2e # forward + 32-token generate on TaH-plus-1.7B
python tests/bench_compile.py # one-off torch.compile vs eager experiment
```

- [ ] Support more inference backends (e.g., SGLang)
- [ ] Optimize iteration decision strategies
- [ ] Integrate TaH with online distillation or RL
- [ ] Support training for larger models
See [`tests/README.md`](tests/README.md) for component-level baselines and CPU-mode (`TAH_TEST_DEVICE=cpu`) instructions.

## Related Projects

Expand Down
12 changes: 0 additions & 12 deletions bash/eval_base.sh

This file was deleted.

16 changes: 0 additions & 16 deletions bash/eval_oracle.sh

This file was deleted.

Loading