From fc46d48484df6d55bd8b7c893088113c4b163e88 Mon Sep 17 00:00:00 2001 From: Tianyu Fu Date: Sat, 25 Apr 2026 22:40:25 -0400 Subject: [PATCH 1/5] [refactor] simplify public TaH codebase: drop registry, modular eval, test+bench harness Squashes 38 iterative commits on the simplify branch into one cohesive cleanup of public TaH (forked from thu-nics/TaH) targeting nics-efc/TaH-plus-1.7B only. Model wrapper (tah/model/): - Drop the by-name registry; inline the single-impl slots (input_updater, output_updater, iter_label_generator, adapter) directly into the wrapper. - Rename recurrent_transformer.py -> tah_model.py; expose iter-loop phases as named helpers (Phase 1-8); extract _resolve_tah_config, _build_loss / _build_iter_decider, _setup_lora, _set_lora_grad_flags helpers. - Vectorise scatter_back; flatten TaHCache dict-of-dicts; trim dead helpers in utils.py and the playground. - IterDeciderLoss: extract _record_threshold_accuracy and _scatter_per_token for readability without changing semantics. - MLPIterDecider: extract _select_hidden_features; thread dtype through _ClassifierBackbone / _ClassifierBlock. - Inline TaHConfig (de)serialiser; round-trip torch.dtype via type_to_dict_string / dict_string_to_type. from_pretrained now resolves HF Hub IDs via snapshot_download and places iter_decider on the model device. Eval driver (tah/evaluate/): - Split eval_unified into datasets.py / backends.py / jobs.py; keep eval_unified.py as a backwards-compat shim. - backends.setup_backend collapsed to a table dispatch. - jobs.run_single_job split into _job_output_dir / _build_prompts / _process_batch with a small unit test that stubs the backend (no GPU/model). - Trim utils.py + matheval.py + codeeval.py; drop tracker integration. Train + scripts: - HF Trainer subclass + collator + iter-aware callback kept; trim dead branches in tah/train and script/train/SFT_TaH.py. - script/preparation/label.py trimmed to a Qwen3-only labeller (drop MobileLLM/R1 templates and the verbose statistics report). - script/preparation/prune.py: drop env override + arrow fallback. - script/evaluation/eval.py: tighten CLI. - Recipes: drop inert fields, fix qwen3_0.6 step1 decider, drop dead variants. Tests + bench: - tests/_harness.py + per-component test_*.py: snapshot-based differential testing against /tmp/TaH-pub via PYTHONPATH-isolated subprocess; baselines in tests/baselines/ are gitignored and regenerate on first run. - tests/test_sft_smoke.py: trainer + collator end-to-end on synthetic data. - tests/bench.py: component + e2e microbenchmarks (B200 baselines documented in README); tests/bench_compile.py: torch.compile vs eager experiment. Docs: - README + CLAUDE.md refreshed for the cleaned layout, helper reference, bench harness and B200 baselines. --- CLAUDE.md | 147 ++ README.md | 102 +- bash/eval_base.sh | 12 - bash/eval_oracle.sh | 16 - bash/eval_route.sh | 20 - bash/sft_base.sh | 9 - pyproject.toml | 76 +- script/analysis/iter_pred_analysis_backup.py | 429 ----- .../evaluation/baseline_routellm_getscore.py | 213 --- .../compute_routellm_threshold_accuracy.py | 204 --- script/evaluation/eval.py | 78 +- script/playground/inference_example.py | 129 +- script/preparation/label.py | 1358 +++++----------- script/preparation/prune.py | 158 +- script/recipes/qwen3_0.6/eval_base.yaml | 12 - script/recipes/qwen3_0.6/eval_tah.yaml | 39 +- script/recipes/qwen3_0.6/sft_base.yaml | 55 - script/recipes/qwen3_0.6/sft_tah_step1.yaml | 35 +- script/recipes/qwen3_0.6/sft_tah_step2.yaml | 31 +- script/recipes/qwen3_1.7/eval_base.yaml | 12 - script/recipes/qwen3_1.7/eval_tah.yaml | 40 +- script/recipes/qwen3_1.7/eval_tah_oracle.yaml | 31 - script/recipes/qwen3_1.7/sft_base.yaml | 55 - script/recipes/qwen3_1.7/sft_tah_step1.yaml | 34 +- script/recipes/qwen3_1.7/sft_tah_step2.yaml | 37 +- script/train/SFT_TaH.py | 534 ++---- tah/__init__.py | 21 + tah/evaluate/__init__.py | 27 + tah/evaluate/backends.py | 204 +++ tah/evaluate/codeeval.py | 459 +++--- tah/evaluate/datasets.py | 140 ++ tah/evaluate/eval_unified.py | 1432 +---------------- tah/evaluate/jobs.py | 624 +++++++ tah/evaluate/matheval.py | 440 ++--- tah/evaluate/utils.py | 694 -------- tah/model/__init__.py | 35 + tah/model/adapter.py | 157 -- tah/model/causal_cache.py | 570 ++----- tah/model/input_updater.py | 86 - tah/model/iter_decider.py | 939 +++-------- tah/model/iter_label.py | 218 --- tah/model/loss.py | 531 ++---- tah/model/output_updater.py | 114 -- tah/model/recurrent_transformer.py | 1323 --------------- tah/model/registry.py | 134 -- tah/model/tah_config.py | 62 +- tah/model/tah_model.py | 875 ++++++++++ tah/model/tracker.py | 189 --- tah/model/utils.py | 655 +++----- tah/train/__init__.py | 15 +- tah/train/data_collator.py | 178 +- tah/train/trainer.py | 270 +--- tah/utils/__init__.py | 5 + tah/utils/data_prepare.py | 719 ++------- tah/utils/sampling.py | 81 - tests/__init__.py | 0 tests/_harness.py | 211 +++ tests/baselines/.gitignore | 4 + tests/bench.py | 300 ++++ tests/bench_compile.py | 129 ++ tests/conftest.py | 44 + tests/test_causal_cache.py | 150 ++ tests/test_input_updater.py | 92 ++ tests/test_iter_decider.py | 145 ++ tests/test_iter_label.py | 114 ++ tests/test_jobs_runner.py | 143 ++ tests/test_loss.py | 168 ++ tests/test_output_updater.py | 70 + tests/test_released_checkpoint.py | 121 ++ tests/test_save_load.py | 125 ++ tests/test_sft_smoke.py | 175 ++ tests/test_wrapper_forward.py | 157 ++ 72 files changed, 6592 insertions(+), 10619 deletions(-) create mode 100644 CLAUDE.md delete mode 100644 bash/eval_base.sh delete mode 100755 bash/eval_oracle.sh delete mode 100644 bash/eval_route.sh delete mode 100644 bash/sft_base.sh delete mode 100644 script/analysis/iter_pred_analysis_backup.py delete mode 100644 script/evaluation/baseline_routellm_getscore.py delete mode 100644 script/evaluation/compute_routellm_threshold_accuracy.py mode change 100755 => 100644 script/evaluation/eval.py mode change 100755 => 100644 script/preparation/label.py delete mode 100644 script/recipes/qwen3_0.6/eval_base.yaml delete mode 100755 script/recipes/qwen3_0.6/sft_base.yaml delete mode 100644 script/recipes/qwen3_1.7/eval_base.yaml delete mode 100644 script/recipes/qwen3_1.7/eval_tah_oracle.yaml delete mode 100755 script/recipes/qwen3_1.7/sft_base.yaml mode change 100755 => 100644 script/train/SFT_TaH.py create mode 100644 tah/__init__.py create mode 100644 tah/evaluate/__init__.py create mode 100644 tah/evaluate/backends.py create mode 100644 tah/evaluate/datasets.py create mode 100644 tah/evaluate/jobs.py delete mode 100644 tah/evaluate/utils.py create mode 100644 tah/model/__init__.py delete mode 100644 tah/model/adapter.py delete mode 100755 tah/model/input_updater.py mode change 100755 => 100644 tah/model/iter_decider.py delete mode 100644 tah/model/iter_label.py delete mode 100644 tah/model/output_updater.py delete mode 100755 tah/model/recurrent_transformer.py delete mode 100644 tah/model/registry.py create mode 100644 tah/model/tah_model.py delete mode 100644 tah/model/tracker.py create mode 100644 tah/utils/__init__.py mode change 100755 => 100644 tah/utils/data_prepare.py delete mode 100644 tah/utils/sampling.py create mode 100644 tests/__init__.py create mode 100644 tests/_harness.py create mode 100644 tests/baselines/.gitignore create mode 100644 tests/bench.py create mode 100644 tests/bench_compile.py create mode 100644 tests/conftest.py create mode 100644 tests/test_causal_cache.py create mode 100644 tests/test_input_updater.py create mode 100644 tests/test_iter_decider.py create mode 100644 tests/test_iter_label.py create mode 100644 tests/test_jobs_runner.py create mode 100644 tests/test_loss.py create mode 100644 tests/test_output_updater.py create mode 100644 tests/test_released_checkpoint.py create mode 100644 tests/test_save_load.py create mode 100644 tests/test_sft_smoke.py create mode 100644 tests/test_wrapper_forward.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..b74d913 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,147 @@ +# CLAUDE.md + +Guidance for Claude Code (claude.ai/code) when working in this directory. + +## What this is + +`tah-release` is the cleaned, single-target version of public TaH (forked from +[thu-nics/TaH](https://github.com/thu-nics/TaH)). Its only supported model is +the released [`nics-efc/TaH-plus-1.7B`](https://huggingface.co/nics-efc/TaH-plus-1.7B). +The user's *other* HR2R fork (a research branch with `Qwen3MLPIterDecider`, +`Qwen3MLPUpdater`, `CombinedLoss`, `iter_attention_mode='causal'`, +`weighted_hidden_method='stop_prob'`) is a separate codebase and is +**not** loadable here — those classes don't exist in this package. + +## Architecture (two-sentence version) + +`TaHForCausalLM` (`tah/model/tah_model.py`) wraps a HF causal LM (Qwen3) so +that, on each forward pass, every token first runs the base model +(iter_depth=0); tokens for which the iter decider votes "continue" then +re-run with LoRA enabled (iter_depth >= 1), accumulating their logits via a +residual additive update. Per-(layer, iteration) KV is stored in a single +`TaHCache` (`tah/model/causal_cache.py`) so future iterations can causally +see prior ones without disturbing iter-0. + +## Key invariants + +* **Single-impl interfaces are inlined.** input_updater / output_updater / + iter_label_generator / adapter / iter_attention_mode all have exactly one + implementation, called directly inside `TaHForCausalLM.forward`. Don't + reintroduce a registry for these — the simplification was deliberate. + +* **Multi-impl interfaces stay modular.** `iter_decider` (`IterLabelDecider` + for step-1 SFT, `MLPIterDecider` for step-2 + eval) and `loss` + (`NextTokenPredLoss`, `IterDeciderLoss`) keep their own modules with + `_BY_NAME` dicts for dispatch. No registries. + +* **The wrapper exposes a minimal contract.** Forward signature is + `(input_ids, attention_mask?, position_ids?, past_key_values?, labels?, + iter_count_labels?, use_cache?, new_sequence?)`. Public TaH had several + more args (`iter_count`, `output_attentions`, `output_hidden_states`) that + were either unused or unsupported; assertions or removals. + +* **Persistence layout.** A saved TaH checkpoint must contain: + - `tah_config.json` (config; type/dtype objects round-tripped via + `type_to_dict_string` / `dict_string_to_type`) + - `iter_decider.bin` (pickled `{class, init_args, state_dict}`) + - `lora/` (PEFT adapter dir) + - `model.safetensors` (base model with cleaned keys: no `.base_layer` + PEFT prefix, no `lora_*` weights — those live in `lora/`) + Downstream consumers (this repo's eval driver, `minisgl-tah` server's + `TaHQwen3ForCausalLM`) all rely on this layout. **Don't change without + also updating those consumers.** + +## Common Commands + +### Install +```bash +conda activate release +uv pip install -e ".[dev,training,evaluation]" +``` + +### Tests + benchmarks +```bash +pytest tests/ # 21 component + wrapper + roundtrip tests +pytest tests/test_.py -v # one file +TAH_TEST_DEVICE=cpu pytest tests/ # run on CPU (skip needs no flag) +python tests/bench.py components # per-helper microbench (B200 baseline in README) +python tests/bench.py e2e # forward + generate on TaH-plus-1.7B +``` +Snapshot baselines are captured by spawning a subprocess scoped to +`/tmp/TaH-pub` (public TaH); cleaned outputs are diffed against the +recorded snapshots in `tests/baselines/`. Snapshots are gitignored — +they regenerate on first run. + +### Training (3-stage) +```bash +# Step 0 +python script/preparation/label.py --num_gpu 8 \ + --dataset_path --test_model_list --output_path + +# Steps 1 + 2 +python -m accelerate.commands.launch \ + --config_file ./script/recipes/accelerate_configs/zero2.yaml \ + --num_processes 8 \ + ./script/train/SFT_TaH.py \ + --config ./script/recipes/qwen3_1.7/sft_tah_step{1,2}.yaml +``` + +### Evaluation (3 backends) +```bash +python script/evaluation/eval.py \ + --eval_config ./script/recipes/qwen3_1.7/eval_tah.yaml \ + --model_path nics-efc/TaH-plus-1.7B \ + --dataset_name gsm8k --backend {tah,hf,sglang} \ + --job_nums 8 --tp_size_per_job 1 +``` + +### Quick inference demo +```bash +python script/playground/inference_example.py +``` + +## Layout + +``` +tah/ +├── __init__.py # re-exports TaHForCausalLM, TaHConfig, TaHCache, … +├── model/ +│ ├── tah_model.py # TaHForCausalLM + inlined slot helpers +│ ├── iter_decider.py # IterLabelDecider, MLPIterDecider, ITER_DECIDER_BY_NAME +│ ├── loss.py # NextTokenPredLoss, IterDeciderLoss, LOSS_BY_NAME +│ ├── causal_cache.py # TaHCache: per-(layer, iter) KV with up-to-iter views +│ ├── tah_config.py # @dataclass TaHConfig +│ └── utils.py # generation helper + IterCountColors + freeze/seed/sampling helpers +├── train/ # HF Trainer subclass + collator + iter-aware callback +├── evaluate/ +│ ├── datasets.py # benchmark loading + standardisation +│ ├── backends.py # sglang / hf / tah model + inference fn +│ ├── jobs.py # job-sharded runner + result aggregation +│ ├── matheval.py # math benchmark graders (math_verify) +│ ├── codeeval.py # humaneval / mbpp via evalplus +│ └── eval_unified.py # backwards-compat shim re-exporting the above +└── utils/data_prepare.py # SFT preprocessing +script/ +├── preparation/ # download.py, label.py, prune.py, filter_split.py +├── train/SFT_TaH.py # YAML → wrapper → HF Trainer.train() +├── evaluation/eval.py # CLI wrapper over tah.evaluate.allocate_gpus_and_run_jobs +├── playground/inference_example.py +└── recipes/ # qwen3_{0.6,1.7}/sft_tah_step{1,2}.yaml + eval_tah.yaml +tests/ # _harness.py + per-component test_*.py + baselines/ (gitignored) +bash/ # sft_tah.sh, eval_tah.sh, pre_data.sh +``` + +## Conventions + +- All `tah.model.*` modules are designed to be importable in isolation with + small synthetic shapes — that's what `tests/conftest.py` exercises. +- `tah/model/tah_model.py` is the only place that mutates the wrapper's + internal state. New iter-loop behaviour goes there, not into the + iter_decider or loss classes. +- `iter_decider_kwargs.dtype` may be a `torch.dtype` — round-trip through the + ``_config_to_serialisable`` / ``_config_from_serialisable`` helpers in + `tah/model/tah_model.py` (called from save_pretrained / from_pretrained). +- The wrapper assumes the base model's hidden size matches + `iter_decider_kwargs.hidden_states_size`. Recipes set this to 1024 for + Qwen3-0.6B and 2048 for Qwen3-1.7B; mismatches show up as a Linear + shape error. diff --git a/README.md b/README.md index 540b68b..ac6d3cf 100644 --- a/README.md +++ b/README.md @@ -85,19 +85,13 @@ Key parameters: - `--job_nums`: Number of parallel jobs - `--tp_size_per_job`: Tensor parallel size per job -### Evaluate standard baseline model -```bash -python script/evaluation/eval.py \ - --eval_config ./script/recipes/qwen3_1.7/eval_base.yaml \ - --model_path nics-efc/Standard-1.7B \ - --dataset_name gsm8k \ - --backend hf \ - --job_nums 8 \ - --tp_size_per_job 1 -``` +### Evaluate with a different backend -Similar to TaH evaluation, but using: -- `--backend hf` or `--backend sglang` +The same `script/evaluation/eval.py` accepts `--backend hf` (vanilla +`AutoModelForCausalLM.generate` — useful for non-TaH baselines) or +`--backend sglang` (sgl Engine for high-throughput serving). All three +backends share the same job-sharded driver under +`tah/evaluate/jobs.py:allocate_gpus_and_run_jobs`. ## Train your own TaH model @@ -157,12 +151,17 @@ python -m accelerate.commands.launch \ ``` Key configurations in Step1 (`sft_tah_step1.yaml`): -- `max_iter: 2`: Maximum number of iterations -- `iter_decider: "FixedLabelIterDecider"`: Use fixed labels to decide iterations -- `iter_label_generator: "FixedIterLabelGenerator"`: Generate labels from mismatch field in data -- `input_updater: "AdditiveUpdater"`: Use additive updater for input updates -- `adapter: "lora"`: Use LoRA adapter for deeper iteration -- `train_loss: "NextTokenPredLoss"`: Next token prediction loss +- `max_iter: 2` — maximum number of iterations. +- `iter_decider: "IterLabelDecider"` — continue iff the per-token oracle + ``iter_count_labels`` (derived from ``mismatch``) say so. Used to teach + the LoRA adapter on tokens marked "hard" by the labeller. +- `adapter: "lora"` — only LoRA is supported in tah-release. +- `train_loss: "NextTokenPredLoss"` — standard causal-LM cross-entropy. + +Note: the input updater (top-k softmax over logits → embedding mix), the +output updater (residual additive accumulation), the iter-label generator +(dense max-merge of dataset labels), and the adapter setup are all inlined +into the wrapper, so there's no separate config field for them anymore. ### Step2: Train Iteration Decider @@ -192,29 +191,58 @@ After two-stage training, the model can automatically decide when to perform lat ``` TaH/ -├── tah/ # Core package -│ ├── model/ # Core model components -│ ├── train/ # Training components -│ ├── evaluate/ # Evaluation utilities -│ └── utils/ # General utilities -├── bash/ # Bash scripts for training and evaluation -├── script/ # Execution scripts -│ ├── analysis/ # Analysis scripts -│ ├── evaluation/ # Evaluation scripts -│ ├── preparation/ # Preparation for training -│ │ ├── label.py # Data labeling (generate mismatch labels) -│ │ └── prune.py # Model pruning -│ ├── playground/ # Some examples -│ └── recipes/ # Configuration files -│ ├── qwen3_0.6/ # Qwen3-0.6B-Base configs -│ ├── qwen3_1.7/ # Qwen3-1.7B-Base configs -│ └── accelerate_configs/ # Distributed training configs -└── pyproject.toml # Project configuration +├── tah/ +│ ├── model/ # core model +│ │ ├── tah_model.py # TaHForCausalLM wrapper + inlined slot helpers +│ │ ├── iter_decider.py # IterLabelDecider, MLPIterDecider, _BY_NAME +│ │ ├── loss.py # NextTokenPredLoss, IterDeciderLoss, _BY_NAME +│ │ ├── causal_cache.py # TaHCache: per-(layer, iter) KV +│ │ ├── tah_config.py # @dataclass TaHConfig +│ │ └── utils.py # generation helper, IterCountColors +│ ├── train/ # HF Trainer subclass + collator + iter-aware callback +│ ├── evaluate/ # multi-backend eval driver +│ │ ├── datasets.py # benchmark loading + standardisation +│ │ ├── backends.py # sglang / hf / tah model + inference fn +│ │ ├── jobs.py # job-sharded runner + result aggregation +│ │ ├── matheval.py # math benchmark graders (math_verify) +│ │ ├── codeeval.py # humaneval / mbpp via evalplus +│ │ └── eval_unified.py # backwards-compat shim +│ └── utils/ # SFT preprocessing +├── script/ +│ ├── preparation/ # download.py, label.py, prune.py, filter_split.py +│ ├── train/SFT_TaH.py # SFT entrypoint +│ ├── evaluation/eval.py # eval CLI entrypoint +│ ├── playground/ # inference demo +│ └── recipes/qwen3_{0.6,1.7}/ # training + eval YAML recipes +├── tests/ # per-component acc/speed tests +└── pyproject.toml ``` +## Tests + benchmarks + +```bash +pytest tests/ -q # 21 acc + roundtrip tests in ~30s on a B200 +python tests/bench.py components # microbenchmarks for the wrapper's hot helpers +python tests/bench.py e2e # forward + 32-token generate on TaH-plus-1.7B +python tests/bench_compile.py # one-off torch.compile vs eager experiment +``` + +Component baselines (single B200, torch 2.11+cu128, bf16): + +| helper | ms | +|---|---| +| topk_softmax_input_update | 0.48 | +| additive_logits_update | 0.03 | +| gather_active | 0.19 | +| scatter_back | 0.12 | +| MLPIterDecider.forward | 0.86 | +| NextTokenPredLoss.final | 0.23 | +| IterDeciderLoss.intra | 0.55 | +| **TaHForCausalLM.forward** (TaH-plus-1.7B, T=15) | **18.0** | +| **TaHForCasualLM_generate(32)** | **691** (~21.6 ms / token) | + ## Future Work -- [ ] Support more inference backends (e.g., SGLang) - [ ] Optimize iteration decision strategies - [ ] Integrate TaH with online distillation or RL - [ ] Support training for larger models diff --git a/bash/eval_base.sh b/bash/eval_base.sh deleted file mode 100644 index 5537183..0000000 --- a/bash/eval_base.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -export HF_ENDPOINT="https://hf-mirror.com" -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - -python script/evaluation/eval.py \ - --eval_config ./script/recipes/qwen3_1.7/eval_base.yaml \ - --model_path nics-efc/Standard-1.7B \ - --dataset_name gsm8k \ - --backend hf \ - --job_nums 8 \ - --tp_size_per_job 1 \ No newline at end of file diff --git a/bash/eval_oracle.sh b/bash/eval_oracle.sh deleted file mode 100755 index 4b5e576..0000000 --- a/bash/eval_oracle.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -export HF_ENDPOINT="https://hf-mirror.com" -export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - -python script/evaluation/eval.py \ - --eval_config script/recipes/qwen3_1.7/eval_tah_oracle.yaml \ - --model_path nics-efc/TaH-plus-1.7B \ - --output_dir output/evaluation/ \ - --dataset_name math500 \ - --backend tah \ - --job_nums 4 \ - --tp_size_per_job 2 \ - --logger_level WARNING \ - --data_range 10 \ No newline at end of file diff --git a/bash/eval_route.sh b/bash/eval_route.sh deleted file mode 100644 index 7573752..0000000 --- a/bash/eval_route.sh +++ /dev/null @@ -1,20 +0,0 @@ - - -export $(grep -v '^#' .env | xargs) - -python script/evaluation/baseline_routellm_getscore.py \ - --dataset-name olympiadbench \ - --out output/evaluation/routellm \ - --router mf \ - -DATASET_NAME=olympiadbench -STRONG_SIZE=4.0 -WEAK_SIZE=0.6 -TARGET_AVG=1.7 - -python ./script/evaluation/compute_routellm_threshold_accuracy.py \ - --scores output/evaluation/routellm/mf/${DATASET_NAME}.csv \ - --strong-detailed /path/to/strong/detailed_results.csv \ - --weak-detailed /path/to/weak/detailed_results.csv \ - --strong-size ${STRONG_SIZE} --weak-size ${WEAK_SIZE} --target-avg ${TARGET_AVG} --rounding round\ - --save-csv output/evaluation/routellm/thresholded_${STRONG_SIZE}_${WEAK_SIZE}_${TARGET_AVG}_${DATASET_NAME}_detail.csv \ No newline at end of file diff --git a/bash/sft_base.sh b/bash/sft_base.sh deleted file mode 100644 index 7068064..0000000 --- a/bash/sft_base.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - -python -m accelerate.commands.launch \ - --config_file ./script/recipes/accelerate_configs/zero2.yaml \ - --num_processes 8 \ - ./script/train/SFT_TaH.py \ - --config ./script/recipes/qwen3_1.7/sft_base.yaml \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0ed6ecb..2e04599 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = [ "setuptools>=61.0", "wheel",] +requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -8,56 +8,54 @@ version = "0.1.0" description = "Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models" readme = "README.md" requires-python = ">=3.10" -classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12",] -dependencies = [ "transformers==4.52.4", "torch==2.6.0", "matplotlib", "torchvision==0.21.0", "accelerate", "datasets", "peft==0.15.2",] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "transformers==4.52.4", + "torch==2.6.0", + "torchvision==0.21.0", + "accelerate", + "datasets", + "peft==0.15.2", +] [project.license] text = "MIT" [project.optional-dependencies] -dev = [ "pytest>=6.0", "pytest-cov", "pylatexenc",] -training = [ "wandb", "deepspeed==0.17.1", "liger_kernel==0.6.0", "flash-attn==2.7.3",] -evaluation = [ "latex2sympy2==1.9.1", "pylatexenc==2.10", "sympy==1.13.1", "sglang>=0.4.6",] +dev = ["pytest>=6.0", "pytest-cov"] +training = ["wandb", "deepspeed==0.17.1", "flash-attn==2.7.3"] +evaluation = [ + "latex2sympy2==1.9.1", + "pylatexenc==2.10", + "sympy==1.13.1", + "math_verify", + "termcolor", + "evalplus", + "sglang>=0.4.6", +] [tool.black] -line-length = 88 -target-version = [ "py310",] -include = "\\.pyi?$" -extend-exclude = "/(\n # directories\n \\.eggs\n | \\.git\n | \\.hg\n | \\.mypy_cache\n | \\.tox\n | \\.venv\n | build\n | dist\n)/\n" +line-length = 100 +target-version = ["py310"] [tool.isort] profile = "black" -line_length = 88 -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true -ensure_newline_before_comments = true - -[tool.mypy] -python_version = "3.10" -warn_return_any = true -warn_unused_configs = true -disallow_untyped_defs = true -disallow_incomplete_defs = true -check_untyped_defs = true -disallow_untyped_decorators = true -no_implicit_optional = true -warn_redundant_casts = true -warn_unused_ignores = true -warn_no_return = true -warn_unreachable = true -strict_equality = true - -[tool.setuptools.package-data] -tah = [ "*.py",] +line_length = 100 [tool.pytest.ini_options] -testpaths = [ "test",] -python_files = [ "test_*.py", "*_test.py",] -python_classes = [ "Test*",] -python_functions = [ "test_*",] +testpaths = ["tests"] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] addopts = "-v --tb=short" [tool.setuptools.packages.find] -include = [ "tah*",] +include = ["tah*"] diff --git a/script/analysis/iter_pred_analysis_backup.py b/script/analysis/iter_pred_analysis_backup.py deleted file mode 100644 index 8b353a4..0000000 --- a/script/analysis/iter_pred_analysis_backup.py +++ /dev/null @@ -1,429 +0,0 @@ -#!/usr/bin/env python3 -"""End-to-end analyses for TaH labeled datasets. - -This script performs the following tasks: - -1) Compute P(iter1_decision=1 | iter1_pred=x) from tah_decision.labeled_correct.csv - - Applies a minimum support threshold (default 0.1% of rows) - - Prints the top-K tokens by conditional probability - -2) Use only tah_decision.labeled_correct.csv for all subsequent analyses - - Define iter2_pred as existing only when iter1_decision==1; in that case - iter2_pred is equal to final_pred. For iter1_decision==0, iter2_pred does - not exist (ignored in analyses that need iter2). - -3) Compute P(iter1_decision=1 | final_pred=x) from the corrected CSV - - Uses the same minimum support threshold and prints the top-K tokens - -4) For selected iter1_pred seeds (auto-chosen from top-3 unless overridden), - prints their most common iter2 (i.e., final_pred on rows with iter1_decision==1) - values with counts and percentages, using only the corrected CSV. - -All default file paths target the 1.7B analysis directory. -""" - -import argparse -import csv -import os -from collections import Counter, defaultdict -from math import ceil - - -def format_token(token: str) -> str: - if token is None: - return "" - return token.replace("\n", "\\n").replace("\t", "\\t") - - -def compute_conditional_probability( - csv_path: str, - group_col: str, - decision_col: str = "iter1_decision", - min_support_ratio: float = 0.001, - top_k: int = 50, - row_filter=None, -): - """Compute P(decision_col==1 | group_col=x) with min-support filter. - - Returns: (total_rows, base_rate, min_support_count, leaderboard) - leaderboard: list of tuples (prob, count, positives, token) - """ - counts = defaultdict(int) - positives = defaultdict(int) - - with open(csv_path, "r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - if group_col not in reader.fieldnames or decision_col not in reader.fieldnames: - raise ValueError( - f"Required columns missing in {csv_path}: {group_col} or {decision_col}" - ) - for row in reader: - if row_filter is not None and not row_filter(row): - continue - key = row.get(group_col) - val = row.get(decision_col) - try: - decision_is_one = int(val) == 1 - except Exception: - # Skip malformed decision values - continue - counts[key] += 1 - if decision_is_one: - positives[key] += 1 - - total_rows = sum(counts.values()) - base_rate = (sum(positives.values()) / total_rows) if total_rows else 0.0 - min_support = ceil(min_support_ratio * total_rows) - - leaderboard = [] - for token, n in counts.items(): - if n < min_support: - continue - k = positives[token] - p = k / n if n else 0.0 - leaderboard.append((p, n, k, token)) - - leaderboard.sort(key=lambda r: (r[0], r[1], r[3]), reverse=True) - return total_rows, base_rate, min_support, leaderboard[:top_k] - - -def iter_rows_with_iter2_from_corrected(corrected_csv: str): - """Yield rows from corrected CSV, attaching an implicit iter2_pred only when - iter1_decision == 1 (then iter2_pred == final_pred). Rows with iter1_decision == 0 - are yielded with iter2_pred = None. - """ - with open(corrected_csv, "r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - for row in reader: - try: - dec = int(row.get("iter1_decision")) - except Exception: - continue - row = dict(row) - row["iter2_pred"] = row.get("final_pred") if dec == 1 else None - yield row - - -def most_common_iter2_for_iter1( - combined_csv: str, - iter1_values: list, - iter1_col: str = "iter1_pred", - iter2_col: str = "iter2_pred", - top_n: int = 20, -): - """For each iter1 in iter1_values, list most common iter2 values with counts. - - Returns dict: iter1_value -> list[(iter2, count, pct)] - """ - counters = {val: Counter() for val in iter1_values} - totals = Counter() - - with open(combined_csv, "r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - for row in reader: - i1 = row.get(iter1_col) - i2 = row.get(iter2_col) - if i1 in counters and i2 is not None: - counters[i1][i2] += 1 - totals[i1] += 1 - - result = {} - for val in iter1_values: - total = totals[val] - ranked = [] - for tok, c in counters[val].most_common(top_n): - pct = (c / total) if total else 0.0 - ranked.append((tok, c, pct)) - result[val] = (total, ranked) - return result - - -def print_leaderboard(title: str, total_rows: int, base_rate: float, min_support: int, leaderboard: list): - print(title) - print(f"Total rows={total_rows}, base_rate={base_rate:.6f}, min_support={min_support}") - print("rank\tprob\tcount\tpositives\ttoken") - for i, (prob, count, pos, token) in enumerate(leaderboard, 1): - print(f"{i}\t{prob:.6f}\t{count}\t{pos}\t{format_token(token)}") - print() - - -def main(): - parser = argparse.ArgumentParser(description="Iter prediction analyses and CSV combiner") - parser.add_argument("--root", default="/share/futianyu/cloud/repo/TaH/local/data/analysis/1.7B", help="Base directory for input/output CSVs") - parser.add_argument("--min_support_ratio", type=float, default=0.00423, help="Minimum support ratio for leaderboards (default 0.1%)") - parser.add_argument("--top_k", type=int, default=50, help="Top-K rows to display in leaderboards") - parser.add_argument("--seeds", nargs="*", default=["\\(", "maybe", "But"], help="iter1_pred seed tokens to analyze for most common iter2_pred (ignored if auto seeds enabled)") - parser.add_argument("--auto_seeds_top_k", type=int, default=2, help="If >0, automatically pick top-K tokens from iter1 leaderboard as seeds") - parser.add_argument("--sankey_top_iter1", type=int, default=2, help="Top-N iter1 seeds to visualize in Sankey (default 2)") - parser.add_argument("--sankey_top_iter2", type=int, default=3, help="Top-N iter2 destinations per seed for Sankey (default 3)") - parser.add_argument("--sankey_html_out", default=None, help="Path to write Sankey HTML (default: /iter_flow_sankey.html)") - args = parser.parse_args() - - # Paths - path_correct = os.path.join(args.root, "tah_decision.labeled_correct.csv") - - # 1) P(iter1_decision=1 | iter1_pred=x) [is_response=True] - tot_r, base_r, min_sup_r, board_r = compute_conditional_probability( - csv_path=path_correct, - group_col="iter1_pred", - decision_col="iter1_decision", - min_support_ratio=args.min_support_ratio, - top_k=args.top_k, - row_filter=lambda r: r.get("is_response") in ("True", True), - ) - print_leaderboard( - title="P(iter1_decision=1 | iter1_pred=x) [is_response=True]", - total_rows=tot_r, - base_rate=base_r, - min_support=min_sup_r, - leaderboard=board_r, - ) - - # 2) No external join; derive iter2_pred from corrected only (where decision==1) - # Show a short verification of counts with iter2 present - iter2_present = 0 - total_rows_tmp = 0 - for row in iter_rows_with_iter2_from_corrected(path_correct): - total_rows_tmp += 1 - if row.get("iter2_pred") is not None: - iter2_present += 1 - print({"rows_in_corrected": total_rows_tmp, "rows_with_iter2_pred": iter2_present}) - print() - - # 3) P(iter1_decision=1 | final_pred (iter2)=x) using corrected only [is_response=True] - # Since iter2_pred exists only when decision==1, we equivalently group by final_pred - # but still compute P(decision==1 | final_pred=x) - tot2, base2, min_sup2, board2 = compute_conditional_probability( - csv_path=path_correct, - group_col="final_pred", - decision_col="iter1_decision", - min_support_ratio=args.min_support_ratio, - top_k=args.top_k, - row_filter=lambda r: r.get("is_response") in ("True", True), - ) - print_leaderboard( - title="P(iter1_decision=1 | final_pred=x) [is_response=True]", - total_rows=tot2, - base_rate=base2, - min_support=min_sup2, - leaderboard=board2, - ) - - # 4) Most common iter2_pred for selected iter1_pred values - # Determine seeds automatically from the top-K iter1 leaderboard if enabled - if args.auto_seeds_top_k and args.auto_seeds_top_k > 0: - auto_seeds = [t for (_, _, _, t) in board_r[: args.auto_seeds_top_k]] - seeds_to_use = auto_seeds - print(f"Selected seeds from top{args.auto_seeds_top_k} iter1 leaderboard: {[format_token(s) for s in seeds_to_use]}") - else: - seeds_to_use = args.seeds - print(f"Using provided seeds: {[format_token(s) for s in seeds_to_use]}") - - print("Most common iter2_pred given iter1_pred seeds [is_response=True & iter1_decision=1]:") - # Compute most common "iter2" as final_pred but only on rows where - # is_response==True and iter1_decision==1, restricted to the selected seeds. - seed_counters = {s: Counter() for s in seeds_to_use} - seed_totals = Counter() - with open(path_correct, "r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - for row in reader: - if row.get("is_response") not in ("True", True): - continue - try: - dec = int(row.get("iter1_decision")) - except Exception: - continue - if dec != 1: - continue - i1 = row.get("iter1_pred") - if i1 not in seed_counters: - continue - i2 = row.get("final_pred") - seed_counters[i1][i2] += 1 - seed_totals[i1] += 1 - - for seed in seeds_to_use: - total = seed_totals[seed] - print(f"iter1_pred={format_token(seed)} total={total}") - for tok, c in seed_counters[seed].most_common(3): - pct = (c / total) if total else 0.0 - print(f" {c}\t{pct:.4f}\t{format_token(tok)}") - print() - - # 5) Two-column Sankey: Pred@Iter1 (passed) -> Pred@Iter2 (+ Others) - sankey_seeds = [t for (_, _, _, t) in board_r[: args.sankey_top_iter1]] - # Aggregate counts for per-destination for passing rows - seed_dest = {s: Counter() for s in sankey_seeds} - with open(path_correct, "r", encoding="utf-8", newline="") as f: - reader = csv.DictReader(f) - for row in reader: - if row.get("is_response") not in ("True", True): - continue - i1 = row.get("iter1_pred") - if i1 not in seed_dest: - continue - try: - dec = int(row.get("iter1_decision")) - except Exception: - continue - if dec == 1: - i2 = row.get("final_pred") - seed_dest[i1][i2] += 1 - - # Two-column Sankey: Pred@Iter1 (passed) -> Pred@Iter2 (+ Others) - # - Column 1: top-2 seeds (passed only) - # - Column 2: union of specified destinations plus an 'Others' bucket - # - Colors: red for 'But', blue for 'So', grey for other tokens, white for 'Others' - - # Decide union of destinations explicitly, defaulting to computed union - explicit_union = ["Wait", "But", "Therefore", "So"] - # Filter to those that actually appear; if none, fall back to computed dest_union - dest_union_2col = [d for d in explicit_union if any(seed_dest[s].get(d, 0) > 0 for s in sankey_seeds)] - if not dest_union_2col: - dest_union = sorted({d for s in sankey_seeds for d in seed_dest.get(s, {}).keys()}) - dest_union_2col = dest_union[:] - others_label = "Others" - - # Label formatting for 2-col Sankey: quote all tokens except Others - def label_for_2col_token(token: str) -> str: - if token == others_label: - return others_label - return f'"{token}"' - - # Pre-compute totals for counts on nodes - total_passed_per_seed = {} - dest_total_counts = {d: 0 for d in dest_union_2col} - others_total_count = 0 - for s in sankey_seeds: - dc = seed_dest.get(s, Counter()) - total_passed = sum(dc.values()) - total_passed_per_seed[s] = total_passed - allocated = 0 - for d in dest_union_2col: - c = dc.get(d, 0) - dest_total_counts[d] += c - allocated += c - others_total_count += max(total_passed - allocated, 0) - - # Sort Column 2 by total amount (descending) - dest_sorted = sorted(dest_union_2col, key=lambda d: dest_total_counts.get(d, 0), reverse=True) - - # Build node labels without counts - col1_nodes_2 = [label_for_2col_token(s) for s in sankey_seeds] - col2_nodes_2 = [label_for_2col_token(d) for d in dest_sorted] + [label_for_2col_token(others_label)] - labels2 = col1_nodes_2 + col2_nodes_2 - - # Color mapping function per token content - def color_for_token(token: str) -> str: - if token == "But": - return "rgba(214, 39, 40, 0.8)" # red - if token == "So": - return "rgba(31, 119, 180, 0.8)" # blue - if token == others_label: - return "rgba(255, 255, 255, 1.0)" # white for Others - return "rgba(127, 127, 127, 0.8)" # grey - - node_colors2 = [] - # Column 1 colors (map from seed token) - for s in sankey_seeds: - node_colors2.append(color_for_token(s)) - # Column 2 colors (map from destination token) - for d in dest_sorted: - node_colors2.append(color_for_token(d)) - # Add Others color - node_colors2.append(color_for_token(others_label)) - - # Build links seeds->destinations using passed counts only - sources2, targets2, values2, link_labels2, link_colors2 = [], [], [], [], [] - def idx_c1(i): - return i - def idx_c2(i): - return len(col1_nodes_2) + i - - for si, s in enumerate(sankey_seeds): - dest_counter = seed_dest.get(s, Counter()) - allocated = 0 - for di, d in enumerate(dest_sorted): - v = dest_counter.get(d, 0) - if v > 0: - sources2.append(idx_c1(si)) - targets2.append(idx_c2(di)) - values2.append(v) - link_labels2.append("") # No labels on links - # Color flows to But/So as red/blue; others semi-transparent grey - if d == "But": - link_colors2.append("rgba(214, 39, 40, 0.6)") - elif d == "So": - link_colors2.append("rgba(31, 119, 180, 0.6)") - else: - link_colors2.append("rgba(0,0,0,0.2)") - allocated += v - # Others - total_passed = sum(dest_counter.values()) - v_other = total_passed - allocated - if v_other > 0: - sources2.append(idx_c1(si)) - targets2.append(idx_c2(len(dest_sorted))) # index of Others node - values2.append(v_other) - link_labels2.append("") # No labels on links - link_colors2.append("rgba(0,0,0,0.1)") - - out_html2 = os.path.join(args.root, "iter_flow_sankey_2col.html") - out_pdf2 = os.path.join(args.root, "iter_flow_sankey_2col.pdf") - - try: - import plotly.graph_objects as go - fig2 = go.Figure( - data=[ - go.Sankey( - arrangement="snap", - node=dict( - label=labels2, - pad=15, - thickness=20, - color=node_colors2, - line=dict(color="black", width=0.5), - ), - link=dict( - source=sources2, - target=targets2, - value=values2, - label=link_labels2, - color=link_colors2, - hovertemplate="%{label}", - ), - ) - ] - ) - - width=600 - height=375 - fig2.update_layout( - font_size=18, - height=height, # 3.5 inches * 300 DPI - width=width, # 6 inches * 300 DPI - ) - - # Save HTML - fig2.write_html(out_html2, include_plotlyjs="cdn") - print({"sankey_2col_saved": out_html2, "nodes": len(labels2), "links": len(values2)}) - - # Export PDF at 300 DPI with size 6x3.5 inches - try: - fig2.write_image(out_pdf2, format="pdf", width=width, height=height, scale=1) - print({"sankey_2col_pdf_saved": out_pdf2, "dpi": 300, "size_inches": [6, 3.5]}) - except Exception as e_img: - print({"sankey_2col_pdf_error": str(e_img), "hint": "pip install -U kaleido for static image export"}) - - except Exception as e: - out_csv2 = os.path.join(args.root, "iter_flow_sankey_2col_links.csv") - with open(out_csv2, "w", encoding="utf-8", newline="") as fo: - w = csv.writer(fo) - w.writerow(["source_label", "target_label", "value"]) - for s_idx, t_idx, v in zip(sources2, targets2, values2): - w.writerow([labels2[s_idx], labels2[t_idx], v]) - print({"sankey_2col_links_csv": out_csv2, "error": str(e)}) - - -if __name__ == "__main__": - main() diff --git a/script/evaluation/baseline_routellm_getscore.py b/script/evaluation/baseline_routellm_getscore.py deleted file mode 100644 index ef74551..0000000 --- a/script/evaluation/baseline_routellm_getscore.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -RouteLLM Batch Scoring Script -- Read dataset (CSV/JSON/JSONL) -- Use RouteLLM router (default: mf) to calculate "strong win rate" score for each question -- Write CSV output in original order: id, score - -Dependencies: - pip install "routellm[serve,eval]" pandas tqdm - -Usage examples: - python baseline_routellm_getscore.py \ - --data /path/to/dataset.jsonl \ - --out /path/to/out.csv - -Different column names: - python baseline_routellm_getscore.py \ - --data /path/to/dataset.csv --id-col qid --text-col prompt --out scores.csv - -Specify router and config: - python baseline_routellm_getscore.py \ - --data ds.csv --out scores.csv \ - --router mf \ - --config /path/to/config.example.yaml - -Load dataset from eval_unified (accuracy mode, corresponding to dataset_configs.json): - python baseline_routellm_getscore.py \ - --dataset-name gsm8k \ - --out scores.csv - # Multiple datasets combined: - python baseline_routellm_getscore.py \ - --dataset-name "gsm8k,math500" \ - --out scores.csv - -Note: Uses mf router by default. The mf router requires OPENAI_API_KEY for embedding calculation. -""" - -import argparse -import json -import sys -from pathlib import Path -from typing import List, Tuple - -import pandas as pd -from tqdm import tqdm - -# RouteLLM imports -from routellm.controller import Controller - -# Unified evaluation dataset loading (accuracy mode) -try: - from tah.evaluate.eval_unified import load_datasets_with_config # Returns (combined_data, field_mapping) -except Exception: - load_datasets_with_config = None # Allow --data branch when tah package is not installed - -def read_dataset(path: Path, id_col: str, text_col: str) -> pd.DataFrame: - suffix = path.suffix.lower() - if suffix == ".csv": - df = pd.read_csv(path) - elif suffix == ".jsonl": - rows = [] - with path.open("r", encoding="utf-8") as f: - for line in f: - if not line.strip(): - continue - obj = json.loads(line) - rows.append(obj) - df = pd.DataFrame(rows) - elif suffix == ".json": - with path.open("r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, dict): - raise ValueError("JSON top level is an object, expected an array.") - df = pd.DataFrame(data) - else: - raise ValueError(f"Unsupported file type: {suffix} (only .csv/.jsonl/.json are supported)") - - # Auto-detect column names (if not explicitly provided and default columns don't exist) - if id_col not in df.columns: - cand = [c for c in ["id", "qid", "question_id", "sample_id"] if c in df.columns] - if cand: - id_col = cand[0] - else: - raise ValueError(f"ID column not found (tried default '{id_col}', also no common candidates in {df.columns.tolist()})") - if text_col not in df.columns: - cand = [c for c in ["question", "prompt", "input", "text"] if c in df.columns] - if cand: - text_col = cand[0] - else: - raise ValueError(f"Question column not found (tried default '{text_col}', also no common candidates in {df.columns.tolist()})") - - # Keep only these two columns and preserve order - df = df[[id_col, text_col]].copy() - df.columns = ["id", "question"] - return df - - -def build_controller(router_name: str, config_path: str = None, - strong_model: str = None, weak_model: str = None) -> Controller: - """ - Build RouteLLM Controller. - - routers: Pass one or more router names (here we only use one) - - Optional: Pass strong/weak model names (doesn't affect scoring, mainly for actual routing scenarios) - - Optional: Pass config file (overrides default configuration) - """ - kwargs = dict(routers=[router_name]) - if config_path: - kwargs["config"] = config_path - kwargs["strong_model"] = strong_model - kwargs["weak_model"] = weak_model - client = Controller(**kwargs) - return client - - -def score_one(controller: Controller, router_name: str, question: str) -> float: - """ - Calculate "strong win rate" score for a single question. - Note: - - RouteLLM documentation states that each router implements `calculate_strong_win_rate(prompt)->float` - - Controller holds router instances internally. Attribute names may differ across versions, so we support two access methods. - """ - # Method A: Prefer public/semi-public router table - if hasattr(controller, "routers") and isinstance(controller.routers, dict): - router = controller.routers.get(router_name) - if router and hasattr(router, "calculate_strong_win_rate"): - return float(router.calculate_strong_win_rate(question)) - - # Method B: Some versions use _routers private attribute - if hasattr(controller, "_routers") and isinstance(controller._routers, dict): - router = controller._routers.get(router_name) - if router and hasattr(router, "calculate_strong_win_rate"): - return float(router.calculate_strong_win_rate(question)) - - raise RuntimeError( - "Failed to get router instance or its calculate_strong_win_rate method from Controller. " - "Please upgrade routellm version or ensure router name is correct when building Controller." - ) - - -def main(): - parser = argparse.ArgumentParser(description="RouteLLM Batch Scoring") - group = parser.add_mutually_exclusive_group(required=True) - group.add_argument("--data", type=str, help="Dataset path (.csv/.jsonl/.json)") - group.add_argument("--dataset-name", type=str, help="eval_unified dataset name (comma-separated for multiple)") - parser.add_argument("--out", required=True, type=str, help="Output CSV path") - parser.add_argument("--id-col", default="id", type=str, help="ID column name in dataset when using --data (default: id)") - parser.add_argument("--text-col", default="question", type=str, help="Text column name in dataset when using --data (default: question)") - parser.add_argument("--router", default="mf", type=str, help="Router name (default: mf)") - parser.add_argument("--config", default=None, type=str, help="RouteLLM config YAML path (optional)") - parser.add_argument("--strong-model", default=None, type=str, help="Strong model name (optional)") - parser.add_argument("--weak-model", default=None, type=str, help="Weak model name (optional)") - parser.add_argument("--batch-size", default=64, type=int, help="Batch size (this script scores one by one, batch size only for display pacing)") - args = parser.parse_args() - - out_path = Path(args.out) - out_path = out_path/ f"{args.router}/{args.dataset_name}.csv" - out_path.parent.mkdir(parents=True, exist_ok=True) - - # Read data - if getattr(args, "dataset_name", None): - if load_datasets_with_config is None: - raise RuntimeError("Failed to import eval_unified dataset loading function. Please ensure tah package is in PYTHONPATH or use --data method.") - # Align with eval_unified: support comma-separated or single dataset names - dataset_names = [name.strip() for name in args.dataset_name.split(',') if name.strip()] - combined_data, field_mapping = load_datasets_with_config(dataset_names) - # Use standard fields (eval_unified already unified to id/question/answer/...) - df = pd.DataFrame(combined_data) - if not {"id", "question"}.issubset(df.columns): - raise ValueError("Standard columns 'id' and 'question' not found after loading from eval_unified.") - df = df[["id", "question"]].copy() - else: - if not args.data: - raise ValueError("Must provide either --data or --dataset-name.") - data_path = Path(args.data) - df = read_dataset(data_path, id_col=args.id_col, text_col=args.text_col) - - # Build Controller (no actual chat generation, only use router scores) - controller = build_controller( - router_name=args.router, - config_path=args.config, - strong_model=args.strong_model, - weak_model=args.weak_model, - ) - - # Calculate scores one by one (in original order) - scores: List[Tuple] = [] - pbar = tqdm(total=len(df), desc=f"Scoring with RouteLLM[{args.router}]") - for i, row in df.iterrows(): - qid = row["id"] - qtext = row["question"] - try: - score = score_one(controller, args.router, qtext) - except Exception as e: - # Return NaN on failure; print error to STDERR - print(f"[WARN] id={qid} scoring failed: {e}", file=sys.stderr) - score = float("nan") - scores.append((qid, score)) - pbar.update(1) - pbar.close() - - # Write result CSV - out_df = pd.DataFrame(scores, columns=["id", "score"]) - # Preserve order: pandas maintains insertion order; for safety, re-align with original df order: - order = {k: idx for idx, k in enumerate(df["id"].tolist())} - out_df["__order"] = out_df["id"].map(order) - out_df = out_df.sort_values("__order").drop(columns="__order") - out_df.to_csv(out_path, index=False, encoding="utf-8") - print(f"✅ Write complete: {out_path}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/script/evaluation/compute_routellm_threshold_accuracy.py b/script/evaluation/compute_routellm_threshold_accuracy.py deleted file mode 100644 index 87ee626..0000000 --- a/script/evaluation/compute_routellm_threshold_accuracy.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import argparse -import math -import sys -from pathlib import Path -from typing import Dict, Tuple - -import pandas as pd - - -def load_problem_stats(detailed_csv: Path) -> Tuple[Dict[str, float], Dict[str, bool], Dict[str, str]]: - """Load detailed_results.csv and compute per-problem stats: - - mean_correct_map: problem_id -> mean(is_correct) across samples (float in [0,1]) - - any_correct_map: problem_id -> any(is_correct) across samples (bool) - - pred_map: problem_id -> representative predicted_answer (prefer a correct sample, else first non-empty) - """ - df = pd.read_csv(detailed_csv) - # print(df) - # Normalize id types - if "problem_id" in df.columns: - df["problem_id"] = df["problem_id"].astype(str).str.strip() - required_cols = {"problem_id", "is_correct"} - if not required_cols.issubset(df.columns): - raise ValueError(f"{detailed_csv} 缺少必要列: {required_cols},实际列: {df.columns.tolist()}") - - # Determine predicted_answer column if present - pred_col = "predicted_answer" if "predicted_answer" in df.columns else None - - # mean correctness per problem - mean_correct = df.groupby("problem_id")["is_correct"].mean().reset_index() - mean_correct_map = dict(zip(mean_correct["problem_id"], mean_correct["is_correct"].astype(float))) - - # any correctness per problem - any_correct = df.groupby("problem_id")["is_correct"].any().reset_index() - any_correct_map = dict(zip(any_correct["problem_id"], any_correct["is_correct"].astype(bool))) - - # Representative predicted_answer per problem - pred_map: Dict[str, str] = {} - if pred_col is None: - # No predicted column; leave empty strings - for pid in df["problem_id"].unique(): - pred_map[pid] = "" - return mean_correct_map, any_correct_map, pred_map - - # Prefer first correct sample's predicted_answer; else first non-empty - for pid, grp in df.groupby("problem_id", sort=False): - rep = "" - # Prefer correct - corr_rows = grp[grp["is_correct"] == True] - if not corr_rows.empty: - val = corr_rows.iloc[0][pred_col] - rep = "" if pd.isna(val) else str(val) - else: - # First non-empty - non_empty = grp[grp[pred_col].notna() & (grp[pred_col].astype(str).str.len() > 0)] - if not non_empty.empty: - rep = str(non_empty.iloc[0][pred_col]) - else: - rep = "" - pred_map[pid] = rep - - return mean_correct_map, any_correct_map, pred_map - - -def main(): - parser = argparse.ArgumentParser(description="Compute Routellm thresholded accuracy by aligning IDs and selecting models by score threshold.") - parser.add_argument("--scores", required=True, type=str, help="RouteLLM scores CSV (columns: id, score)") - parser.add_argument("--strong-detailed", required=True, type=str, help="Strong model detailed_results.csv") - parser.add_argument("--weak-detailed", required=True, type=str, help="Weak model detailed_results.csv") - parser.add_argument("--strong-size", default=4.0, type=float, help="Strong model params in B (default 4.0)") - parser.add_argument("--weak-size", default=0.6, type=float, help="Weak model params in B (default 0.6)") - parser.add_argument("--target-avg", default=1.7, type=float, help="Target average params in B (default 1.7)") - parser.add_argument("--rounding", default="round", choices=["round", "floor", "ceil"], help="How to convert proportion to count (default round)") - parser.add_argument("--save-csv", default=None, type=str, help="可选:保存详细路由与正误结果到该 CSV 文件") - args = parser.parse_args() - - scores_path = Path(args.scores) - strong_path = Path(args.strong_detailed) - weak_path = Path(args.weak_detailed) - - # Load scores - scores_df = pd.read_csv(scores_path) - if not {"id", "score"}.issubset(scores_df.columns): - raise ValueError(f"{scores_path} 需要包含列 ['id','score'],实际列: {scores_df.columns.tolist()}") - # Normalize ids - scores_df["id"] = scores_df["id"].astype(str).str.strip() - - # Compute strong proportion p s.t. p*strong + (1-p)*weak = target - denom = (args.strong_size - args.weak_size) - if denom <= 0: - raise ValueError("strong_size 必须大于 weak_size") - p_strong = (args.target_avg - args.weak_size) / denom - p_strong = max(0.0, min(1.0, p_strong)) - - # Decide top-K by score to assign to strong - n = len(scores_df) - if args.rounding == "round": - k = int(round(p_strong * n)) - elif args.rounding == "floor": - k = int(math.floor(p_strong * n)) - else: - k = int(math.ceil(p_strong * n)) - k = max(0, min(n, k)) - - # Sort descending by score; tie-breaker: stable by original order - scores_sorted = scores_df.sort_values(["score", "id"], ascending=[False, True]).reset_index(drop=True) - # Determine threshold value for reporting (score at position k-1) - threshold = float("nan") - if k > 0: - threshold = float(scores_sorted.iloc[k - 1]["score"]) if k - 1 < len(scores_sorted) else float("nan") - - # Build model assignment - assign_series = pd.Series(["strong"] * k + ["weak"] * (n - k)) - assign_df = scores_sorted.copy() - assign_df["assigned_model"] = assign_series.values - - # Load per-problem stats - strong_mean_map, strong_any_map, strong_pred_map = load_problem_stats(strong_path) - weak_mean_map, weak_any_map, weak_pred_map = load_problem_stats(weak_path) - - # Report missing ids coverage - ids_set = set(scores_df["id"].tolist()) - miss_strong = [qid for qid in ids_set if qid not in strong_mean_map] - # print(('olympiadbench_1606' in ids_set), ('olympiadbench_1606' in strong_mean_map)) - miss_weak = [qid for qid in ids_set if qid not in weak_mean_map] - if miss_strong: - print(f"[WARN] {len(miss_strong)} ids not found in strong detailed_results (treated as incorrect)") - if miss_weak: - print(f"[WARN] {len(miss_weak)} ids not found in weak detailed_results (treated as incorrect)") - - # Join mean correctness by assignment - def pick_correct_mean(row): - qid = row["id"] - if row["assigned_model"] == "strong": - return float(strong_mean_map.get(qid, 0.0)) - return float(weak_mean_map.get(qid, 0.0)) - - assign_df["final_correct_mean"] = assign_df.apply(pick_correct_mean, axis=1) - - # Attach model-wise representative predicted answers and correctness - assign_df["strong_predicted"] = assign_df["id"].map(lambda x: strong_pred_map.get(x, "")) - assign_df["strong_correct_mean"] = assign_df["id"].map(lambda x: float(strong_mean_map.get(x, 0.0))) - assign_df["strong_is_correct_any"] = assign_df["id"].map(lambda x: bool(strong_any_map.get(x, False))) - assign_df["weak_predicted"] = assign_df["id"].map(lambda x: weak_pred_map.get(x, "")) - assign_df["weak_correct_mean"] = assign_df["id"].map(lambda x: float(weak_mean_map.get(x, 0.0))) - assign_df["weak_is_correct_any"] = assign_df["id"].map(lambda x: bool(weak_any_map.get(x, False))) - - # Final routed predicted - # print(assign_df["strong_predicted"]) - def pick_pred(row): - # if row['assigned_model'] == 'strong': - # print(row["strong_predicted"]) - return row["strong_predicted"] if row["assigned_model"] == "strong" else row["weak_predicted"] - - assign_df["final_predicted"] = assign_df.apply(pick_pred, axis=1) - - # Compute accuracy as average of per-problem mean correctness - total = len(assign_df) - sum_mean = float(assign_df["final_correct_mean"].sum()) - acc = sum_mean / total if total > 0 else 0.0 - - # Report - print("Routing with target average params:") - print(f" strong_size={args.strong_size}B, weak_size={args.weak_size}B, target_avg={args.target_avg}B") - print(f" proportion_strong={p_strong:.6f} -> select_top_k={k}/{n}") - print(f" score_threshold≈{threshold}") - print("") - print("Final accuracy (mean across samples per problem):") - print(f" avg-correctness-sum={sum_mean:.2f} / total={total} -> accuracy={acc:.4f}") - - # Also print a small breakdown - strong_correct_mean = float(assign_df.loc[assign_df["assigned_model"] == "strong", "final_correct_mean"].sum()) - weak_correct_mean = float(assign_df.loc[assign_df["assigned_model"] == "weak", "final_correct_mean"].sum()) - print("") - print("Breakdown:") - print(f" strong: {k} selected, mean-correct-sum={strong_correct_mean:.2f}") - print(f" weak: {n - k} selected, mean-correct-sum={weak_correct_mean:.2f}") - - # Optional save CSV with detailed assignment and outputs - if args.save_csv: - out_cols = [ - "id", "score", "assigned_model", - "strong_predicted", "strong_correct_mean", "strong_is_correct_any", - "weak_predicted", "weak_correct_mean", "weak_is_correct_any", - "final_predicted", "final_correct_mean", - ] - save_df = assign_df[out_cols].copy() - out_path = Path(args.save_csv) - out_path.parent.mkdir(parents=True, exist_ok=True) - save_df.to_csv(out_path, index=False, encoding="utf-8") - print("") - print(f"Saved detailed routing CSV to: {out_path}") - - -if __name__ == "__main__": - try: - main() - except Exception as e: - print(f"[ERROR] {e}", file=sys.stderr) - sys.exit(1) - - diff --git a/script/evaluation/eval.py b/script/evaluation/eval.py old mode 100755 new mode 100644 index 179d530..37f1c8f --- a/script/evaluation/eval.py +++ b/script/evaluation/eval.py @@ -1,44 +1,54 @@ +"""CLI entrypoint for the multi-backend, multi-job eval driver. + +Wraps :func:`tah.evaluate.allocate_gpus_and_run_jobs`. Run e.g.:: + + python script/evaluation/eval.py \\ + --eval_config script/recipes/qwen3_1.7/eval_tah.yaml \\ + --model_path nics-efc/TaH-plus-1.7B \\ + --dataset_name gsm8k --backend tah \\ + --job_nums 8 --tp_size_per_job 1 +""" +from __future__ import annotations + import argparse -from tah.evaluate.eval_unified import allocate_gpus_and_run_jobs -def main(args): - """Main coordinator function""" - from transformers.utils import logging as hf_logging +from transformers.utils import logging as hf_logging + +from tah.evaluate import allocate_gpus_and_run_jobs + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="TaH multi-backend evaluation driver") + p.add_argument("--eval_config", required=True, help="Path to YAML eval recipe") + p.add_argument("--model_path", required=True, help="Path or HF id of the model to eval") + p.add_argument("--backend", choices=("sglang", "hf", "tah"), default="hf", + help="Inference backend (default: hf)") + p.add_argument("--dataset_name", required=True, + help='Dataset name(s); comma-separated for multi (e.g. "aime24,math500"). ' + "Must appear in eval_configs/dataset_configs.json.") + p.add_argument("--job_nums", type=int, default=1, help="Number of parallel jobs to fan out") + p.add_argument("--tp_size_per_job", type=int, default=1, help="GPUs (tensor-parallel size) per job") + p.add_argument("--output_dir", default=None, + help="Output directory; defaults to /eval_results") + p.add_argument("--data_range", type=int, nargs="+", default=None, + help="Subset slice — [end] or [start, end]") + p.add_argument("--data_ids", default=None, + help='Comma-separated specific problem IDs (e.g. "gsm8k_0,gsm8k_5"); overrides --data_range') + p.add_argument("--del_job_dir", type=bool, default=True, + help="Delete per-job directories after combining results") + p.add_argument("--logger_level", default="WARNING", + help="Logger level (DEBUG / INFO / WARNING / ERROR / CRITICAL)") + p.add_argument("--random_seed", type=int, default=42, help="Per-job random seed") + return p + + +def main(args: argparse.Namespace) -> None: level = getattr(hf_logging, args.logger_level.upper(), hf_logging.WARNING) hf_logging.set_verbosity(level) hf_logging.enable_default_handler() hf_logging.enable_propagation() - allocate_gpus_and_run_jobs(args) + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run LLM inference and evaluation with multiple backends") - parser.add_argument("--eval_config", type=str, required=True, help="Path to YAML configuration file") - parser.add_argument("--backend", type=str, choices=['sglang', 'hf', 'tah'], default='hf', - help="Inference backend to use: 'sglang', 'hf', or 'tah' (default: hf)") - - # Add job-based processing arguments - parser.add_argument('--job_nums', type=int, default=1, - help='Total number of jobs to split the dataset into') - parser.add_argument('--tp_size_per_job', type=int, default=1, - help='Number of GPUs (tensor parallel size) per job') - parser.add_argument('--model_path', type=str, default=None, - help='Path to the model') - parser.add_argument('--output_dir', type=str, default=None, - help='Path to the output directory, default is model_path') - parser.add_argument('--dataset_name', type=str, default=None, - help='Name of the dataset to use (supports multiple datasets separated by commas, e.g., "aime24,math500")') - parser.add_argument('--data_range', type=int, nargs='+', default=None, - help='Data range: either [end] or [start, end]') - parser.add_argument('--data_ids', type=str, default=None, - help='Comma-separated indices to evaluate, e.g., "0,5,6,15". If provided, overrides --data_range') - parser.add_argument('--del_job_dir', type=bool, default=True, - help='Delete job directory after evaluation') - parser.add_argument('--logger_level', type=str, default='WARNING', - help='Logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL)') - parser.add_argument('--random_seed', type=int, default=42, - help='Random seed for evaluation') - args = parser.parse_args() - - main(args) + main(_build_parser().parse_args()) diff --git a/script/playground/inference_example.py b/script/playground/inference_example.py index 6932a77..3d2bbc0 100644 --- a/script/playground/inference_example.py +++ b/script/playground/inference_example.py @@ -1,141 +1,76 @@ +"""End-to-end demo: load TaH-plus-1.7B from HF Hub and run sampling generation +with per-token iter-count colouring. + +Run: + python script/playground/inference_example.py +""" import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -import os +from transformers import AutoTokenizer -from tah.model.recurrent_transformer import TaHForCausalLM -from tah.model.tracker import TaHTracker +from tah.model.tah_model import TaHForCausalLM from tah.model.utils import IterCountColors, TaHForCasualLM_generate -# Fix random seed for reproducibility -torch.manual_seed(42) def main(): - """ - Initializations - """ - - save_model_name = "nics-efc/TaH-plus-1.7B" + model_name = "nics-efc/TaH-plus-1.7B" device_map = "cuda:0" - tokenizer = AutoTokenizer.from_pretrained(save_model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - override_config = None - tah_model = TaHForCausalLM.from_pretrained( - save_model_name, + model_name, torch_dtype=torch.bfloat16, device_map=device_map, attn_implementation="sdpa", - tah_config=override_config, ) + print(f"Device: {tah_model.device}, Dtype: {tah_model.dtype}") - device = tah_model.device - dtype = tah_model.dtype - print(f"Device: {device}, Dtype: {dtype}") - - tah_model = tah_model.to(dtype=dtype) - - # Attach tracker - tracker = TaHTracker(top_k=10) - tracker.attach(tah_model) - - """ - Input and run - """ - - # prepare the model input prompts = [ - "Six points $A, B, C, D, E$ and $F$ lie in a straight line in that order. Suppose that $G$ is a point not on the line and that $AC = 26$, $BD = 22$, $CE = 31$, $DF = 33$, $AF = 73$, $CG = 40$, and $DG = 30$. Find the area of $\\triangle BGE$." + "Six points $A, B, C, D, E$ and $F$ lie in a straight line in that order. " + "Suppose that $G$ is a point not on the line and that $AC = 26$, $BD = 22$, " + "$CE = 31$, $DF = 33$, $AF = 73$, $CG = 40$, and $DG = 30$. " + "Find the area of $\\triangle BGE$.", ] - - # Process each prompt through chat template - texts = [] - for prompt in prompts: - messages = [{"role": "user", "content": prompt}] - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=True, # Switches between thinking and non-thinking modes. Default is True. + texts = [ + tokenizer.apply_chat_template( + [{"role": "user", "content": p}], + tokenize=False, add_generation_prompt=True, enable_thinking=True, ) - texts.append(text) - + for p in prompts + ] model_inputs = tokenizer( - texts, return_tensors="pt", padding=True, padding_side="left" - ).to(device=device) - batch_size = model_inputs.input_ids.shape[0] - - print("Initial input:") - for i in range(batch_size): - print(f"\nSample {i+1}:") - print(f"Prompt: {prompts[i][:100]}{'...' if len(prompts[i]) > 100 else ''}") - print(f"Input IDs shape: {model_inputs.input_ids.shape}") + texts, return_tensors="pt", padding=True, padding_side="left", + ).to(device=tah_model.device) + print("\nInitial input:") + for i, p in enumerate(prompts): + print(f"Sample {i+1}: {p[:100]}{'…' if len(p) > 100 else ''}") + print(f"Input IDs shape: {tuple(model_inputs.input_ids.shape)}") print(IterCountColors.get_legend()) - # Use the generation function with sampling output_tokens, final_texts = TaHForCasualLM_generate( tah_model=tah_model, tokenizer=tokenizer, - model_inputs=model_inputs, - iter_count=None, # Use automatic iteration from iter_decider + model_inputs=dict(model_inputs), max_new_tokens=16384, do_sample=True, temperature=0.6, top_p=0.95, top_k=20, min_p=0.0, - verbose=True + verbose=True, ) - # analyze the token count for batch - batch_size = model_inputs.input_ids.shape[0] - max_input_length = model_inputs.input_ids.shape[1] - print("\n" + "=" * 50) - print("TOKEN COUNT ANALYSIS") - print("=" * 50) - print(f"Batch size: {batch_size}") - print(f"Max input length (with padding): {max_input_length}") - - # Calculate actual input lengths (excluding padding) - actual_input_lengths = [] - for i in range(batch_size): - actual_length = ( - (model_inputs.input_ids[i] != tokenizer.pad_token_id).sum().item() - ) - actual_input_lengths.append(actual_length) - print(f"Sample {i+1} actual input length: {actual_length}") - - # For generated tokens, we now get a list of lists for each batch item - generated_counts = [len(seq) for seq in output_tokens] - total_generated = sum(generated_counts) - print(f"Generated tokens per sample: {generated_counts}") - print(f"Total generated tokens: {total_generated}") - print("=" * 50) - - # Print final generated texts for each sample in the batch - print("\nFINAL GENERATED TEXTS:") + print("FINAL GENERATED TEXTS") print("=" * 50) for i, text in enumerate(final_texts): - print(f"\nSample {i+1} output:") + print(f"\nSample {i+1} ({len(output_tokens[i])} tokens):") print("-" * 30) print(text) - # Display recorded information from tracker - record_pd = tracker.to_pandas() - print(record_pd) - - # Save tracker records to CSV - if not os.path.exists("output/analysis"): - os.makedirs("output/analysis") - record_pd.to_csv("output/analysis/tracker_records.csv", index=False) - - tracker.detach() - - print(f"\nBatch inference completed successfully for {batch_size} samples!") if __name__ == "__main__": main() diff --git a/script/preparation/label.py b/script/preparation/label.py old mode 100755 new mode 100644 index 54a7909..72a9ac7 --- a/script/preparation/label.py +++ b/script/preparation/label.py @@ -1,1015 +1,533 @@ -""" -Step1 use SLM to prefill the LLM responses, finding all non-identical SLM next-token predictions. -Multi-GPU version to read JSONL or JSON format with conversations structure. +"""Step-0 of the SFT pipeline: produce per-token "mismatch" labels. -Inputs: -- A JSONL file (.jsonl extension) with conversations format (one JSON object per line). - - Each line contains: {"conversations": [{"from": "human", "value": "..."}, {"from": "assistant", "value": "..."}], "system": "..."} -- Or a JSON file (.json extension) with conversations format (JSON array). - - Contains an array of objects: [{"conversations": [{"from": "human", "value": "..."}, {"from": "assistant", "value": "..."}], "system": "..."}, ...] +For every assistant token in a labelled-conversation dataset, we use a small +language model (the "test model") to predict the next token from the prefix +and compare it to the ground-truth next token. Tokens where the prediction +disagrees are marked ``mismatch=1`` — those are the "hard" tokens the TaH +adapter learns to refine via extra iterations. -Outputs: -- Processed dataset with data grouped by data_id, containing real_text, real_token, mask, and mismatch information +Inputs: + --dataset_path .jsonl|.json + Each record: ``{"conversations": [{"from": "human"|"assistant", "value": ...}], "system": ...}`` + Or, alternatively, ``{"problem"|"question": ..., "output"|"solution"|"answer": ...}``. + +Outputs (under ``--output_path``): + - HuggingFace ``Dataset`` with columns: ``data_id``, ``real_text``, + ``real_token``, ``mask`` (1 for assistant tokens), ``mismatch`` + (1 where SLM next-token != GT next-token), and optional ``entropy`` / + ``cross_entropy`` columns. + - ``args.json`` capturing the CLI flags this run was invoked with. + +Public TaH supported MobileLLM and DeepSeek-R1 chat templates in addition +to Qwen3. The cleaned version only supports Qwen3 — the canonical SFT +recipes use ``Qwen/Qwen3-1.7B`` as the labeller. Trim the others if you +need them by reintroducing the per-model think-token IDs in +:func:`_categorize_masks` and the per-model template in :func:`_format_prompt`. """ +from __future__ import annotations +import argparse +import gc import json +import multiprocessing as mp import os -import argparse +import shutil import signal import sys -from tqdm import tqdm -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from typing import List, Optional, Tuple + +import numpy as np +import pandas as pd import torch import torch.nn.functional as F +from datasets import Dataset, concatenate_datasets from torch.nn.utils.rnn import pad_sequence -import pandas as pd -import numpy as np -import multiprocessing as mp -from datasets import Dataset, concatenate_datasets, DatasetDict +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer -from tah.utils.sampling import sample_token +from tah.model.utils import sample_next_token -# Global variable to track running processes -running_processes = [] -SYSTEM_PROMPT = """ -You are a helpful assistant. To answer the user's question, you first think about the reasoning process and then provide the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . -""" -def signal_handler(signum, frame): - """Handle interrupt signals gracefully""" - print(f"\nReceived signal {signum}, cleaning up processes...") - global running_processes - - for p in running_processes: +# ──────────────────────────────────────────────────────────────────────────── +# Process management (graceful Ctrl-C, multi-process orchestration) +# ──────────────────────────────────────────────────────────────────────────── + + +_running_processes: List[mp.Process] = [] +QWEN3_THINK_TOKEN_ID = 151667 # token in Qwen3 tokenizer + + +def _signal_handler(signum, frame) -> None: + """SIGINT/SIGTERM: terminate child workers; force-kill if they hang.""" + print(f"\nReceived signal {signum}, cleaning up processes…") + for p in _running_processes: + if not p.is_alive(): + continue + print(f"Terminating process {p.pid}…") + p.terminate() + p.join(timeout=60) if p.is_alive(): - print(f"Terminating process {p.pid}...") - p.terminate() - p.join(timeout=60) # Give 60 seconds for graceful termination (increased from 5 seconds) - if p.is_alive(): - print(f"Force killing process {p.pid}...") - p.kill() - p.join() - - print("Cleanup completed, exiting...") + print(f"Force killing process {p.pid}…") + p.kill() + p.join() sys.exit(0) -def load_model(model_name, device_id): - """Load a model on specific GPU with basic error handling""" - try: - model_config = AutoConfig.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained( - model_name, - config=model_config, - device_map=f"cuda:{device_id}", - torch_dtype=torch.bfloat16 - ).eval() - print(f"Model {model_name} loaded successfully on GPU {device_id}!") - return model - except Exception as e: - print(f"Error loading model on GPU {device_id}: {e}") - return None +# ──────────────────────────────────────────────────────────────────────────── +# Dataset I/O +# ──────────────────────────────────────────────────────────────────────────── -def load_jsonl_json_dataset(file_path, index_range=None, random_num=None): - """Load dataset from JSONL or JSON file based on file extension""" - data = [] - - # Determine file format based on extension - file_extension = os.path.splitext(file_path)[1].lower() - - if file_extension == '.jsonl': - # Load JSONL format (one JSON object per line) - with open(file_path, 'r', encoding='utf-8') as f: - for line in f: - if line.strip(): - data.append(json.loads(line.strip())) - elif file_extension == '.json': - # Load JSON format (single JSON array) - with open(file_path, 'r', encoding='utf-8') as f: - loaded_data = json.load(f) - # If it's a list, use it directly; if it's a dict, wrap it in a list - if isinstance(loaded_data, list): - data = loaded_data - else: - data = [loaded_data] + +def _load_dataset(path: str, index_range: Optional[Tuple[int, int]], random_num: Optional[int]) -> List[dict]: + """Load a ``.jsonl`` or ``.json`` file into a list of dicts. + + ``index_range`` slices ``[start, end)``; ``random_num`` then randomly + subsamples (seeded) when smaller than the slice. + """ + ext = os.path.splitext(path)[1].lower() + if ext == ".jsonl": + with open(path, "r", encoding="utf-8") as f: + data = [json.loads(line) for line in f if line.strip()] + elif ext == ".json": + with open(path, "r", encoding="utf-8") as f: + loaded = json.load(f) + data = loaded if isinstance(loaded, list) else [loaded] else: - # Default to JSONL format for unknown extensions - print(f"Warning: Unknown file extension '{file_extension}'. Trying to read as JSONL format.") - with open(file_path, 'r', encoding='utf-8') as f: - for line in f: - if line.strip(): - data.append(json.loads(line.strip())) - - print(f"Loaded {len(data)} samples from {file_extension if file_extension else 'unknown'} format file: {file_path}") - + print(f"Warning: unknown extension {ext!r}; trying JSONL") + with open(path, "r", encoding="utf-8") as f: + data = [json.loads(line) for line in f if line.strip()] + + print(f"Loaded {len(data)} samples from {path}") if index_range: - start_idx, end_idx = index_range - data = data[start_idx:end_idx] - print(f"Selected range [{start_idx}:{end_idx}], resulting in {len(data)} samples") - - # Apply random sampling if random_num is specified - if random_num is not None and random_num > 0 and random_num < len(data): + s, e = index_range + data = data[s:e] + print(f"Sliced [{s}:{e}] → {len(data)} samples") + if random_num and 0 < random_num < len(data): import random - random.seed(42) # Set seed for reproducibility + random.seed(42) data = random.sample(data, random_num) - print(f"Randomly sampled {random_num} samples from {len(data)} available samples") - elif random_num is not None and random_num >= len(data): - print(f"random_num ({random_num}) is >= dataset size ({len(data)}), using all samples") - + print(f"Random-sampled {random_num} of available") return data -def split_dataset(dataset, num_splits): - """Split dataset into num_splits parts""" - chunk_size = len(dataset) // num_splits - remainder = len(dataset) % num_splits - - splits = [] - start_idx = 0 - - for i in range(num_splits): - # Add one extra item to first 'remainder' splits - current_chunk_size = chunk_size + (1 if i < remainder else 0) - end_idx = start_idx + current_chunk_size - - splits.append((start_idx, end_idx)) - start_idx = end_idx - - return splits - - -def parse_conversations(conversations): - """Parse conversations to extract input_text, model_reasoning, and model_response""" - input_text = None - assistant_response = None - - for conv in conversations: - if conv["from"] == "human" or conv["from"] == "user": - input_text = conv["value"] - elif conv["from"] == "assistant": - assistant_response = conv["value"] - - if not input_text or not assistant_response: +def _split_indices(n: int, k: int) -> List[Tuple[int, int]]: + """Split ``[0, n)`` into ``k`` near-equal contiguous ranges.""" + base, rem = divmod(n, k) + out, s = [], 0 + for i in range(k): + size = base + (1 if i < rem else 0) + out.append((s, s + size)) + s += size + return out + + +# ──────────────────────────────────────────────────────────────────────────── +# Prompt formatting (Qwen3-only) +# ──────────────────────────────────────────────────────────────────────────── + + +def _parse_conversations(conversations: List[dict]) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """Pull (user, reasoning, response) out of a "conversations" list. + + A response wrapped in ``...`` is split into reasoning and + final response; otherwise ``(reasoning, response) = (None, None)`` and + the caller should skip the sample. + """ + user, asst = None, None + for c in conversations: + if c["from"] in ("human", "user"): + user = c["value"] + elif c["from"] == "assistant": + asst = c["value"] + if not user or not asst: return None, None, None - - # Split assistant response into reasoning and response parts - if "" in assistant_response and "" in assistant_response: - # Extract thinking content - think_start = assistant_response.find("") - think_end = assistant_response.find("") + len("") - - model_reasoning = assistant_response[think_start + len(""):assistant_response.find("")].strip() - model_response = assistant_response[think_end:].strip() - else: - # No thinking tags, treat entire response as final response - model_reasoning = None - model_response = None - - return input_text, model_reasoning, model_response - - -def apply_qwen_r1_chat_template(messages, add_generation_prompt=False): - """Apply the Qwen R1 chat template to the messages""" - prompt = "<|begin▁of▁sentence|>" - ns = { - "is_first": False, - "is_tool": False, - "is_output_first": True, - "system_prompt": "", - } - - # extract system prompt - for message in messages: - if message["role"] == "system": - ns["system_prompt"] = message["content"] - - prompt += ns["system_prompt"] - - for message in messages: - if message["role"] == "user": - ns["is_tool"] = False - prompt += "<|User|>" + message["content"] - - elif message["role"] == "assistant" and message["content"] is not None: - content = message["content"] - prompt += "<|Assistant|>" + content + "<|end▁of▁sentence|>" - - if add_generation_prompt: - prompt += "<|Assistant|>\n" - - return prompt - -def replace_mobilellm_think(messages): - """Replace the think tag with the think tag in the messages""" - for message in messages: - if message["role"] == "assistant" and message["content"] is not None: - message["content"] = message["content"].replace("", "<|think|>").replace("", "<|/think|>") - return messages - -def get_formatted_prompt_1(sample, tokenizer, model_name): - """Format prompt from conversations structure""" - question = sample.get("problem", "") or sample.get("question", "") - answer = sample.get("output", "") or sample.get("solution", "") or sample.get("generation", "") or sample.get("answer", "") - - messages = [ - {"role": "user", "content": question}, - {"role": "assistant", "content": answer}, - ] - - if "qwen3" in model_name.lower(): - prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, enable_thinking=False) - elif "mobilellm" in model_name.lower(): - messages = replace_mobilellm_think(messages) - prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, use_system_prompt=False) - else: - prompt = apply_qwen_r1_chat_template(messages, add_generation_prompt=False) - - return prompt - -def get_formatted_prompt(sample, tokenizer, model_name): - """Format prompt from conversations structure""" - conversations = sample.get("conversations", []) - system_prompt = sample.get("system", "") - - # Parse conversations - input_text, model_reasoning, model_response = parse_conversations(conversations) - - if not input_text or model_response is None: - print(f"Invalid conversation format, skipping") + if "" in asst and "" in asst: + s = asst.find("") + len("") + e = asst.find("") + return user, asst[s:e].strip(), asst[e + len(""):].strip() + return None, None, None + + +def _format_prompt(sample: dict, tokenizer) -> Optional[str]: + """Build the chat-templated prompt that we'll prefill on the SLM. + + Two record formats are accepted: + * ``{"conversations": [...], "system": ...}`` (tracker-style) + * ``{"problem"|"question": ..., "output"|"solution"|"answer"|...: ...}`` + """ + if "conversations" in sample: + user, reasoning, response = _parse_conversations(sample["conversations"]) + if not user or response is None: + return None + content = f"{reasoning}\n\n\n{response}" if reasoning else response + msgs = [] + if sample.get("system"): + msgs.append({"role": "system", "content": sample["system"]}) + msgs += [{"role": "user", "content": user}, {"role": "assistant", "content": content}] + return tokenizer.apply_chat_template( + msgs, tokenize=False, add_generation_prompt=False, enable_thinking=True, + ) + + question = sample.get("problem") or sample.get("question") or "" + answer = ( + sample.get("output") or sample.get("solution") + or sample.get("generation") or sample.get("answer") or "" + ) + if not question or not answer: return None + return tokenizer.apply_chat_template( + [{"role": "user", "content": question}, {"role": "assistant", "content": answer}], + tokenize=False, add_generation_prompt=False, enable_thinking=False, + ) - # Build messages - messages = [ - {"role": "user", "content": input_text}, - {"role": "assistant", "content": None}, - ] - - # Add system prompt if present - if system_prompt: - messages.insert(0, {"role": "system", "content": system_prompt}) - - # Format assistant response based on model type - if "qwen3" in model_name.lower(): - if model_reasoning: - messages[-1]["content"] = f"{model_reasoning}\n\n\n{model_response}" - else: - messages[-1]["content"] = model_response - prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, enable_thinking=True) - else: - if model_reasoning: - messages[-1]["content"] = f"\n{model_reasoning}\n\n\n{model_response}" - else: - messages[-1]["content"] = model_response - prompt = apply_qwen_r1_chat_template(messages, add_generation_prompt=False) - - return prompt - - -def categorize_masks(input_ids, tokenizer, model_name): - """Categorize tokens into mask: system and query are 0, others are 1""" - - masks = [] - current_mask = 0 # Default to 0 for system and query - id_qwen3_think = 151667 - id_qwen3_assistant = 77091 - id_dpsk_think = 151648 - id_mobile_llm = 128002 - if "qwen3" in model_name.lower(): - id_think = id_qwen3_think - elif "mobilellm" in model_name.lower(): - id_think = id_mobile_llm - else: - id_think = id_dpsk_think - - for i, token_id in enumerate(input_ids[0]): - token_id = token_id.item() - - if token_id == id_think: - # count_for_think += 1: # Switch to 1 only on third occurrence - current_mask = 1 - - masks.append(current_mask) - + +def _categorize_masks(input_ids: torch.Tensor) -> List[int]: + """1 for tokens at and after the assistant's token; 0 before. + + The mismatch labels are only meaningful for assistant content (mask=1). + """ + masks: List[int] = [] + current = 0 + for tok in input_ids[0].tolist(): + if tok == QWEN3_THINK_TOKEN_ID: + current = 1 + masks.append(current) return masks -def calculate_mismatch(predictions, real_tokens, data_ids): - """Calculate mismatch between predictions[k] and real_tokens[k+1] for each sample""" +# ──────────────────────────────────────────────────────────────────────────── +# Mismatch scoring +# ──────────────────────────────────────────────────────────────────────────── + + +def _calculate_mismatch(predictions: torch.Tensor, real_tokens: torch.Tensor, data_ids: torch.Tensor) -> torch.Tensor: + """``mismatch[i] = 1`` iff ``predictions[i] != real_tokens[i+1]`` and + position ``i`` is not the last token of its sample. + + Sample boundaries are derived from where ``data_ids`` changes. The very + last token of each sample has no "next token" to predict, so we leave its + mismatch at 0 by masking it out before the comparison. + """ device = predictions.device - - # create mismatch tensor with the same size as input, initialized to 0 mismatch = torch.zeros_like(predictions, dtype=torch.int32, device=device) - - # find the end position of each sample (the position where the data_id changes) - # to handle boundary cases, add a different value to the end of data_ids - padded_data_ids = torch.cat([data_ids, torch.tensor([data_ids[-1] + 1], device=device)]) - - # find the position where data_id changes - change_mask = padded_data_ids[1:] != padded_data_ids[:-1] - sample_end_indices = torch.where(change_mask)[0] - - # create mask, mark all positions except the last position of each sample - valid_mask = torch.ones(len(predictions), dtype=torch.bool, device=device) - valid_mask[sample_end_indices] = False - - # for valid positions, compare predictions[k] and real_tokens[k+1] - # only compare non-last positions - if valid_mask.any(): - valid_indices = torch.where(valid_mask)[0] - pred_tokens = predictions[valid_indices] - next_real_tokens = real_tokens[valid_indices + 1] - - # calculate mismatch: 1 for mismatch, 0 for match - mismatch_values = (pred_tokens != next_real_tokens).int() - mismatch[valid_indices] = mismatch_values - + padded = torch.cat([data_ids, torch.tensor([data_ids[-1] + 1], device=device)]) + sample_ends = torch.where(padded[1:] != padded[:-1])[0] + valid = torch.ones(len(predictions), dtype=torch.bool, device=device) + valid[sample_ends] = False + if valid.any(): + idx = torch.where(valid)[0] + mismatch[idx] = (predictions[idx] != real_tokens[idx + 1]).int() return mismatch.cpu() -def process_single_gpu(args, device_id, data_range, model_name): - """Process dataset on a single GPU""" +# ──────────────────────────────────────────────────────────────────────────── +# Per-GPU worker +# ──────────────────────────────────────────────────────────────────────────── + + +def _process_single_gpu(args, device_id: int, data_range: Tuple[int, int], model_name: str) -> None: + """Run the SLM forward pass on one GPU's slice of the dataset.""" start_idx, end_idx = data_range - model_path = model_name.split("/")[-1] - - print(f"GPU {device_id}: Processing data range {start_idx}-{end_idx} for model {model_name}") - - # Load dataset subset - dataset = load_jsonl_json_dataset(args.dataset_path, (start_idx, end_idx), args.random_num) - - # Load tokenizer + print(f"GPU {device_id}: range {start_idx}-{end_idx} of {model_name}") + + dataset = _load_dataset(args.dataset_path, (start_idx, end_idx), args.random_num) tokenizer = AutoTokenizer.from_pretrained(model_name) - # Ensure pad token for batching if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" - - # Load model on specific GPU - model = load_model(model_name, device_id) - if model is None: - return None - - # Store results - predictions_list = [] - real_tokens_list = [] - token_ids_list = [] - data_ids_list = [] - masks_list = [] - entropy_list = [] if args.save_entropy else None - ce_list = [] if args.save_ce else None - - # Process each sample - pbar = tqdm(total=len(dataset), desc=f"GPU {device_id} - {model_path}", position=device_id) + + try: + model = AutoModelForCausalLM.from_pretrained( + model_name, device_map=f"cuda:{device_id}", torch_dtype=torch.bfloat16, + ).eval() + except Exception as e: + print(f"GPU {device_id}: model load failed: {e}") + return + + # Per-token outputs aggregated across the GPU's samples. + preds, reals, tok_ids, data_ids, masks = [], [], [], [], [] + entropies: Optional[List[torch.Tensor]] = [] if args.save_entropy else None + ces: Optional[List[torch.Tensor]] = [] if args.save_ce else None + + bs = max(1, int(args.batch_size)) + pbar = tqdm(total=len(dataset), desc=f"GPU {device_id}", position=device_id) with torch.no_grad(): - bs = max(1, int(getattr(args, "batch_size", 1))) - num_samples = len(dataset) - processed = 0 - for batch_start in range(0, num_samples, bs): - batch_end = min(batch_start + bs, num_samples) - batch_items = [] # tuples: (ids_1d_cpu, length, global_id) - prompts_meta = [] # store tensors length and global id for later - for local_offset, sample in enumerate(dataset[batch_start:batch_end]): - global_data_id = start_idx + (batch_start + local_offset) - # Build prompt - if sample.get("conversations") is not None: - prompt = get_formatted_prompt(sample, tokenizer, model_name) - else: - prompt = get_formatted_prompt_1(sample, tokenizer, model_name) + for batch_start in range(0, len(dataset), bs): + batch_end = min(batch_start + bs, len(dataset)) + batch_ids: List[torch.Tensor] = [] + batch_meta: List[Tuple[int, int]] = [] # (length, global_data_id) + + for offset, sample in enumerate(dataset[batch_start:batch_end]): + global_id = start_idx + batch_start + offset + prompt = _format_prompt(sample, tokenizer) if prompt is None: continue ids = tokenizer(prompt, return_tensors="pt").input_ids[0] if ids.shape[-1] > args.max_input_length: if args.is_cutoff: - ids = ids[:args.max_input_length] + ids = ids[: args.max_input_length] else: continue - batch_items.append(ids) - prompts_meta.append((ids.shape[-1], global_data_id)) - if not batch_items: - # even if empty due to skips, advance pbar by original batch window size + batch_ids.append(ids) + batch_meta.append((ids.shape[-1], global_id)) + + if not batch_ids: pbar.update(batch_end - batch_start) continue - lengths = [x.shape[0] for x in batch_items] - padded = pad_sequence(batch_items, batch_first=True, padding_value=tokenizer.pad_token_id) - attention_mask = torch.arange(padded.shape[1]).unsqueeze(0) < torch.tensor(lengths).unsqueeze(1) + + lengths = [t.shape[0] for t in batch_ids] + padded = pad_sequence(batch_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + attn = torch.arange(padded.shape[1]).unsqueeze(0) < torch.tensor(lengths).unsqueeze(1) input_ids = padded.to(model.device) - attention_mask = attention_mask.to(model.device) - # Forward - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): - outputs = model(input_ids=input_ids, attention_mask=attention_mask) - logits = outputs.logits.to(torch.float32) # [B, Lmax, V] - # Per-sample handling - for b_idx, (seq_len, global_data_id) in enumerate(prompts_meta): + attention_mask = attn.to(model.device) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(input_ids=input_ids, attention_mask=attention_mask).logits.to(torch.float32) + + for b_idx, (seq_len, global_id) in enumerate(batch_meta): seq_logits = logits[b_idx, :seq_len, :] - if args.save_entropy or args.save_ce: - lp = F.log_softmax(seq_logits, dim=-1) + lp = F.log_softmax(seq_logits, dim=-1) if (args.save_entropy or args.save_ce) else None if args.save_entropy: - probs = lp.exp() - entropy = -(probs * lp).sum(dim=-1).cpu() # [seq_len] - entropy_list.append(entropy) - pred = sample_token(seq_logits, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k).cpu() - token_id = torch.arange(0, seq_len, 1).cpu() - data_id_tensor = torch.full((seq_len,), global_data_id, dtype=token_id.dtype).cpu() - real_token = input_ids[b_idx, :seq_len].detach().cpu() - masks = categorize_masks(real_token.unsqueeze(0), tokenizer, model_name) - masks_tensor = torch.tensor(masks, dtype=torch.int32).cpu() + entropies.append((-lp.exp() * lp).sum(dim=-1).cpu()) + pred = sample_next_token( + seq_logits, temperature=args.temperature, top_p=args.top_p, + top_k=max(args.top_k, 0), do_sample=args.temperature > 0, + ).cpu() + + token_id = torch.arange(seq_len, dtype=torch.long) + data_id = torch.full((seq_len,), global_id, dtype=torch.long) + real = input_ids[b_idx, :seq_len].detach().cpu() + mask_ = torch.tensor(_categorize_masks(real.unsqueeze(0)), dtype=torch.int32) + if args.save_ce: - device = seq_logits.device - data_id_dev = data_id_tensor.to(device) - padded_ids = torch.cat([data_id_dev, torch.tensor([data_id_dev[-1] + 1], device=device)]) - change_mask = padded_ids[1:] != padded_ids[:-1] - sample_end_indices = torch.where(change_mask)[0] - valid_mask = torch.ones(seq_len, dtype=torch.bool, device=device) - valid_mask[sample_end_indices] = False - ce = torch.zeros(seq_len, dtype=torch.float32, device=device) - if valid_mask.any(): - valid_indices = torch.where(valid_mask)[0] - targets = input_ids[b_idx, :seq_len].to(device)[valid_indices + 1] - ce_values = -lp[valid_indices, targets] - ce[valid_indices] = ce_values - ce_list.append(ce.cpu()) - predictions_list.append(pred) - real_tokens_list.append(real_token) - token_ids_list.append(token_id) - data_ids_list.append(data_id_tensor) - masks_list.append(masks_tensor) - processed += (batch_end - batch_start) + # Per-token CE w.r.t. the next ground-truth token, masked + # out at sample boundaries (last position has no target). + valid = torch.ones(seq_len, dtype=torch.bool, device=lp.device) + valid[-1] = False + targets = real.to(lp.device) + ce = torch.zeros(seq_len, dtype=torch.float32, device=lp.device) + if valid.any(): + idx = torch.where(valid)[0] + ce[idx] = -lp[idx, targets[idx + 1]] + ces.append(ce.cpu()) + + preds.append(pred) + reals.append(real) + tok_ids.append(token_id) + data_ids.append(data_id) + masks.append(mask_) + pbar.update(batch_end - batch_start) if (batch_start // bs) % 10 == 0: torch.cuda.empty_cache() pbar.close() - if not predictions_list: - print(f"GPU {device_id}: No valid samples processed") - return None - - # Concatenate results - predictions = torch.cat(predictions_list, dim=0) - real_tokens = torch.cat(real_tokens_list, dim=0) - token_ids = torch.cat(token_ids_list, dim=0) - data_ids = torch.cat(data_ids_list, dim=0) - masks = torch.cat(masks_list, dim=0) - # Optional tensors - if args.save_entropy: - entropies = torch.cat(entropy_list, dim=0) - - # Calculate mismatch - print(f"GPU {device_id}: Calculating mismatch...") - mismatch = calculate_mismatch(predictions, real_tokens, data_ids) - - # Convert tensors to python lists for Dataset compatibility - results_dict = { - "predictions": predictions.tolist(), - "small_token": token_ids.tolist(), - "data_id": data_ids.tolist(), - "mask": masks.tolist(), - "real_token": real_tokens.tolist(), + if not preds: + print(f"GPU {device_id}: no valid samples") + return + + cat_preds = torch.cat(preds) + cat_reals = torch.cat(reals) + cat_data_ids = torch.cat(data_ids) + print(f"GPU {device_id}: computing mismatch…") + mismatch = _calculate_mismatch(cat_preds, cat_reals, cat_data_ids) + + out: dict = { + "predictions": cat_preds.tolist(), + "small_token": torch.cat(tok_ids).tolist(), + "data_id": cat_data_ids.tolist(), + "mask": torch.cat(masks).tolist(), + "real_token": cat_reals.tolist(), "mismatch": mismatch.tolist(), } if args.save_entropy: - results_dict["entropy"] = entropies.tolist() + out["entropy"] = torch.cat(entropies).tolist() if args.save_ce: - ce_tensor = torch.cat(ce_list, dim=0) - results_dict["cross_entropy"] = ce_tensor.tolist() - - # Create Dataset from dict - dataset = Dataset.from_dict(results_dict) - - # Save as Dataset - output_file = os.path.join(args.output_path, f"results_gpu_{device_id}_{model_path}") - dataset.save_to_disk(output_file) - - # Clear variables - del model - del tokenizer - if 'predictions_list' in locals(): del predictions_list - if 'real_tokens_list' in locals(): del real_tokens_list - if 'token_ids_list' in locals(): del token_ids_list - if 'data_ids_list' in locals(): del data_ids_list - if 'masks_list' in locals(): del masks_list - if 'entropy_list' in locals(): del entropy_list - if 'ce_list' in locals(): del ce_list - - # Clear GPU cache and synchronize + out["cross_entropy"] = torch.cat(ces).tolist() + + out_path = os.path.join(args.output_path, f"results_gpu_{device_id}_{model_name.split('/')[-1]}") + Dataset.from_dict(out).save_to_disk(out_path) + print(f"GPU {device_id}: saved → {out_path}") + + del model, tokenizer if torch.cuda.is_available(): torch.cuda.synchronize(device=device_id) torch.cuda.empty_cache() torch.cuda.ipc_collect() - - # Force garbage collection - import gc gc.collect() - - print(f"GPU {device_id}: Saved dataset to {output_file}") - return dataset -def analyze_detailed_statistics(df, tokenizer): - """Perform detailed statistical analysis on the dataset""" - analysis_results = {} - - # Basic statistics - total_tokens = len(df) - total_samples = df['data_id'].nunique() - total_mismatch_tokens = sum(df['mismatch']) - - analysis_results['basic'] = { - 'total_tokens': int(total_tokens), - 'total_samples': int(total_samples), - 'total_mismatch_tokens': int(total_mismatch_tokens), - 'mismatch_ratio': float((total_mismatch_tokens / total_tokens * 100)) if total_tokens > 0 else 0.0 - } - - # Mask-based analysis (0=system/user, 1=assistant) - mask_0_tokens = len(df[df['mask'] == 0]) - mask_1_tokens = len(df[df['mask'] == 1]) - mask_0_mismatch = sum(df[df['mask'] == 0]['mismatch']) - mask_1_mismatch = sum(df[df['mask'] == 1]['mismatch']) - - analysis_results['mask_analysis'] = { - 'system_user_tokens': int(mask_0_tokens), - 'assistant_tokens': int(mask_1_tokens), - 'system_user_mismatch': int(mask_0_mismatch), - 'assistant_mismatch': int(mask_1_mismatch), - 'system_user_mismatch_ratio': float((mask_0_mismatch / mask_0_tokens * 100)) if mask_0_tokens > 0 else 0.0, - 'assistant_mismatch_ratio': float((mask_1_mismatch / mask_1_tokens * 100)) if mask_1_tokens > 0 else 0.0 - } - - # Per-sample analysis - sample_stats = [] - grouped = df.groupby('data_id') - token_lengths = [] - assistant_token_lengths = [] - mismatch_ratios = [] - - for data_id, group in grouped: - sample_total_tokens = len(group) - sample_assistant_tokens = len(group[group['mask'] == 1]) - sample_mismatch_tokens = sum(group['mismatch']) - sample_assistant_mismatch = sum(group[group['mask'] == 1]['mismatch']) - - token_lengths.append(sample_total_tokens) - assistant_token_lengths.append(sample_assistant_tokens) - - sample_mismatch_ratio = (sample_mismatch_tokens / sample_total_tokens * 100) if sample_total_tokens > 0 else 0 - mismatch_ratios.append(sample_mismatch_ratio) - - sample_stats.append({ - 'data_id': int(data_id), - 'total_tokens': int(sample_total_tokens), - 'assistant_tokens': int(sample_assistant_tokens), - 'mismatch_tokens': int(sample_mismatch_tokens), - 'assistant_mismatch_tokens': int(sample_assistant_mismatch), - 'mismatch_ratio': float(sample_mismatch_ratio), - 'assistant_mismatch_ratio': float((sample_assistant_mismatch / sample_assistant_tokens * 100)) if sample_assistant_tokens > 0 else 0.0 - }) - - # Token length statistics - analysis_results['length_analysis'] = { - 'avg_tokens_per_sample': float(np.mean(token_lengths)), - 'median_tokens_per_sample': float(np.median(token_lengths)), - 'min_tokens_per_sample': int(np.min(token_lengths)), - 'max_tokens_per_sample': int(np.max(token_lengths)), - 'std_tokens_per_sample': float(np.std(token_lengths)), - 'avg_assistant_tokens': float(np.mean(assistant_token_lengths)), - 'median_assistant_tokens': float(np.median(assistant_token_lengths)), - 'min_assistant_tokens': int(np.min(assistant_token_lengths)), - 'max_assistant_tokens': int(np.max(assistant_token_lengths)) - } - - # Mismatch ratio distribution - analysis_results['mismatch_distribution'] = { - 'avg_mismatch_ratio': float(np.mean(mismatch_ratios)), - 'median_mismatch_ratio': float(np.median(mismatch_ratios)), - 'min_mismatch_ratio': float(np.min(mismatch_ratios)), - 'max_mismatch_ratio': float(np.max(mismatch_ratios)), - 'std_mismatch_ratio': float(np.std(mismatch_ratios)), - 'samples_with_no_mismatch': int(sum(1 for ratio in mismatch_ratios if ratio == 0)), - 'samples_with_high_mismatch': int(sum(1 for ratio in mismatch_ratios if ratio > 50)) - } - - # Token frequency analysis for real tokens - real_token_counts = df['real_token'].value_counts() - most_common_tokens = real_token_counts.head(20).to_dict() - - # Decode most common tokens - decoded_common_tokens = {} - for token_id, count in most_common_tokens.items(): - try: - decoded_token = tokenizer.decode([int(token_id)]) - decoded_common_tokens[f"{int(token_id)} ({repr(decoded_token)})"] = int(count) - except: - decoded_common_tokens[str(int(token_id))] = int(count) - - analysis_results['token_frequency'] = { - 'most_common_tokens': decoded_common_tokens, - 'unique_tokens': int(len(real_token_counts)), - 'total_token_occurrences': int(real_token_counts.sum()) - } - - return analysis_results, sample_stats - -def process_and_convert_dataset(merged_dataset, model_name, output_path): - """Convert merged dataset to final processed format""" - print("Converting dataset to final format...") - - # Load tokenizer for text decoding +# ──────────────────────────────────────────────────────────────────────────── +# Per-GPU result merge + final dataset +# ──────────────────────────────────────────────────────────────────────────── + + +def _process_and_convert_dataset(merged: Dataset, model_name: str, output_path: str) -> Dataset: + """Group per-token results by ``data_id`` into one row per sample, decode + text, and write the SFT-ready dataset to ``output_path``. + + Also prints a brief stats summary (basic counts, mask split, mismatch + ratio). Public TaH wrote a multi-page text/JSON/CSV analysis report; + rich enough to be its own postprocess script — kept the inline summary, + dropped the report files. + """ tokenizer = AutoTokenizer.from_pretrained(model_name) - - # Convert to pandas for easier grouping - df = merged_dataset.to_pandas() - - # Perform detailed analysis on raw data - print("Performing detailed statistical analysis...") - analysis_results, sample_stats = analyze_detailed_statistics(df, tokenizer) - - # 按data_id分组 - grouped = df.groupby('data_id') - print(f"Found {len(grouped)} unique data_ids.") - - # Initialize counters for statistics - total_tokens = 0 - total_mismatch_tokens = 0 - - final_data_list = [] - print("Processing groups...") - for data_id, group in tqdm(grouped): - # Convert the real_token list to text - real_tokens = group['real_token'].tolist() - real_text = tokenizer.decode(real_tokens) - - # Get mismatch indices - mismatch_indices = group['mismatch'].tolist() - - # Update statistics - total_tokens += len(real_tokens) - total_mismatch_tokens += sum(1 for x in mismatch_indices if x == 1) - - processed_item = { - 'data_id': data_id, - 'real_text': real_text, - 'real_token': real_tokens, - 'mask': group['mask'].tolist(), - 'mismatch': mismatch_indices, + df = merged.to_pandas() + grouped = df.groupby("data_id") + print(f"Found {len(grouped)} unique data_ids") + + final = [] + for data_id, group in tqdm(grouped, desc="Grouping by data_id"): + real = group["real_token"].tolist() + item = { + "data_id": data_id, + "real_text": tokenizer.decode(real), + "real_token": real, + "mask": group["mask"].tolist(), + "mismatch": group["mismatch"].tolist(), } - if 'entropy' in group.columns: - processed_item['entropy'] = group['entropy'].tolist() - if 'cross_entropy' in group.columns: - processed_item['cross_entropy'] = group['cross_entropy'].tolist() - final_data_list.append(processed_item) - - # Print statistics - print("\n" + "="*80) - print("DETAILED STATISTICS SUMMARY") - print("="*80) - - basic = analysis_results['basic'] - print(f"Basic Statistics:") - print(f" Total samples: {basic['total_samples']:,}") - print(f" Total tokens: {basic['total_tokens']:,}") - print(f" Total mismatch tokens: {basic['total_mismatch_tokens']:,}") - print(f" Overall mismatch ratio: {basic['mismatch_ratio']:.2f}%") - - mask = analysis_results['mask_analysis'] - print(f"\nMask-based Analysis:") - print(f" System/User tokens (mask=0): {mask['system_user_tokens']:,}") - print(f" Assistant tokens (mask=1): {mask['assistant_tokens']:,}") - print(f" System/User mismatch: {mask['system_user_mismatch']:,} ({mask['system_user_mismatch_ratio']:.2f}%)") - print(f" Assistant mismatch: {mask['assistant_mismatch']:,} ({mask['assistant_mismatch_ratio']:.2f}%)") - - length = analysis_results['length_analysis'] - print(f"\nToken Length Analysis:") - print(f" Avg tokens per sample: {length['avg_tokens_per_sample']:.1f}") - print(f" Median tokens per sample: {length['median_tokens_per_sample']:.1f}") - print(f" Token range: {length['min_tokens_per_sample']:.0f} - {length['max_tokens_per_sample']:.0f}") - print(f" Avg assistant tokens: {length['avg_assistant_tokens']:.1f}") - - mismatch_dist = analysis_results['mismatch_distribution'] - print(f"\nMismatch Distribution:") - print(f" Avg mismatch ratio per sample: {mismatch_dist['avg_mismatch_ratio']:.2f}%") - print(f" Median mismatch ratio: {mismatch_dist['median_mismatch_ratio']:.2f}%") - print(f" Samples with no mismatch: {mismatch_dist['samples_with_no_mismatch']}") - print(f" Samples with >50% mismatch: {mismatch_dist['samples_with_high_mismatch']}") - - print("="*80) - - processed_dataset = Dataset.from_pandas(pd.DataFrame(final_data_list)) - - print(f"Processed dataset info:") - print(processed_dataset) - - # Save processed dataset directly to output_path - final_output_path = output_path - processed_dataset.save_to_disk(final_output_path) - print(f"Processed dataset saved to {final_output_path}") - - # Save detailed analysis to files - analysis_dir = final_output_path - - # Save detailed statistics - detailed_stats_file = os.path.join(analysis_dir, "detailed_analysis.json") - with open(detailed_stats_file, 'w', encoding='utf-8') as f: - json.dump(analysis_results, f, indent=2, ensure_ascii=False) - - # Save per-sample statistics - sample_stats_file = os.path.join(analysis_dir, "per_sample_statistics.csv") - sample_df = pd.DataFrame(sample_stats) - sample_df.to_csv(sample_stats_file, index=False) - - # Save comprehensive text report - report_file = os.path.join(analysis_dir, "analysis_report.txt") - with open(report_file, 'w', encoding='utf-8') as f: - f.write("COMPREHENSIVE DATA ANALYSIS REPORT\n") - f.write("="*50 + "\n\n") - - f.write("1. BASIC STATISTICS\n") - f.write("-"*20 + "\n") - f.write(f"Total samples: {basic['total_samples']:,}\n") - f.write(f"Total tokens: {basic['total_tokens']:,}\n") - f.write(f"Total mismatch tokens: {basic['total_mismatch_tokens']:,}\n") - f.write(f"Overall mismatch ratio: {basic['mismatch_ratio']:.4f}%\n\n") - - f.write("2. MASK-BASED ANALYSIS\n") - f.write("-"*20 + "\n") - f.write(f"System/User tokens (mask=0): {mask['system_user_tokens']:,}\n") - f.write(f"Assistant tokens (mask=1): {mask['assistant_tokens']:,}\n") - f.write(f"System/User mismatch: {mask['system_user_mismatch']:,} ({mask['system_user_mismatch_ratio']:.4f}%)\n") - f.write(f"Assistant mismatch: {mask['assistant_mismatch']:,} ({mask['assistant_mismatch_ratio']:.4f}%)\n\n") - - f.write("3. TOKEN LENGTH ANALYSIS\n") - f.write("-"*25 + "\n") - f.write(f"Average tokens per sample: {length['avg_tokens_per_sample']:.2f}\n") - f.write(f"Median tokens per sample: {length['median_tokens_per_sample']:.2f}\n") - f.write(f"Min tokens per sample: {length['min_tokens_per_sample']:.0f}\n") - f.write(f"Max tokens per sample: {length['max_tokens_per_sample']:.0f}\n") - f.write(f"Std deviation: {length['std_tokens_per_sample']:.2f}\n") - f.write(f"Average assistant tokens: {length['avg_assistant_tokens']:.2f}\n") - f.write(f"Median assistant tokens: {length['median_assistant_tokens']:.2f}\n\n") - - f.write("4. MISMATCH DISTRIBUTION\n") - f.write("-"*25 + "\n") - f.write(f"Average mismatch ratio per sample: {mismatch_dist['avg_mismatch_ratio']:.4f}%\n") - f.write(f"Median mismatch ratio: {mismatch_dist['median_mismatch_ratio']:.4f}%\n") - f.write(f"Min mismatch ratio: {mismatch_dist['min_mismatch_ratio']:.4f}%\n") - f.write(f"Max mismatch ratio: {mismatch_dist['max_mismatch_ratio']:.4f}%\n") - f.write(f"Std deviation: {mismatch_dist['std_mismatch_ratio']:.4f}%\n") - f.write(f"Samples with no mismatch: {mismatch_dist['samples_with_no_mismatch']}\n") - f.write(f"Samples with >50% mismatch: {mismatch_dist['samples_with_high_mismatch']}\n\n") - - f.write("5. TOKEN FREQUENCY ANALYSIS\n") - f.write("-"*30 + "\n") - token_freq = analysis_results['token_frequency'] - f.write(f"Unique tokens: {token_freq['unique_tokens']:,}\n") - f.write(f"Total token occurrences: {token_freq['total_token_occurrences']:,}\n") - f.write("Most common tokens:\n") - for token, count in list(token_freq['most_common_tokens'].items())[:10]: - f.write(f" {token}: {count:,}\n") - - print(f"Detailed analysis saved to:") - print(f" - JSON format: {detailed_stats_file}") - print(f" - Per-sample CSV: {sample_stats_file}") - print(f" - Text report: {report_file}") - - return processed_dataset - -def merge_gpu_results(args, model_name): - """Merge results from all GPUs and convert to final format""" - model_path = model_name.split("/")[-1] - all_datasets = [] - - # Load results from all GPUs + if "entropy" in group.columns: + item["entropy"] = group["entropy"].tolist() + if "cross_entropy" in group.columns: + item["cross_entropy"] = group["cross_entropy"].tolist() + final.append(item) + + total_tokens = len(df) + total_mismatch = int(df["mismatch"].sum()) + asst_tokens = int((df["mask"] == 1).sum()) + asst_mismatch = int(df.loc[df["mask"] == 1, "mismatch"].sum()) + print("\n=== Label statistics ===") + print(f" samples: {df['data_id'].nunique():,}") + print(f" tokens: {total_tokens:,} (mismatch {total_mismatch:,}, {100*total_mismatch/max(total_tokens,1):.2f}%)") + print(f" assistant tokens: {asst_tokens:,} (mismatch {asst_mismatch:,}, {100*asst_mismatch/max(asst_tokens,1):.2f}%)") + print("========================") + + processed = Dataset.from_pandas(pd.DataFrame(final)) + processed.save_to_disk(output_path) + print(f"Processed dataset → {output_path}") + return processed + + +def _merge_gpu_results(args, model_name: str) -> Optional[Dataset]: + """Load and concatenate per-GPU results, then collapse into the final SFT dataset.""" + short = model_name.split("/")[-1] + parts = [] for gpu_id in range(args.num_gpu): - result_dir = os.path.join(args.output_path, f"results_gpu_{gpu_id}_{model_path}") - if os.path.exists(result_dir): - dataset = Dataset.load_from_disk(result_dir) - all_datasets.append(dataset) - print(f"Loaded dataset from GPU {gpu_id}") - - if not all_datasets: + d = os.path.join(args.output_path, f"results_gpu_{gpu_id}_{short}") + if os.path.exists(d): + parts.append(Dataset.load_from_disk(d)) + print(f"Loaded GPU {gpu_id} dataset") + if not parts: print("No GPU datasets found to merge") return None - - # Concatenate all datasets - merged_dataset = concatenate_datasets(all_datasets) - - # Process and convert to final format - processed_dataset = process_and_convert_dataset(merged_dataset, model_name, args.output_path) - - # Clean up individual GPU files + merged = concatenate_datasets(parts) + final = _process_and_convert_dataset(merged, model_name, args.output_path) for gpu_id in range(args.num_gpu): - result_dir = os.path.join(args.output_path, f"results_gpu_{gpu_id}_{model_path}") - if os.path.exists(result_dir): - import shutil - shutil.rmtree(result_dir) - - return processed_dataset - - -def process_dataset_multi_gpu(args): - """Process the JSONL dataset with multiple GPUs""" - global running_processes - - # Create output directory - if not os.path.exists(args.output_path): - os.makedirs(args.output_path) - - # Load full dataset to get length and split - print(f"Loading dataset from {args.dataset_path}") - full_dataset = load_jsonl_json_dataset(args.dataset_path, args.index_range, args.random_num) - print(f"Dataset length: {len(full_dataset)}") - - # Split dataset into multiple parts based on num_gpu - data_splits = split_dataset(full_dataset, args.num_gpu) - print(f"Dataset split into {args.num_gpu} parts: {data_splits}") - - # Process each model + d = os.path.join(args.output_path, f"results_gpu_{gpu_id}_{short}") + if os.path.exists(d): + shutil.rmtree(d) + return final + + +# ──────────────────────────────────────────────────────────────────────────── +# Multi-GPU orchestration +# ──────────────────────────────────────────────────────────────────────────── + + +def _terminate_with_timeout(p: mp.Process, soft: int = 30, hard: int = 60) -> None: + """Best-effort cleanup of one worker process (terminate → kill if still alive).""" + if not p.is_alive(): + return + p.terminate() + p.join(timeout=soft) + if p.is_alive(): + print(f"Force killing process {p.pid}…") + p.kill() + p.join(timeout=hard) + + +def _process_dataset_multi_gpu(args) -> None: + """Spawn one worker per GPU per model, wait, then merge results.""" + global _running_processes + os.makedirs(args.output_path, exist_ok=True) + + full = _load_dataset(args.dataset_path, args.index_range, args.random_num) + print(f"Total samples: {len(full)}") + splits = _split_indices(len(full), args.num_gpu) + print(f"Per-GPU splits: {splits}") + for model_name in args.test_model_list: - model_path = model_name.split("/")[-1] - - # # Skip if processed results already exist - # if os.path.exists(os.path.join(args.output_path, f"processed_data_{model_path}")): - # print(f"Processed results for {model_name} already exist, skipping.") - # continue - - print(f"Processing model: {model_name}") - - # Create processes for each GPU - processes = [] - - for gpu_id in range(args.num_gpu): - p = mp.Process( - target=process_single_gpu, - args=(args, gpu_id, data_splits[gpu_id], model_name) - ) - processes.append(p) + print(f"\n=== Processing model: {model_name} ===") + processes = [ + mp.Process(target=_process_single_gpu, args=(args, gpu_id, splits[gpu_id], model_name)) + for gpu_id in range(args.num_gpu) + ] + _running_processes = processes + for p in processes: p.start() - - # Update global process list for signal handling - running_processes = processes - - # Wait for all processes to complete with timeout - timeout = 24 * 60 * 60 # 24 hours timeout per process (increased from 1 hour) + timeout = 24 * 60 * 60 for i, p in enumerate(processes): try: p.join(timeout=timeout) if p.is_alive(): - print(f"GPU {i}: Process timeout after {timeout} seconds, terminating...") - p.terminate() - p.join(timeout=30) # Give 30 seconds for graceful termination (increased from 10 seconds) - if p.is_alive(): - print(f"GPU {i}: Force killing process...") - p.kill() - p.join() + print(f"GPU {i}: process timed out, terminating…") + _terminate_with_timeout(p) elif p.exitcode != 0: - print(f"GPU {i}: Process exited with code {p.exitcode}") + print(f"GPU {i}: exited with code {p.exitcode}") else: - print(f"GPU {i}: Process completed successfully") + print(f"GPU {i}: completed") except Exception as e: - print(f"GPU {i}: Error during process join: {e}") - if p.is_alive(): - p.terminate() - p.join(timeout=30) # Give 30 seconds for graceful termination (increased from 10 seconds) - if p.is_alive(): - p.kill() - p.join() - - # Ensure all processes are cleaned up + print(f"GPU {i}: error during join: {e}") + _terminate_with_timeout(p) for p in processes: - if p.is_alive(): - p.terminate() - p.join(timeout=60) # Give 60 seconds for graceful termination (increased from 5 seconds) - if p.is_alive(): - p.kill() - p.join() - - print(f"All GPU processes completed for {model_name}") - - # Clear global process list - running_processes = [] - - # Merge results from all GPUs and convert to final format - processed_results = merge_gpu_results(args, model_name) - if processed_results is None: - print(f"Failed to process results for {model_name}") - continue - - print("Multi-GPU processing completed!") - - -def main(): - global running_processes - - # Register signal handlers for graceful cleanup - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - parser = argparse.ArgumentParser( - description="Run multi-GPU model inference on JSONL or JSON datasets" - ) - parser.add_argument( - "--dataset_path", type=str, required=True, help="Path to the dataset file (supports .jsonl and .json formats)" - ) - parser.add_argument( - "--num_gpu", type=int, default=4, help="Number of GPUs to use" - ) - parser.add_argument( - "--test_model_list", - nargs="+", - type=str, - required=True, - help="List of test models to run", - ) - parser.add_argument( - "--output_path", type=str, required=True, help="Directory to save output files" - ) - parser.add_argument( - "--max_input_length", - type=int, - default=32768, - help="Maximum length of input tokens", - ) - parser.add_argument( - "--is_cutoff", - type=bool, - default=False, - help="whether of not cut off the dataset", - ) - parser.add_argument( - "--index_range", - nargs=2, - type=int, - default=None, - help="Range of dataset samples to process [start_idx, end_idx]", - ) - parser.add_argument( - "--random_num", - type=int, - default=None, - help="Number of samples to randomly sample from the dataset", - ) - parser.add_argument( - "--top_k", - type=int, - default=-1, - help="Number of top predictions to include in the output", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.0, - help="Temperature to apply to logits", - ) - parser.add_argument( - "--top_p", - type=float, - default=1.0, - help="Top-p probability threshold for nucleus sampling", - ) - parser.add_argument( - "--batch_size", - type=int, - default=1, - help="Batch size for prefill/inference per GPU process", - ) - parser.add_argument( - "--save_entropy", - action="store_true", - help="Save per-token entropy of model logits", - ) - parser.add_argument( - "--save_ce", - action="store_true", - help="Save per-token cross-entropy w.r.t. next-token labels", - ) + _terminate_with_timeout(p) + _running_processes = [] + + if _merge_gpu_results(args, model_name) is None: + print(f"Failed to merge results for {model_name}") + + print("Multi-GPU labelling complete.") + + +# ──────────────────────────────────────────────────────────────────────────── +# CLI entry +# ──────────────────────────────────────────────────────────────────────────── + + +def main() -> None: + signal.signal(signal.SIGINT, _signal_handler) + signal.signal(signal.SIGTERM, _signal_handler) + + parser = argparse.ArgumentParser(description="Multi-GPU SLM prefill → mismatch labelling") + parser.add_argument("--dataset_path", required=True, help="JSONL or JSON conversations dataset") + parser.add_argument("--num_gpu", type=int, default=4) + parser.add_argument("--test_model_list", nargs="+", required=True, help="HF model id(s) for the SLM labeller") + parser.add_argument("--output_path", required=True) + parser.add_argument("--max_input_length", type=int, default=32768) + parser.add_argument("--is_cutoff", type=bool, default=False, help="If True, truncate over-long samples; else drop them") + parser.add_argument("--index_range", nargs=2, type=int, default=None, help="[start, end) slice of the dataset") + parser.add_argument("--random_num", type=int, default=None, help="Sub-sample N of available; seed=42") + parser.add_argument("--top_k", type=int, default=-1) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--save_entropy", action="store_true", help="Save per-token softmax entropy") + parser.add_argument("--save_ce", action="store_true", help="Save per-token cross-entropy w.r.t. next-token labels") args = parser.parse_args() - # Set multiprocessing start method - mp.set_start_method('spawn', force=True) - + mp.set_start_method("spawn", force=True) try: - # Process dataset with multiple GPUs - process_dataset_multi_gpu(args) - - # Save args as json - with open(os.path.join(args.output_path, "args.json"), "w") as f: - json.dump(args.__dict__, f) - - print("All processing completed!") - - except Exception as e: - print(f"Error during processing: {e}") - # Clean up any remaining processes - for p in running_processes: - if p.is_alive(): - p.terminate() - p.join(timeout=60) # Give 60 seconds for graceful termination (increased from 5 seconds) - if p.is_alive(): - p.kill() - p.join() - raise e + _process_dataset_multi_gpu(args) + with open(os.path.join(args.output_path, "args.json"), "w", encoding="utf-8") as f: + json.dump(args.__dict__, f, indent=2) + print("All processing complete.") + except Exception: + for p in _running_processes: + _terminate_with_timeout(p) + raise finally: - # Final cleanup - running_processes = [] + _running_processes.clear() if __name__ == "__main__": - main() + main() diff --git a/script/preparation/prune.py b/script/preparation/prune.py index 7e80908..6bd233f 100644 --- a/script/preparation/prune.py +++ b/script/preparation/prune.py @@ -1,44 +1,58 @@ +"""Prune one or more layers from a Qwen3-style decoder-only model. + +Usage: + + python script/preparation/prune.py \\ + --model Qwen/Qwen3-1.7B-Base \\ + --dataset data/processed_data/openr1_math/1_7/eval \\ + --output model/qwen3_1.7_base_pruned \\ + --num_prune 1 + +For each layer, temporarily remove it from the model and measure perplexity +on a small calibration set; the layers with the smallest ΔPPL are dropped +permanently and the resulting model is saved with ``model.save_pretrained``. +""" +from __future__ import annotations + import argparse +import math import os -os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' -os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7' import sys -import math -from typing import List, Tuple, Optional +from typing import List, Optional, Tuple import torch +from datasets import load_dataset, load_from_disk from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer -from datasets import load_dataset, load_from_disk def _finite_or_inf(x: float) -> float: - """Return x if finite, else +inf to avoid NaN/Inf in comparisons.""" + """Return ``x`` if finite, else ``+inf`` so it sorts to the back of "best" lists.""" return x if math.isfinite(x) else float("inf") def _get_device(device_arg: str) -> torch.device: - """Resolve device from arg.""" if device_arg: return torch.device(device_arg) return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -def _resolve_layers_module(model: nn.Module) -> Tuple[nn.Module, List[nn.Module]]: - """Find the decoder layers ModuleList for common HF causal LMs. +def _resolve_layers_module(model: nn.Module) -> Tuple[nn.Module, str, List[nn.Module]]: + """Locate the decoder ``ModuleList`` on common HF causal-LM layouts. - Returns a tuple of (parent_module, layers_list). + Returns ``(parent_module, attr_name, layers_list)`` where + ``parent_module.`` is the writable ModuleList we'll + swap out when pruning. """ - # Common path for modern decoder-only models - for parent_attr in ["model", "transformer"]: + for parent_attr in ("model", "transformer"): parent = getattr(model, parent_attr, None) if parent is None: continue - for layers_attr in ["layers", "h"]: + for layers_attr in ("layers", "h"): layers = getattr(parent, layers_attr, None) if isinstance(layers, nn.ModuleList) and len(layers) > 0: - return parent, list(layers) - raise RuntimeError("unable to locate model layers") + return parent, layers_attr, list(layers) + raise RuntimeError("unable to locate decoder layers on the model") def _prepare_tokenizer(model_name_or_path: str): @@ -47,70 +61,25 @@ def _prepare_tokenizer(model_name_or_path: str): tokenizer.pad_token = tokenizer.eos_token return tokenizer -def _read_calibration_from_hf(dataset_arg: str, text_field: str, max_samples: int) -> List[str]: - if load_dataset is None or load_from_disk is None: - raise RuntimeError("datasets not found") - - texts: List[str] = [] - # Try local disk first - ds = None +def _read_calibration_from_hf(dataset_arg: str, text_field: str, max_samples: int) -> List[str]: + """Read up to ``max_samples`` non-empty strings from ``dataset_arg``'s + ``text_field`` column. Accepts either a local saved-to-disk dataset or an HF + dataset id.""" if os.path.isdir(dataset_arg): - try: - ds = load_from_disk(dataset_arg) - except Exception as e: - # Fallback: directly read Arrow file to extract text when features schema is incompatible - try: - import pyarrow as pa # type: ignore - import pyarrow.ipc as pa_ipc # type: ignore - # locate data-*.arrow under the directory or its immediate subdirs - arrow_files = [f for f in os.listdir(dataset_arg) if f.startswith("data-") and f.endswith(".arrow")] - search_dir = dataset_arg - if not arrow_files: - subdirs = [os.path.join(dataset_arg, d) for d in os.listdir(dataset_arg) if os.path.isdir(os.path.join(dataset_arg, d))] - for sd in subdirs: - cand = [f for f in os.listdir(sd) if f.startswith("data-") and f.endswith(".arrow")] - if cand: - search_dir = sd - arrow_files = cand - break - if not arrow_files: - raise RuntimeError("unable to find data-*.arrow file for fallback reading") - file_path = os.path.join(search_dir, sorted(arrow_files)[0]) - with pa.memory_map(file_path, "r") as source: - reader = pa_ipc.RecordBatchFileReader(source) - table = reader.read_all() - if text_field not in table.column_names: - raise KeyError(f"column {text_field} not found in Arrow table: available columns {table.column_names}") - col_py = table[text_field].to_pylist() - for value in col_py: - if isinstance(value, str) and value.strip(): - texts.append(value.strip()) - if len(texts) >= max_samples: - break - if not texts: - raise RuntimeError(f"dataset has no available text in column {text_field}") - return texts - except Exception as e2: - raise RuntimeError(f"local dataset loading failed, and Arrow fallback failed: {e2}") from e + ds = load_from_disk(dataset_arg) else: - # Remote or canonical HF dataset id ds = load_dataset(dataset_arg) - - # Now ds should be a Dataset-like - if ds is None: - raise RuntimeError("unable to load dataset") - - # Extract text_field - col = ds[text_field] - for value in col: + if hasattr(ds, "column_names") and text_field not in ds.column_names: + raise KeyError(f"column {text_field!r} not found; available: {ds.column_names}") + texts = [] + for value in ds[text_field]: if isinstance(value, str) and value.strip(): texts.append(value.strip()) if len(texts) >= max_samples: break - if not texts: - raise RuntimeError(f"dataset has no available text in column {text_field}") + raise RuntimeError(f"no usable text found in column {text_field!r}") return texts @@ -146,45 +115,28 @@ def _compute_loss_with_temp_removed_layer( max_length: int, remove_index: int, ) -> float: - """Temporarily remove one layer, evaluate average loss, and restore.""" - parent, layers_list = _resolve_layers_module(model) - used_attr: Optional[str] = None - if hasattr(parent, "layers") and isinstance(getattr(parent, "layers"), nn.ModuleList): - used_attr = "layers" - elif hasattr(parent, "h") and isinstance(getattr(parent, "h"), nn.ModuleList): - used_attr = "h" - else: - raise RuntimeError("unable to locate layer container attribute for temporary removal") - - original_modules = list(layers_list) - original_num_layers: Optional[int] = getattr(getattr(model, "config", object()), "num_hidden_layers", None) - + """Temporarily remove one layer, evaluate avg loss, restore the layer.""" + parent, attr, original = _resolve_layers_module(model) + original_num = getattr(getattr(model, "config", object()), "num_hidden_layers", None) try: - keep = [m for i, m in enumerate(original_modules) if i != remove_index] - setattr(parent, used_attr, nn.ModuleList(keep)) - if original_num_layers is not None: + keep = [m for i, m in enumerate(original) if i != remove_index] + setattr(parent, attr, nn.ModuleList(keep)) + if original_num is not None: model.config.num_hidden_layers = len(keep) return _compute_reference_loss(model, tokenizer, texts, device, batch_size, max_length) finally: - setattr(parent, used_attr, nn.ModuleList(original_modules)) - if original_num_layers is not None: - model.config.num_hidden_layers = original_num_layers + setattr(parent, attr, nn.ModuleList(original)) + if original_num is not None: + model.config.num_hidden_layers = original_num def _prune_layers_inplace(model: nn.Module, remove_indices: List[int]) -> None: - parent, layers_list = _resolve_layers_module(model) - keep = [m for i, m in enumerate(layers_list) if i not in set(remove_indices)] - - # Assign back as ModuleList to the parent - if hasattr(parent, "layers") and isinstance(getattr(parent, "layers"), nn.ModuleList): - parent.layers = nn.ModuleList(keep) # type: ignore[attr-defined] - elif hasattr(parent, "h") and isinstance(getattr(parent, "h"), nn.ModuleList): - parent.h = nn.ModuleList(keep) # type: ignore[attr-defined] - else: - raise RuntimeError("unable to write back pruned layer list") - - # Update config if present - if hasattr(model, "config") and hasattr(model.config, "num_hidden_layers"): + """Drop the listed layer indices and update ``num_hidden_layers``.""" + parent, attr, layers = _resolve_layers_module(model) + drop = set(remove_indices) + keep = [m for i, m in enumerate(layers) if i not in drop] + setattr(parent, attr, nn.ModuleList(keep)) + if hasattr(getattr(model, "config", None), "num_hidden_layers"): model.config.num_hidden_layers = len(keep) @@ -235,7 +187,7 @@ def main() -> int: print(f"baseline PPL: {ref_ppl:.4f}") # Try removing each single layer and compute loss/PPL - parent, layers_list = _resolve_layers_module(model) + _, _, layers_list = _resolve_layers_module(model) num_layers = len(layers_list) per_layer_loss: List[float] = [] per_layer_ppl: List[float] = [] diff --git a/script/recipes/qwen3_0.6/eval_base.yaml b/script/recipes/qwen3_0.6/eval_base.yaml deleted file mode 100644 index bcbc012..0000000 --- a/script/recipes/qwen3_0.6/eval_base.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# model config -dtype: bfloat16 -mem_fraction_static: 0.90 - -# inference config -temperature: 0.6 -top_p: 0.95 -top_k: 20 -min_p: 0.0 -max_new_tokens: 4096 -batch_size: 32 # total batch size -repeat_size: 1 # number of samples per question \ No newline at end of file diff --git a/script/recipes/qwen3_0.6/eval_tah.yaml b/script/recipes/qwen3_0.6/eval_tah.yaml index 310bfc7..79b3f9d 100644 --- a/script/recipes/qwen3_0.6/eval_tah.yaml +++ b/script/recipes/qwen3_0.6/eval_tah.yaml @@ -1,33 +1,34 @@ -# model config dtype: bfloat16 -mem_fraction_static: 0.80 - -# inference config +mem_fraction_static: 0.8 temperature: 0.6 top_p: 0.95 top_k: 20 min_p: 0.0 max_new_tokens: 4096 -batch_size: 16 # total batch size -repeat_size: 1 # number of samples per question - -# tracker config -use_tracker: true -tracker_kwargs: - top_k: 5 - -# tah config +batch_size: 16 +repeat_size: 1 max_iter: 2 -iter_decider: "MLPIterDecider" +iter_decider: MLPIterDecider iter_decider_kwargs: topk: 100 hidden_states_size: 1024 - hidden_states_layer_nums: [0,7,13,20,26] - hidden_dims: [256, 256, 256, 256, 256, 256] + hidden_states_layer_nums: + - 0 + - 7 + - 13 + - 20 + - 26 + hidden_dims: + - 256 + - 256 + - 256 + - 256 + - 256 + - 256 expansion_factor: 4 dropout_rate: 0.1 - normalize_input: False + normalize_input: false threshold: 0.9 max_iter: 2 -eval_iter_decider: "iter_decider" -eval_iter_decider_kwargs: {} \ No newline at end of file +eval_iter_decider: iter_decider +eval_iter_decider_kwargs: {} diff --git a/script/recipes/qwen3_0.6/sft_base.yaml b/script/recipes/qwen3_0.6/sft_base.yaml deleted file mode 100755 index bed843d..0000000 --- a/script/recipes/qwen3_0.6/sft_base.yaml +++ /dev/null @@ -1,55 +0,0 @@ -### Model Configuration ### -model: - name: Qwen/Qwen3-0.6B-Base - torch_dtype: "bfloat16" - device_map: "auto" - trust_remote_code: true - attn_implementation: "sdpa" - embedding_key: "model.embed_tokens" - max_iter: 1 - input_updater: "TrivialUpdater" - input_updater_kwargs: {} - iter_decider: "TrivialIterDecider" - iter_decider_kwargs: {} - output_updater: "NoneUpdater" - output_updater_kwargs: {} - adapter: "none" - adapter_kwargs: {} - train_loss: "NextTokenPredLoss" - eval_loss: "NextTokenPredLoss" - -### data ### -data: - train_data_path: data/processed_data/openr1_math/0_6/train - eval_data_path: data/processed_data/openr1_math/0_6/eval - output_dir: "output/openr1_math/0_6/" - max_length: 8192 - -### training ### -training: - num_train_epochs: 5 - per_device_train_batch_size: 1 - gradient_accumulation_steps: 16 - gradient_checkpointing: false - learning_rate: 4.0e-5 - weight_decay: 0.01 - warmup_ratio: 0.03 - max_grad_norm: 0.2 - lr_scheduler_type: "cosine_with_min_lr" - lr_scheduler_kwargs: - min_lr_rate: 0.1 - logging_steps: 1 - save_strategy: "epoch" - save_only_model: true - save_total_limit: 50 - bf16: true - # evaluation - eval_strategy: "steps" - eval_steps: 40 - eval_on_start: true - per_device_eval_batch_size: 1 - # wandb - report_to: "wandb" # Options: "none", "wandb" - wandb_project: "TaH" - wandb_name: "openr1_0.6base_standard" - # wandb_entity: "" diff --git a/script/recipes/qwen3_0.6/sft_tah_step1.yaml b/script/recipes/qwen3_0.6/sft_tah_step1.yaml index bd9818a..aa058ab 100755 --- a/script/recipes/qwen3_0.6/sft_tah_step1.yaml +++ b/script/recipes/qwen3_0.6/sft_tah_step1.yaml @@ -1,34 +1,31 @@ ### model ### model: name: Qwen/Qwen3-0.6B-Base # For TaH+ version - # name: ./model/qwen3_0.6_base_pruned # For TaH version - torch_dtype: "bfloat16" + # name: ./model/qwen3_0.6_base_pruned # For TaH version (drops one layer) + torch_dtype: "bfloat16" device_map: "auto" trust_remote_code: true attn_implementation: "sdpa" embedding_key: "model.embed_tokens" max_iter: 2 - iter_label_generator: "FixedIterLabelGenerator" - iter_label_generator_kwargs: {} - input_updater: "TrivialUpdater" - input_updater_kwargs: - topk: 100 - iter_decider: "FixedLabelIterDecider" + + # Step 1: oracle-supervised iteration; the LoRA adapter learns to refine + # tokens that the labeller marked as "hard" (mismatch=1). + iter_decider: "IterLabelDecider" iter_decider_kwargs: - label_type: "mismatch" max_iter: 2 - eval_iter_decider: "iter_decider" - eval_iter_decider_kwargs: - max_iter: 2 - output_updater: "AdditiveLogitsUpdater" - output_updater_kwargs: {} + + input_updater_kwargs: + topk: 100 + adapter: "lora" - adapter_kwargs: + adapter_kwargs: r: 16 lora_alpha: 32 lora_dropout: 0.1 target_modules: "all-linear" bias: "none" + train_loss: "NextTokenPredLoss" eval_loss: "NextTokenPredLoss" @@ -38,11 +35,10 @@ data: eval_data_path: data/processed_data/openr1_math/0_6/eval output_dir: "output/openr1_math/0_6/" max_length: 8192 - iter_count_strategy: "mismatch" ### training ### training: - num_train_epochs: 5 + num_train_epochs: 5 per_device_train_batch_size: 1 gradient_accumulation_steps: 16 gradient_checkpointing: false @@ -58,13 +54,10 @@ training: save_only_model: true save_total_limit: 50 bf16: true - # evaluation config eval_strategy: "steps" eval_steps: 40 eval_on_start: true per_device_eval_batch_size: 1 - # wandb config - report_to: "wandb" # Options: "none", "wandb" + report_to: "none" # "none" or "wandb" wandb_project: "TaH" wandb_name: "openr1_0.6base_step1" - # wandb_entity: "" \ No newline at end of file diff --git a/script/recipes/qwen3_0.6/sft_tah_step2.yaml b/script/recipes/qwen3_0.6/sft_tah_step2.yaml index eab007a..b4cdcd3 100755 --- a/script/recipes/qwen3_0.6/sft_tah_step2.yaml +++ b/script/recipes/qwen3_0.6/sft_tah_step2.yaml @@ -1,26 +1,20 @@ ### model ### model: name: Qwen/Qwen3-0.6B-Base # For TaH+ version - # name: ./model/qwen3_0.6_base_pruned # For TaH version - tah_model_path: /path/to/step1-model/epoch5 - torch_dtype: "bfloat16" + # name: ./model/qwen3_0.6_base_pruned # For TaH version (drops one layer) + tah_model_path: /path/to/step1-model/epoch5 # ← required: step-1 checkpoint + torch_dtype: "bfloat16" device_map: "auto" trust_remote_code: true attn_implementation: "sdpa" embedding_key: "model.embed_tokens" max_iter: 2 - use_iter_embedding: false - input_updater: "TrivialUpdater" - input_updater_kwargs: - topk: 100 - iter_label_generator: "FixedIterLabelGenerator" - iter_label_generator_kwargs: - max_iter: 2 + iter_decider: "MLPIterDecider" iter_decider_kwargs: topk: 100 hidden_states_size: 1024 - hidden_states_layer_nums: [0,7,13,20,26] + hidden_states_layer_nums: [0, 7, 13, 20, 26] hidden_dims: [256, 256, 256, 256, 256, 256] expansion_factor: 4 dropout_rate: 0.1 @@ -29,18 +23,19 @@ model: eval_iter_decider: "iter_decider" eval_iter_decider_kwargs: {} - output_updater: "AdditiveLogitsUpdater" # Options: "NoneUpdater", "AdditiveLogitsUpdater" - output_updater_kwargs: {} + input_updater_kwargs: + topk: 100 - adapter: "lora" # Options: "lora", "cascade", "none", "multilora" + adapter: "lora" adapter_kwargs: r: 16 lora_alpha: 32 lora_dropout: 0.1 target_modules: "all-linear" base_grad: false - adapter_grad: false + adapter_grad: false # adapter is frozen; only the iter decider trains bias: "none" + train_loss: "IterDeciderLoss" train_loss_kwargs: pos_weight: 4.8 @@ -55,7 +50,6 @@ data: eval_data_path: data/processed_data/openr1_math/0_6/eval output_dir: "output/openr1_math/0_6/" max_length: 8192 - iter_count_strategy: "mismatch" ### training ### training: @@ -77,13 +71,10 @@ training: save_only_model: true save_total_limit: 10 bf16: true - # evaluation config eval_strategy: "steps" eval_steps: 40 eval_on_start: true per_device_eval_batch_size: 1 - # wandb config - report_to: "wandb" # Options: "none", "wandb" + report_to: "wandb" # "none" or "wandb" wandb_project: "TaH" wandb_name: "openr1_0.6base_step2" - # wandb_entity: "" \ No newline at end of file diff --git a/script/recipes/qwen3_1.7/eval_base.yaml b/script/recipes/qwen3_1.7/eval_base.yaml deleted file mode 100644 index 90d831b..0000000 --- a/script/recipes/qwen3_1.7/eval_base.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# model config -dtype: bfloat16 -mem_fraction_static: 0.90 - -# inference config -temperature: 0.6 -top_p: 0.95 -top_k: 20 -min_p: 0.0 -max_new_tokens: 8192 -batch_size: 64 # total batch size -repeat_size: 1 # number of samples per question \ No newline at end of file diff --git a/script/recipes/qwen3_1.7/eval_tah.yaml b/script/recipes/qwen3_1.7/eval_tah.yaml index 5e2ee77..367b50f 100644 --- a/script/recipes/qwen3_1.7/eval_tah.yaml +++ b/script/recipes/qwen3_1.7/eval_tah.yaml @@ -1,33 +1,33 @@ -# model config dtype: bfloat16 -mem_fraction_static: 0.80 - -# inference config +mem_fraction_static: 0.8 temperature: 0.6 top_p: 0.95 top_k: 20 min_p: 0.0 max_new_tokens: 4096 -batch_size: 16 # total batch size -repeat_size: 1 # number of samples per question - -# tracker config -use_tracker: true -tracker_kwargs: - top_k: 5 - -# tah config -max_iter: 2 # threshold 0.9 -iter_decider: "MLPIterDecider" +batch_size: 16 +repeat_size: 1 +max_iter: 2 +iter_decider: MLPIterDecider iter_decider_kwargs: topk: 100 hidden_states_size: 2048 - hidden_states_layer_nums: [2,10,18,26] - hidden_dims: [512, 512, 512, 512, 512, 512] + hidden_states_layer_nums: + - 2 + - 10 + - 18 + - 26 + hidden_dims: + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 expansion_factor: 4 dropout_rate: 0.1 - normalize_input: False + normalize_input: false threshold: 0.9 max_iter: 2 -eval_iter_decider: "iter_decider" -eval_iter_decider_kwargs: {} \ No newline at end of file +eval_iter_decider: iter_decider +eval_iter_decider_kwargs: {} diff --git a/script/recipes/qwen3_1.7/eval_tah_oracle.yaml b/script/recipes/qwen3_1.7/eval_tah_oracle.yaml deleted file mode 100644 index 0c2240a..0000000 --- a/script/recipes/qwen3_1.7/eval_tah_oracle.yaml +++ /dev/null @@ -1,31 +0,0 @@ -# model config -dtype: bfloat16 -mem_fraction_static: 0.80 - -# inference config -temperature: 0.6 -top_p: 0.95 -top_k: 20 -min_p: 0.0 -max_new_tokens: 8192 -batch_size: 1 # total batch size -repeat_size: 1 # number of samples per question - -# tracker config -use_tracker: true -tracker_kwargs: - top_k: 5 - -# tah config -max_iter: 2 -eval_iter_decider: "OracleDynamicIterDecider" -eval_iter_decider_kwargs: - ref_model_path: "/path/to/DeepSeek-R1-Distill-Qwen-32B" - max_iter: 2 - device: "cuda" - dtype: "float16" - backend: "sglang" # hf sglang - # false_negative_rate: 0 - # false_positive_rate: 0 -eval_loss: "NextTokenPredLoss" -eval_loss_kwargs: {} \ No newline at end of file diff --git a/script/recipes/qwen3_1.7/sft_base.yaml b/script/recipes/qwen3_1.7/sft_base.yaml deleted file mode 100755 index 4df4467..0000000 --- a/script/recipes/qwen3_1.7/sft_base.yaml +++ /dev/null @@ -1,55 +0,0 @@ -### Model Configuration ### -model: - name: Qwen/Qwen3-1.7B-Base - torch_dtype: "bfloat16" - device_map: "auto" - trust_remote_code: true - attn_implementation: "sdpa" - embedding_key: "model.embed_tokens" - max_iter: 1 - input_updater: "TrivialUpdater" - input_updater_kwargs: {} - iter_decider: "TrivialIterDecider" - iter_decider_kwargs: {} - output_updater: "NoneUpdater" - output_updater_kwargs: {} - adapter: "none" - adapter_kwargs: {} - train_loss: "NextTokenPredLoss" - eval_loss: "NextTokenPredLoss" - -### data ### -data: - train_data_path: data/processed_data/openr1/1_7/train - eval_data_path: data/processed_data/openr1/1_7/eval - output_dir: "output/openr1/" - max_length: 10000 - -### training ### -training: - num_train_epochs: 5 - per_device_train_batch_size: 1 - gradient_accumulation_steps: 16 - gradient_checkpointing: false - learning_rate: 4.0e-5 - weight_decay: 0.01 - warmup_ratio: 0.03 - max_grad_norm: 0.2 - lr_scheduler_type: "cosine_with_min_lr" - lr_scheduler_kwargs: - min_lr_rate: 0.1 - logging_steps: 1 - save_strategy: "epoch" - save_only_model: true - save_total_limit: 50 - bf16: true - # evaluation - eval_strategy: "steps" - eval_steps: 10 - eval_on_start: false - per_device_eval_batch_size: 1 - # wandb - report_to: "wandb" # Options: "none", "wandb" - wandb_project: "TaH" - wandb_name: "openr1_1.7base_standard_1015" - # wandb_entity: "" diff --git a/script/recipes/qwen3_1.7/sft_tah_step1.yaml b/script/recipes/qwen3_1.7/sft_tah_step1.yaml index 2232523..63af69a 100755 --- a/script/recipes/qwen3_1.7/sft_tah_step1.yaml +++ b/script/recipes/qwen3_1.7/sft_tah_step1.yaml @@ -1,33 +1,33 @@ ### model ### model: name: Qwen/Qwen3-1.7B-Base # For TaH+ version - # name: ./model/qwen3_1.7_base_pruned # For TaH version - torch_dtype: "bfloat16" + # name: ./model/qwen3_1.7_base_pruned # For TaH version (drops one layer) + torch_dtype: "bfloat16" device_map: "auto" trust_remote_code: true attn_implementation: "sdpa" embedding_key: "model.embed_tokens" max_iter: 2 - iter_label_generator: "FixedIterLabelGenerator" - iter_label_generator_kwargs: {} - input_updater: "TrivialUpdater" - input_updater_kwargs: - topk: 100 + + # Step 1: oracle-supervised iteration; the LoRA adapter learns to refine + # tokens that the labeller marked as "hard" (mismatch=1). iter_decider: "IterLabelDecider" iter_decider_kwargs: max_iter: 2 - eval_iter_decider: "iter_decider" - eval_iter_decider_kwargs: - max_iter: 2 - output_updater: "AdditiveLogitsUpdater" - output_updater_kwargs: {} + + # Top-k softmax over previous logits feeds the next iteration's input + # embedding (only `topk` is read by the wrapper; rest is ignored). + input_updater_kwargs: + topk: 100 + adapter: "lora" - adapter_kwargs: + adapter_kwargs: r: 32 lora_alpha: 64 lora_dropout: 0.1 target_modules: "all-linear" bias: "none" + train_loss: "NextTokenPredLoss" eval_loss: "NextTokenPredLoss" @@ -37,11 +37,10 @@ data: eval_data_path: data/processed_data/openr1_math/1_7/eval output_dir: "output/openr1_math/" max_length: 10000 - iter_count_strategy: "mismatch" ### training ### training: - num_train_epochs: 5 + num_train_epochs: 5 per_device_train_batch_size: 1 gradient_accumulation_steps: 16 gradient_checkpointing: false @@ -57,13 +56,10 @@ training: save_only_model: true save_total_limit: 50 bf16: true - # evaluation config eval_strategy: "epoch" eval_steps: 40 eval_on_start: true per_device_eval_batch_size: 1 - # wandb config - report_to: "none" # Options: "none", "wandb" + report_to: "none" # "none" or "wandb" wandb_project: "TaH" wandb_name: "openr1_1.7base_step1" - # wandb_entity: "" \ No newline at end of file diff --git a/script/recipes/qwen3_1.7/sft_tah_step2.yaml b/script/recipes/qwen3_1.7/sft_tah_step2.yaml index 5655e17..1764bc4 100755 --- a/script/recipes/qwen3_1.7/sft_tah_step2.yaml +++ b/script/recipes/qwen3_1.7/sft_tah_step2.yaml @@ -1,50 +1,47 @@ ### model ### model: name: Qwen/Qwen3-1.7B-Base # For TaH+ version - # name: ./model/qwen3_1.7_base_pruned # For TaH version - tah_model_path: /path/to/step1-model/epoch4 - torch_dtype: "bfloat16" + # name: ./model/qwen3_1.7_base_pruned # For TaH version (drops one layer) + tah_model_path: /path/to/step1-model/epoch4 # ← required: step-1 checkpoint + torch_dtype: "bfloat16" device_map: "auto" trust_remote_code: true attn_implementation: "sdpa" embedding_key: "model.embed_tokens" max_iter: 2 - use_iter_embedding: false - input_updater: "TrivialUpdater" - input_updater_kwargs: - topk: 100 - iter_label_generator: "FixedIterLabelGenerator" - iter_label_generator_kwargs: - max_iter: 2 + + # Step 2: train the learned MLP iter decider on top of the frozen base+LoRA. iter_decider: "MLPIterDecider" iter_decider_kwargs: topk: 100 hidden_states_size: 2048 - hidden_states_layer_nums: [2,10,18,26] + hidden_states_layer_nums: [2, 10, 18, 26] hidden_dims: [512, 512, 512, 512, 512, 512] expansion_factor: 4 dropout_rate: 0.1 normalize_input: false threshold: 0.8 + # eval-time decider is the same one we just trained. eval_iter_decider: "iter_decider" eval_iter_decider_kwargs: {} - output_updater: "AdditiveLogitsUpdater" # Options: "NoneUpdater", "AdditiveLogitsUpdater" - output_updater_kwargs: {} + input_updater_kwargs: + topk: 100 - adapter: "lora" # Options: "lora", "cascade", "none", "multilora" + adapter: "lora" adapter_kwargs: r: 32 lora_alpha: 64 lora_dropout: 0.1 target_modules: "all-linear" base_grad: false - adapter_grad: false + adapter_grad: false # adapter is frozen; only the iter decider trains bias: "none" + train_loss: "IterDeciderLoss" train_loss_kwargs: - pos_weight: 5.4 - skip_last_iter: true + pos_weight: 5.4 # positive-class upweight for BCE + skip_last_iter: true # don't supervise the always-stop final iter max_iter: 2 eval_loss: "NextTokenPredLoss" eval_loss_kwargs: {} @@ -55,7 +52,6 @@ data: eval_data_path: data/processed_data/openr1_math/1_7/eval output_dir: "output/openr1_math/1_7/" max_length: 8192 - iter_count_strategy: "mismatch" ### training ### training: @@ -77,13 +73,10 @@ training: save_only_model: true save_total_limit: 10 bf16: true - # evaluation config eval_strategy: "steps" eval_steps: 40 eval_on_start: true per_device_eval_batch_size: 1 - # wandb config - report_to: "wandb" # Options: "none", "wandb" + report_to: "wandb" # "none" or "wandb" wandb_project: "TaH" wandb_name: "openr1_1.7base_step2" - # wandb_entity: "" \ No newline at end of file diff --git a/script/train/SFT_TaH.py b/script/train/SFT_TaH.py old mode 100755 new mode 100644 index 9266a33..6b5785d --- a/script/train/SFT_TaH.py +++ b/script/train/SFT_TaH.py @@ -1,405 +1,211 @@ +"""TaH SFT entrypoint. + +Loads a YAML recipe (see ``script/recipes/qwen3_1.7/sft_tah_step{1,2}.yaml``), +constructs the wrapper, runs HuggingFace Trainer. + +Step 1 — train the LoRA adapter against oracle iteration labels + (iter_decider=IterLabelDecider, train_loss=NextTokenPredLoss). + +Step 2 — train the iter decider on top of the frozen base+adapter + (iter_decider=MLPIterDecider, train_loss=IterDeciderLoss, + freeze_component=[model.simple_base_model]). + +Run via ``accelerate launch`` so DeepSpeed / DDP wrappers are in place: + + python -m accelerate.commands.launch \ + --config_file ./script/recipes/accelerate_configs/zero2.yaml \ + --num_processes 8 \ + ./script/train/SFT_TaH.py \ + --config ./script/recipes/qwen3_1.7/sft_tah_step1.yaml +""" +from __future__ import annotations + +import argparse import os +from dataclasses import fields +from datetime import datetime +from typing import Dict + import torch import yaml -import argparse -from datetime import datetime -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - TrainingArguments, -) from accelerate import Accelerator from accelerate.utils import broadcast_object_list -from typing import Dict +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments -from tah.model.recurrent_transformer import TaHForCausalLM -from tah.model.tah_config import TaHConfig from tah.model.iter_decider import load_iter_decider -from tah.train import CustomTaHTrainer, CustomTaHDataCollator, LoggerCallback +from tah.model.tah_config import TaHConfig +from tah.model.tah_model import TaHForCausalLM +from tah.model.utils import compute_trainable_param_size_gb, freeze_components, set_all_seeds +from tah.train import CustomTaHDataCollator, CustomTaHTrainer, LoggerCallback from tah.utils.data_prepare import preprocess_dataset -try: - from liger_kernel.transformers import AutoLigerKernelForCausalLM -except ImportError: - AutoLigerKernelForCausalLM = None +set_all_seeds(420) -from tah.model.utils import set_all_seeds, freeze_components, compute_trainable_param_size_gb -from dataclasses import fields -set_all_seeds(420) +_DTYPE_BY_STR = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32} -def load_config(config_path: str) -> Dict: - """Load configuration from YAML file.""" - print(f"Loading configuration from: {config_path}") - with open(config_path, 'r', encoding='utf-8') as f: +def _load_yaml(path: str) -> Dict: + print(f"Loading configuration from: {path}") + with open(path, "r", encoding="utf-8") as f: return yaml.safe_load(f) -def load_model_and_tokenizer(training_config: Dict, model_config: Dict, accelerator: Accelerator): - """Load model and tokenizer based on configuration.""" - accelerator.print("Loading model and tokenizer...") - - # Convert torch dtype string to actual torch dtype - dtype_mapping = { - 'bfloat16': torch.bfloat16, - 'float16': torch.float16, - 'float32': torch.float32 - } - - torch_dtype = dtype_mapping.get(model_config['torch_dtype'], torch.bfloat16) - - # use accelerator's device setting - device_map = None if accelerator.num_processes >= 1 else model_config.get('device_map', 'auto') - - # Check if we should load a pretrained TaH model - if 'tah_model_path' in model_config: - accelerator.print(f"Loading pretrained TaH model from: {model_config['tah_model_path']}") - - # Create TaH config for overriding if specified in model_config - tah_config = None - tah_config_fields = [field.name for field in fields(TaHConfig)] - - if any(key in model_config for key in tah_config_fields): - # Only set fields that are present in the YAML config - overide_config_dict = {} - for field in tah_config_fields: - if field in model_config: - overide_config_dict[field] = model_config[field] - - tah_config = TaHConfig(**overide_config_dict) - - accelerator.print("Using TaH config from YAML to override saved config:") - accelerator.print(f"TaH config override: {overide_config_dict}") - else: - accelerator.print("Using saved TaH config from pretrained model") - - # Load pretrained TaH model with optional config override - model = TaHForCausalLM.from_pretrained( - model_config['tah_model_path'], - tah_config=tah_config - ).to(dtype=torch_dtype) - # Load tokenizer from the original base model path or from the tah model path - tokenizer_path = model_config.get('name', model_config['tah_model_path']) - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, - trust_remote_code=model_config.get('trust_remote_code', True), - padding_side="right" - ) +def _build_model_and_tokenizer(model_config: Dict, accelerator: Accelerator): + """Either resume an existing TaH checkpoint or build a fresh wrapper.""" + accelerator.print("Loading model and tokenizer…") + torch_dtype = _DTYPE_BY_STR.get(model_config["torch_dtype"], torch.bfloat16) + # accelerate handles device placement; never pass device_map="auto" here. + device_map = None - accelerator.print("Successfully loaded pretrained TaH model") - accelerator.print(f"Model architecture: {model}") - + # Tokenizer location: use ``name`` if given, else fall back to the TaH ckpt. + tok_path = model_config.get("name", model_config.get("tah_model_path")) + tokenizer = AutoTokenizer.from_pretrained( + tok_path, trust_remote_code=model_config.get("trust_remote_code", True), padding_side="right", + ) + tokenizer.pad_token = tokenizer.eos_token + + if "tah_model_path" in model_config: + accelerator.print(f"Resuming from TaH checkpoint: {model_config['tah_model_path']}") + # Override only the fields explicitly set in the new YAML. + valid = {f.name for f in fields(TaHConfig)} + override = TaHConfig(**{k: v for k, v in model_config.items() if k in valid}) + model = TaHForCausalLM.from_pretrained( + model_config["tah_model_path"], tah_config=override, + ).to(dtype=torch_dtype) else: - # Original logic for loading base model and creating TaH model - tokenizer = AutoTokenizer.from_pretrained( - model_config['name'], - trust_remote_code=model_config['trust_remote_code'], - padding_side="right" + # Construct fresh from a base model + recipe-specified components. + valid = {f.name for f in fields(TaHConfig)} + cfg = TaHConfig(**{k: v for k, v in model_config.items() if k in valid}) + base = AutoModelForCausalLM.from_pretrained( + model_config["name"], torch_dtype=torch_dtype, device_map=device_map, + trust_remote_code=model_config.get("trust_remote_code", True), + attn_implementation=model_config.get("attn_implementation", "sdpa"), ) - - # Create TaH config (populate only from provided keys) - tah_config = TaHConfig(embedding_key=model_config.get('embedding_key', "model.embed_tokens")) - for f in fields(TaHConfig): - if f.name == 'embedding_key': - continue - if f.name in model_config: - setattr(tah_config, f.name, model_config[f.name]) - use_base_model_only = (tah_config.max_iter == 1) - - # load base model - if training_config.get('enable_liger_kernel', False): - if AutoLigerKernelForCausalLM is None: - raise ImportError("liger_kernel is not installed. Please install it using `pip install liger_kernel`.") - base_model = AutoLigerKernelForCausalLM.from_pretrained( - model_config['name'], - torch_dtype=torch_dtype, - device_map=device_map, - trust_remote_code=model_config['trust_remote_code'], - attn_implementation=model_config['attn_implementation'] - ) - accelerator.print("Using Liger Kernel") + if "load_path" in (cfg.iter_decider_kwargs or {}): + iter_decider_path = cfg.iter_decider_kwargs.pop("load_path") + model = TaHForCausalLM(base_model=base, config=cfg) + model.iter_decider = load_iter_decider(iter_decider_path) else: - base_model = AutoModelForCausalLM.from_pretrained( - model_config['name'], - torch_dtype=torch_dtype, - device_map=device_map if not use_base_model_only else None, # cannot use device_map=auto for base model if using TaH - trust_remote_code=model_config['trust_remote_code'], - attn_implementation=model_config['attn_implementation'] - ) - - if tah_config.max_iter == 1: - model = base_model - else: - if "load_path" in tah_config.iter_decider_kwargs: - iter_decider_path = tah_config.iter_decider_kwargs.pop("load_path") - model = TaHForCausalLM(base_model=base_model, config=tah_config, device_map=device_map) - model.iter_decider = load_iter_decider(iter_decider_path) - else: - # regular init - model = TaHForCausalLM(base_model=base_model, config=tah_config, device_map=device_map) - - tokenizer.pad_token = tokenizer.eos_token - + model = TaHForCausalLM(base_model=base, config=cfg) + return model, tokenizer -def create_training_args(training_config: Dict, data_config: Dict, output_dir: str, accelerator: Accelerator, timestamp: str = None) -> TrainingArguments: - """Create training arguments from configuration.""" - accelerator.print("Configuring training arguments...") - - training_args_dict = { - 'output_dir': output_dir, - 'num_train_epochs': training_config['num_train_epochs'], - 'per_device_train_batch_size': training_config['per_device_train_batch_size'], - 'gradient_accumulation_steps': training_config['gradient_accumulation_steps'], - 'gradient_checkpointing': training_config['gradient_checkpointing'], - 'learning_rate': training_config['learning_rate'], - 'warmup_ratio': training_config['warmup_ratio'], - 'weight_decay': training_config['weight_decay'], - 'max_grad_norm': training_config['max_grad_norm'], - 'lr_scheduler_type': training_config['lr_scheduler_type'], - 'lr_scheduler_kwargs': training_config['lr_scheduler_kwargs'], - 'logging_steps': training_config['logging_steps'], - 'save_strategy': training_config['save_strategy'], - 'save_steps': training_config.get('save_steps', 100), - 'save_only_model': training_config['save_only_model'], - 'save_total_limit': training_config['save_total_limit'], - 'report_to': training_config['report_to'], - 'bf16': training_config['bf16'], - # accelerate - 'remove_unused_columns': False, - 'ddp_find_unused_parameters': False, +def _build_training_args(training_config: Dict, data_config: Dict, output_dir: str, timestamp: str) -> TrainingArguments: + args = { + "output_dir": output_dir, + "num_train_epochs": training_config["num_train_epochs"], + "per_device_train_batch_size": training_config["per_device_train_batch_size"], + "gradient_accumulation_steps": training_config["gradient_accumulation_steps"], + "gradient_checkpointing": training_config.get("gradient_checkpointing", False), + "learning_rate": training_config["learning_rate"], + "warmup_ratio": training_config["warmup_ratio"], + "weight_decay": training_config["weight_decay"], + "max_grad_norm": training_config["max_grad_norm"], + "lr_scheduler_type": training_config["lr_scheduler_type"], + "lr_scheduler_kwargs": training_config["lr_scheduler_kwargs"], + "logging_steps": training_config["logging_steps"], + "save_strategy": training_config["save_strategy"], + "save_steps": training_config.get("save_steps", 100), + "save_only_model": training_config["save_only_model"], + "save_total_limit": training_config["save_total_limit"], + "report_to": training_config["report_to"], + "bf16": training_config["bf16"], + "remove_unused_columns": False, + "ddp_find_unused_parameters": False, } - - # set evaluation dataset ratio - if data_config.get('eval_data_ratio', 0.0) > 0 or data_config.get('eval_data_path', None) is not None: - training_args_dict['eval_strategy'] = training_config.get('eval_strategy') - training_args_dict['eval_steps'] = training_config.get('eval_steps') - training_args_dict['per_device_eval_batch_size'] = training_config.get('per_device_eval_batch_size') - training_args_dict['eval_on_start'] = training_config.get('eval_on_start') - - # set wandb related environment variables and run_name parameter - if training_config['report_to'] == "wandb": - # set environment variables - if 'wandb_project' in training_config: - os.environ['WANDB_PROJECT'] = training_config['wandb_project'] - if 'wandb_name' in training_config: - os.environ['WANDB_NAME'] = training_config['wandb_name'] - if 'wandb_entity' in training_config: - os.environ['WANDB_ENTITY'] = training_config['wandb_entity'] - - # use run_name parameter (transformers 4.52.4 supported) - default_run_name = f"training_{timestamp}" if timestamp else f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - training_args_dict['run_name'] = training_config.get('wandb_name', default_run_name) - - return TrainingArguments(**training_args_dict) - - -def train_model(model, tokenizer, processed_train_dataset, processed_eval_dataset, training_args, accelerator: Accelerator, training_config: Dict, resume_from_checkpoint_path: str = None): - """Initialize trainer and start training.""" - # Use custom data collator that handles iter_count field - # Create trainer first (without data_collator) - trainer = CustomTaHTrainer( - model=model, - args=training_args, - train_dataset=processed_train_dataset, - eval_dataset=processed_eval_dataset, - processing_class=tokenizer, - prediction_config=None, - ) - - # Create data collator (noise logic removed from collator) - data_collator = CustomTaHDataCollator( - tokenizer=tokenizer, - padding=True, - ) - - # Set data collator on trainer - trainer.data_collator = data_collator - # Instantiate the LoggerCallback to track iter count - iter_count_callback = LoggerCallback( - trainer=trainer - ) - model.logger_callback = iter_count_callback - trainer.callback_handler.callbacks.insert(0, iter_count_callback) - - accelerator.print("\n--- Starting Training ---") - if resume_from_checkpoint_path is not None: - accelerator.print(f"Resuming training from checkpoint: {resume_from_checkpoint_path}") - trainer.train(resume_from_checkpoint=resume_from_checkpoint_path) - else: - trainer.train() - accelerator.print("--- Training Complete ---") - - # get training history - training_history = { - 'train_loss': [], - 'eval_loss': [], - } - - # get training history from trainer's log_history - if hasattr(trainer.state, 'log_history') and trainer.state.log_history: - accelerator.print(f"Found {len(trainer.state.log_history)} log entries") - for log_entry in trainer.state.log_history: - if 'train_loss' in log_entry: - training_history['train_loss'].append(log_entry['train_loss']) - if 'eval_loss' in log_entry: - training_history['eval_loss'].append(log_entry['eval_loss']) - - accelerator.print(f"Collected {len(training_history['train_loss'])} training loss entries") - accelerator.print(f"Collected {len(training_history['eval_loss'])} evaluation loss entries") - - return trainer, training_history - - -def save_final_model(trainer, tokenizer, output_dir: str, config: Dict, accelerator: Accelerator): - """Save the final model and configuration.""" - # - final_model_path = os.path.join(output_dir, "final_model") - accelerator.print(f"Saving final model to: {final_model_path}") - trainer.model.save_pretrained(final_model_path) - tokenizer.save_pretrained(final_model_path) - - accelerator.print("Model and configuration saved successfully!") - - -def main(config): - """Main training function.""" - # initialize accelerator first + if data_config.get("eval_data_path") or data_config.get("eval_data_ratio", 0.0) > 0: + args.update({ + "eval_strategy": training_config.get("eval_strategy"), + "eval_steps": training_config.get("eval_steps"), + "per_device_eval_batch_size": training_config.get("per_device_eval_batch_size"), + "eval_on_start": training_config.get("eval_on_start"), + }) + + if training_config["report_to"] == "wandb": + for env_key, cfg_key in (("WANDB_PROJECT", "wandb_project"), ("WANDB_NAME", "wandb_name"), ("WANDB_ENTITY", "wandb_entity")): + if cfg_key in training_config: + os.environ[env_key] = training_config[cfg_key] + args["run_name"] = training_config.get("wandb_name", f"training_{timestamp}") + + return TrainingArguments(**args) + + +def main(config: Dict): accelerator = Accelerator( - mixed_precision='bf16', + mixed_precision="bf16", log_with="wandb" if os.environ.get("WANDB_MODE") != "disabled" else None, ) - - # generate timestamp only on main process and broadcast to all processes - timestamp = None - if accelerator.is_main_process: - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - - # broadcast timestamp to all processes - timestamp_list = [timestamp] - broadcast_object_list(timestamp_list) - timestamp = timestamp_list[0] - - model_config = config['model'] - data_config = config['data'] - training_config = config['training'] - - # Extract key paths - output_dir = data_config['output_dir'] - - if 'tah_model_path' in model_config: - # For continued training from pretrained TaH model - base_model_name = model_config['name'].split('/')[-1] - output_dir = os.path.join(output_dir, "continue_training", base_model_name, timestamp) - accelerator.print(f"Continue training mode - using output directory: {output_dir}") + + # Single timestamp shared across all ranks. + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if accelerator.is_main_process else None + ts_holder = [timestamp] + broadcast_object_list(ts_holder) + timestamp = ts_holder[0] + + model_config = config["model"] + data_config = config["data"] + training_config = config["training"] + + # Output dir: continue-training puts everything under continue_training//; + # from-scratch puts under __/. + base_name = (model_config.get("name") or model_config["tah_model_path"]).split("/")[-1] + if "tah_model_path" in model_config: + output_dir = os.path.join(data_config["output_dir"], "continue_training", base_name, timestamp) else: - # For training from scratch - use detailed naming - output_dir = os.path.join(output_dir, (model_config['name'].split('/')[-1] + "_" + model_config['input_updater'][:-7])) - output_dir = output_dir + "_" + str(model_config['input_updater_kwargs'].get('num_layers', '')) - output_dir = output_dir + "_" + model_config['iter_decider'][:-11] - if model_config['adapter'] != 'none': - output_dir = output_dir + "_" + model_config['adapter'] - output_dir = os.path.join(output_dir, timestamp) - - # Load Model and Tokenizer - model, tokenizer = load_model_and_tokenizer(training_config, model_config, accelerator) - - # Optionally freeze specified components and report trainable size - freeze_list = training_config.get('freeze_component', []) + decider = (model_config.get("iter_decider") or "decider").rsplit("Decider", 1)[0] or "decider" + output_dir = os.path.join( + data_config["output_dir"], f"{base_name}_{decider}_{model_config.get('adapter', 'lora')}", timestamp, + ) + + model, tokenizer = _build_model_and_tokenizer(model_config, accelerator) + + freeze_list = training_config.get("freeze_component") or [] if isinstance(freeze_list, str): freeze_list = [freeze_list] if freeze_list: accelerator.print(f"Freezing components: {freeze_list}") freeze_components(model, freeze_list, accelerator) + accelerator.print(f"Trainable parameter size: {compute_trainable_param_size_gb(model):.3f} GB") - trainable_gb = compute_trainable_param_size_gb(model) - accelerator.print(f"Trainable parameter size: {trainable_gb:.3f} GB") - - # Preprocess Dataset - processed_train_dataset, processed_eval_dataset, avg_hard_ratio = preprocess_dataset(training_config, data_config, model_config, accelerator) - - # Calculate and set balanced weights if hard_token_relative_weight is not 1.0 - hard_token_relative_weight = training_config.get('hard_token_relative_weight', 1.0) - if hard_token_relative_weight != 1.0: - # Calculate weights such that: - # 1. p * weight_hard + (1 - p) * weight_easy = 1.0 - # 2. weight_hard / weight_easy = r - weight_easy = 1.0 / (avg_hard_ratio * hard_token_relative_weight + (1 - avg_hard_ratio)) - weight_hard = hard_token_relative_weight * weight_easy - model.weight_hard = weight_hard - model.weight_easy = weight_easy - model.hard_token_relative_weight = hard_token_relative_weight - accelerator.print(f"Calculated balanced weights:") - accelerator.print(f" - Hard token ratio: {avg_hard_ratio:.4f}") - accelerator.print(f" - Weight for hard tokens: {weight_hard:.4f}") - accelerator.print(f" - Weight for easy tokens: {weight_easy:.4f}") - else: - accelerator.print(f"Skipping balanced weights calculation because hard_token_relative_weight is 1.0") + train_ds, eval_ds = preprocess_dataset(training_config, data_config, model_config, accelerator) + + training_args = _build_training_args(training_config, data_config, output_dir, timestamp) - # Create Training Arguments - training_args = create_training_args(training_config, data_config, output_dir, accelerator, timestamp) - - # Print training infos if accelerator.is_main_process: - print(f"Model: {model_config['name']}") - print(f"Output directory: {output_dir}") - print(f"Training epochs: {training_config['num_train_epochs']}") - print(f"Batch size: {training_config['per_device_train_batch_size']}") - print(f"Learning rate: {training_config['learning_rate']}") - print(f"Max length: {data_config.get('max_length', None)}") - print(f"Max length action: {data_config.get('max_length_action', 'cutoff')}") - print("--- Training begins ---\n") - - # Also save the configuration file for reference - config_save_path = os.path.join(output_dir, "training_config.yaml") - os.makedirs(output_dir, exist_ok=True) - with open(config_save_path, 'w', encoding='utf-8') as f: - yaml.dump(config, f, default_flow_style=False, allow_unicode=True) - accelerator.print(f"Configuration saved to: {config_save_path}") - - # Determine resume checkpoint path (if requested) - resume_from_ckpt = training_config.get('resume_from_ckpt', False) - resume_from_checkpoint_path = None - if resume_from_ckpt and ('tah_model_path' in model_config): - resume_from_checkpoint_path = model_config['tah_model_path'] - accelerator.print(f"Resume-from-ckpt enabled. Using checkpoint path: {resume_from_checkpoint_path}") - - # Train Model - trainer, training_history = train_model( - model, - tokenizer, - processed_train_dataset, - processed_eval_dataset, - training_args, - accelerator, - training_config, - resume_from_checkpoint_path=resume_from_checkpoint_path, + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(output_dir, "training_config.yaml"), "w", encoding="utf-8") as f: + yaml.dump(config, f, default_flow_style=False, allow_unicode=True) + accelerator.print(f"Configuration saved to: {output_dir}/training_config.yaml") + + trainer = CustomTaHTrainer( + model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tokenizer, ) - - # Save Final Model - save_final_model(trainer, tokenizer, output_dir, config, accelerator) - - return training_history + trainer.data_collator = CustomTaHDataCollator(tokenizer=tokenizer, padding=True) + + callback = LoggerCallback() + model.logger_callback = callback + trainer.callback_handler.callbacks.insert(0, callback) + + accelerator.print("\n--- Starting Training ---") + resume_path = model_config["tah_model_path"] if training_config.get("resume_from_ckpt") and "tah_model_path" in model_config else None + if resume_path: + accelerator.print(f"Resuming optimizer state from: {resume_path}") + trainer.train(resume_from_checkpoint=resume_path) + else: + trainer.train() + accelerator.print("--- Training Complete ---") + + final_dir = os.path.join(output_dir, "final_model") + accelerator.print(f"Saving final model to: {final_dir}") + trainer.model.save_pretrained(final_dir) + tokenizer.save_pretrained(final_dir) + if __name__ == "__main__": - # Load Configuration - parser = argparse.ArgumentParser(description='Train a causal language model with configuration file') - parser.add_argument('--config', type=str, default='config.yaml', help='Path to configuration file') + parser = argparse.ArgumentParser(description="TaH SFT entrypoint") + parser.add_argument("--config", type=str, required=True, help="Path to YAML recipe") args = parser.parse_args() - config = load_config(args.config) - - training_history = main(config) - - # Print training history summary - if training_history: - print("\n--- Training History Summary ---") - print(f"Number of training loss entries: {len(training_history['train_loss'])}") - print(f"Number of evaluation loss entries: {len(training_history['eval_loss'])}") - if training_history['train_loss']: - print(f"Final training loss: {training_history['train_loss'][-1]:.6f}") - if training_history['eval_loss']: - print(f"Final evaluation loss: {training_history['eval_loss'][-1]:.6f}") - print("Full training history returned in training_history variable") \ No newline at end of file + main(_load_yaml(args.config)) diff --git a/tah/__init__.py b/tah/__init__.py new file mode 100644 index 0000000..966298b --- /dev/null +++ b/tah/__init__.py @@ -0,0 +1,21 @@ +"""TaH: Selective Latent Iterations to Improve Reasoning Language Models. + +Top-level re-exports of the most commonly used names. Submodules can also be +imported directly (``from tah.model.tah_model import ...``). +""" +from tah.model.causal_cache import TaHCache +from tah.model.iter_decider import IterLabelDecider, MLPIterDecider +from tah.model.loss import IterDeciderLoss, NextTokenPredLoss +from tah.model.tah_config import TaHConfig +from tah.model.tah_model import TaHCausalLMOutputWithPast, TaHForCausalLM + +__all__ = [ + "IterDeciderLoss", + "IterLabelDecider", + "MLPIterDecider", + "NextTokenPredLoss", + "TaHCache", + "TaHCausalLMOutputWithPast", + "TaHConfig", + "TaHForCausalLM", +] diff --git a/tah/evaluate/__init__.py b/tah/evaluate/__init__.py new file mode 100644 index 0000000..6c156e0 --- /dev/null +++ b/tah/evaluate/__init__.py @@ -0,0 +1,27 @@ +"""TaH eval driver: dataset loading, per-backend setup, multi-job orchestration. + +Most callers want :func:`allocate_gpus_and_run_jobs` (the top-level entry +the CLI uses); the module split is otherwise mainly internal: + +* ``datasets`` — load + standardise benchmark datasets. +* ``backends`` — per-backend (sglang / hf / tah) model + inference fn. +* ``jobs`` — per-job runner, process orchestration, result aggregation. +* ``matheval`` — math benchmark graders (rule-based via ``math_verify``). +* ``codeeval`` — humaneval/mbpp grading via ``evalplus``. +* ``eval_unified`` — backwards-compat shim re-exporting the above entry points. +""" +from tah.evaluate.datasets import load_combined_dataset +from tah.evaluate.jobs import ( + allocate_gpus_and_run_jobs, + combine_job_results, + parse_data_range, + run_single_job, +) + +__all__ = [ + "allocate_gpus_and_run_jobs", + "combine_job_results", + "load_combined_dataset", + "parse_data_range", + "run_single_job", +] diff --git a/tah/evaluate/backends.py b/tah/evaluate/backends.py new file mode 100644 index 0000000..b3e47ad --- /dev/null +++ b/tah/evaluate/backends.py @@ -0,0 +1,204 @@ +"""Per-backend model loaders + inference adapters for the eval driver. + +Each ``setup_*`` function returns ``(model, inference_function)`` where +``inference_function(prompts: list[str]) -> list[(text, seconds)]``. This +shape is what the per-job runner consumes; it doesn't care which backend +produced the strings. + +Backends: + * ``setup_sglang`` — ``sgl.Engine`` for production-throughput inference. + * ``setup_hf`` — vanilla ``AutoModelForCausalLM.generate``. + * ``setup_tah`` — :class:`tah.model.tah_model.TaHForCausalLM` with + its own iter-aware generation helper. +""" +from __future__ import annotations + +import os +import time +from dataclasses import fields +from typing import Callable, Dict, Iterable, List, Tuple + +import torch + + +InferenceFn = Callable[[List[str]], List[Tuple[str, float]]] + + +def time_inference(fn: Callable): + """Run ``fn``, returning ``(result, elapsed_seconds)``. Uses CUDA events + for tighter timing on GPU; falls back to wall clock otherwise.""" + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + out = fn() + end.record() + torch.cuda.synchronize() + return out, start.elapsed_time(end) / 1000.0 + t0 = time.time() + out = fn() + return out, time.time() - t0 + + +def warmup(model, tokenizer, backend: str) -> None: + """One throwaway forward to JIT-compile / page caches per backend.""" + print(f"Warming up {backend} model…") + if backend == "sglang": + model.generate(["who are you?"], { + "temperature": 0.6, "max_new_tokens": 100, + "top_p": 0.95, "top_k": 20, "min_p": 0.0, + }) + return + inputs = tokenizer("who are you?", return_tensors="pt") + device = next(model.parameters()).device + inputs = {k: v.to(device) for k, v in inputs.items()} + with torch.no_grad(): + print(model.generate(**inputs, max_new_tokens=100, do_sample=True)) + + +def cleanup(model, backend: str) -> None: + if model is None: + return + if backend == "sglang": + model.shutdown() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +# ──────────────────────────────────────────────────────────────────────────── +# Backend setup +# ──────────────────────────────────────────────────────────────────────────── + + +def setup_sglang(config: Dict, model_path: str, tokenizer, tp_size: int) -> Tuple[object, InferenceFn]: + import sglang as sgl + + sampling = { + "temperature": config["temperature"], + "max_new_tokens": config["max_new_tokens"], + "top_p": config["top_p"], + } + for opt in ("top_k", "min_p"): + if config.get(opt) is not None: + sampling[opt] = config[opt] + + print(f"Loading SGLang engine from: {model_path}") + model = sgl.Engine( + model_path=model_path, + dtype=config.get("dtype", "bfloat16"), + tp_size=tp_size, + mem_fraction_static=config.get("mem_fraction_static", 0.90), + host="127.0.0.1", + port=int(os.getenv("SGLANG_NCCL_PORT", "30000")), + attention_backend=config.get("attention_backend", "triton"), + ) + warmup(model, tokenizer, "sglang") + + def infer(prompts: List[str]) -> List[Tuple[str, float]]: + out: List[Tuple[str, float]] = [] + bs = config["batch_size"] + for i in range(0, len(prompts), bs): + batch = prompts[i:i + bs] + outputs, elapsed = time_inference(lambda: model.generate(batch, sampling)) + out.extend((o["text"], elapsed) for o in outputs) + return out + + return model, infer + + +def setup_hf(config: Dict, model_path: str, tokenizer, tp_size: int = 1) -> Tuple[object, InferenceFn]: + del tp_size # accepted for dispatcher signature uniformity; HF uses device_map="auto" + from transformers import AutoModelForCausalLM + + print(f"Loading Hugging Face model from: {model_path} (visible CUDA: {torch.cuda.device_count()})") + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=getattr(torch, config.get("dtype", "bfloat16")), + device_map="auto", + trust_remote_code=True, + attn_implementation="flash_attention_2" if config.get("use_flash_attention") else None, + low_cpu_mem_usage=True, + ) + gen_cfg = { + "temperature": config["temperature"], + "max_new_tokens": config["max_new_tokens"], + "top_p": config["top_p"], + "do_sample": config["temperature"] > 0.0, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": tokenizer.eos_token_id, + } + for opt in ("top_k", "min_p"): + if config.get(opt) is not None: + gen_cfg[opt] = config[opt] + warmup(model, tokenizer, "hf") + + def infer(prompts: List[str]) -> List[Tuple[str, float]]: + out: List[Tuple[str, float]] = [] + bs = config["batch_size"] + device = next(model.parameters()).device + for i in range(0, len(prompts), bs): + batch = prompts[i:i + bs] + inputs = tokenizer(batch, return_tensors="pt", padding=True, padding_side="left", truncation=True) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs, elapsed = time_inference(lambda: model.generate(**inputs, **gen_cfg)) + for j, gen in enumerate(outputs): + input_len = inputs["input_ids"][j].shape[0] + text = tokenizer.decode(gen[input_len:], skip_special_tokens=True) + out.append((text, elapsed / len(outputs))) + return out + + return model, infer + + +def setup_tah(config: Dict, model_path: str, tokenizer, tp_size: int = 1) -> Tuple[object, InferenceFn]: + del tp_size # accepted for dispatcher signature uniformity; TaH uses device_map="auto" + from tah.model.tah_config import TaHConfig + from tah.model.tah_model import TaHForCausalLM + from tah.model.utils import TaHForCasualLM_generate + + print(f"Loading TaH model from: {model_path} (visible CUDA: {torch.cuda.device_count()})") + valid = {f.name for f in fields(TaHConfig)} + override = TaHConfig(**{k: v for k, v in config.items() if k in valid}) + dtype = getattr(torch, config.get("dtype", "bfloat16")) + model = TaHForCausalLM.from_pretrained( + model_path, torch_dtype=dtype, device_map="auto", + attn_implementation="sdpa", tah_config=override, + ).to(dtype=dtype) + + def infer(prompts: List[str]) -> List[Tuple[str, float]]: + out: List[Tuple[str, float]] = [] + bs = config["batch_size"] + for i in range(0, len(prompts), bs): + batch = prompts[i:i + bs] + inputs = tokenizer(batch, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to(model.device) for k, v in inputs.items()} + (_, texts), elapsed = time_inference(lambda: TaHForCasualLM_generate( + tah_model=model, tokenizer=tokenizer, model_inputs=inputs, + max_new_tokens=config["max_new_tokens"], + do_sample=config["temperature"] > 0.0, + temperature=config["temperature"], + top_p=config["top_p"], + top_k=config.get("top_k", 0), + min_p=config.get("min_p", 0.0), + verbose=False, + )) + per = elapsed / max(len(texts), 1) + out.extend((t, per) for t in texts) + return out + + return model, infer + + +_SETUP_BY_NAME = {"sglang": setup_sglang, "hf": setup_hf, "tah": setup_tah} + + +def setup_backend( + backend: str, config: Dict, model_path: str, tokenizer, tp_size: int = 1, +) -> Tuple[object, InferenceFn]: + """Dispatch to the named ``setup_*`` and forward all args.""" + try: + setup = _SETUP_BY_NAME[backend] + except KeyError: + raise ValueError(f"unsupported backend {backend!r}; have {sorted(_SETUP_BY_NAME)}") from None + return setup(config, model_path, tokenizer, tp_size) diff --git a/tah/evaluate/codeeval.py b/tah/evaluate/codeeval.py index 01316fc..88c43d6 100644 --- a/tah/evaluate/codeeval.py +++ b/tah/evaluate/codeeval.py @@ -1,3 +1,24 @@ +"""HumanEval / MBPP evaluation glue around `evalplus`. + +Three entry points the eval driver uses: + +* :func:`make_raw_chat_prompt_for_code_evaluation` — turn a problem prompt + into a chat-templated prompt that primes the model to emit a single + self-contained Python script in a markdown code block. +* :func:`sanitize` — strip the model output down to just the function body + for the requested ``entry_point`` (re-exported from + :mod:`evalplus.sanitize`). +* :func:`evaluate` — run a per-sample untrusted-execution check against the + evalplus base + plus test suites in a process pool, then write a + ``*.eval_results.json`` next to the input ``samples.jsonl``. + +Public TaH carried a broader ``evaluate`` surface (interactive overwrite +prompt, `gguf_file`/`num_ctx`/`**model_kwargs` plumbing, a "reasoning" +chat-template branch). None of those are used by the cleaned eval driver, +so this module trims to the actually-called interface. +""" +from __future__ import annotations + import json import multiprocessing import os @@ -14,7 +35,7 @@ from termcolor import cprint from tqdm import tqdm -from evalplus.config import * +from evalplus.config import DEFAULT_GT_TIME_LIMIT_FACTOR, DEFAULT_MIN_TIME_LIMIT from evalplus.data import ( get_human_eval_plus, get_human_eval_plus_hash, @@ -32,367 +53,253 @@ ) from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS from evalplus.gen.util import trusted_exec -# 1st item: the status -# 2nd item (optional): the detailed pass/fail boolean for each input +from evalplus.sanitize import sanitize # noqa: F401 -- re-exported for the runner + +# (status_code, per-input pass/fail booleans) — evalplus's untrusted_check shape. Result = Tuple[str, List[bool]] -# some random words which serves as the splitter -_MAGIC_SPLITTER_ = "-[[]]-this-is-really-our-highest-priority-[[]]-" -# Model instructions for code evaluation -INSTRUCTION_PREFIX = "Please provide a self-contained Python script that solves the following problem in a markdown code block:" -INSTRUCTION_PREFIX_REASONING = "Please think step by step and then provide a self-contained Python script that solves the following problem in a markdown code block:" -RESPONSE_PREFIX = "Below is a Python script with a self-contained function that solves the problem and passes corresponding tests:" - -def get_groundtruth(problems, hashcode, tasks_only_output_not_none): + +_PROMPT_INSTRUCTION = ( + "Please provide a self-contained Python script that solves the following problem in a markdown code block:" +) +_RESPONSE_PREFIX = ( + "Below is a Python script with a self-contained function that solves the problem and passes corresponding tests:" +) +# Magic string that splits the assistant template open from the model's start. +_PROMPT_SPLITTER = "-[[]]-this-is-really-our-highest-priority-[[]]-" + + +def make_raw_chat_prompt_for_code_evaluation(task_prompt: str, reasoning: bool, tokenizer) -> str: + """Render the task prompt into the model's chat template + a fenced markdown + block opener so the model continues with the function body inline. + + ``reasoning=True`` wraps the user message only and lets the model think + freely; ``reasoning=False`` (the default our eval driver passes) primes + the assistant turn so the model emits just the code. + """ + if tokenizer.chat_template is None: + return task_prompt + + user_msg = f"{_PROMPT_INSTRUCTION}\n```\n{task_prompt.strip()}\n```\n" + if reasoning: + return tokenizer.apply_chat_template([{"role": "user", "content": user_msg}], tokenize=False) + + primed = f"\n{_RESPONSE_PREFIX}\n```python\n{_PROMPT_SPLITTER}\n```\n" + rendered = tokenizer.apply_chat_template( + [{"role": "user", "content": user_msg}, {"role": "assistant", "content": primed}], + tokenize=False, + ) + return rendered.split(_PROMPT_SPLITTER)[0] + + +def _get_groundtruth(problems: Dict, hashcode: str, output_not_none_tasks) -> Dict: + """Compute (or load from disk cache) the expected outputs for each problem + by running the canonical solution against the base + plus inputs.""" cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") if os.path.exists(cache_file): - print(f"Load from ground-truth from {cache_file}") + print(f"Loading cached ground-truth from {cache_file}") with open(cache_file, "rb") as f: return pickle.load(f) os.makedirs(CACHE_DIR, exist_ok=True) - print("Computing expected output...") - tbegin = time.time() - expected_output = {} + print("Computing expected outputs…") + t0 = time.time() + expected = {} for task_id, problem in problems.items(): - oracle = {} - oracle["base"], oracle["base_time"] = trusted_exec( - problem["prompt"] + problem["canonical_solution"], - problem["base_input"], - problem["entry_point"], - record_time=True, - output_not_none=problem["entry_point"] in tasks_only_output_not_none, - ) - - oracle["plus"], oracle["plus_time"] = trusted_exec( - problem["prompt"] + problem["canonical_solution"], - problem["plus_input"], - problem["entry_point"], - record_time=True, - output_not_none=problem["entry_point"] in tasks_only_output_not_none, - ) - expected_output[task_id] = oracle - print(f"Expected outputs computed in {time.time() - tbegin:.2f}s") + solution_src = problem["prompt"] + problem["canonical_solution"] + oracle: Dict[str, Any] = {} + for which in ("base", "plus"): + outputs, t = trusted_exec( + solution_src, + problem[f"{which}_input"], + problem["entry_point"], + record_time=True, + output_not_none=problem["entry_point"] in output_not_none_tasks, + ) + oracle[which] = outputs + oracle[f"{which}_time"] = t + expected[task_id] = oracle + print(f"Expected outputs computed in {time.time() - t0:.2f}s") with open(cache_file, "wb") as f: - pickle.dump(expected_output, f) + pickle.dump(expected, f) + return expected - return expected_output - -def check_correctness( - dataset: str, - completion_id: int, - problem: Dict[str, Any], - solution: str, - expected_output: Dict[str, List], - base_only=False, - fast_check=False, - identifier=None, - min_time_limit: float = DEFAULT_MIN_TIME_LIMIT, - gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR, -) -> Dict[str, Result]: # {...}, "base" | "plus" -> (status, details) - ret = { +def _check_correctness( + dataset: str, completion_id: int, problem: Dict[str, Any], solution: str, + expected_output: Dict[str, List], *, base_only: bool, fast_check: bool, + identifier: str, min_time_limit: float, gt_time_limit_factor: float, +) -> Dict[str, Result]: + """Run ``solution`` against base (and optionally plus) test inputs.""" + out: Dict[str, Any] = { "completion_id": completion_id, "task_id": problem["task_id"], "_identifier": identifier, "solution": solution, } - ret["base"] = untrusted_check( - dataset, - solution, - problem["base_input"], - problem["entry_point"], - expected=expected_output["base"], - atol=problem["atol"], - ref_time=expected_output["base_time"], - fast_check=fast_check, - min_time_limit=min_time_limit, - gt_time_limit_factor=gt_time_limit_factor, - ) - - if not base_only: - ret["plus"] = untrusted_check( - dataset, - solution, - problem["plus_input"], - problem["entry_point"], - expected=expected_output["plus"], - atol=problem["atol"], - ref_time=expected_output["plus_time"], - fast_check=fast_check, - min_time_limit=min_time_limit, + for which in (("base",) if base_only else ("base", "plus")): + out[which] = untrusted_check( + dataset, solution, + problem[f"{which}_input"], problem["entry_point"], + expected=expected_output[which], atol=problem["atol"], + ref_time=expected_output[f"{which}_time"], + fast_check=fast_check, min_time_limit=min_time_limit, gt_time_limit_factor=gt_time_limit_factor, ) + return out - return ret +def _failed_inputs(stat: str, details: Optional[List[bool]], inputs: list, *, full: bool) -> list: + """Pick which input rows to surface in the per-task failure log.""" + if stat == PASS or not details: + return [] + if full: + return [inputs[i] for i, ok in enumerate(details) if not ok] + return [inputs[len(details) - 1]] # last failure only -def make_raw_chat_prompt_for_code_evaluation( - task_prompt: str, - reasoning: bool, - tokenizer, -) -> str: - # directly return prompt if it does not have a tokenizer.chat_template - if tokenizer.chat_template is None: - return task_prompt - - task_prompt = f"""\ -{INSTRUCTION_PREFIX_REASONING if reasoning else INSTRUCTION_PREFIX} -``` -{task_prompt.strip()} -``` -""" - response = f"""\ - -{RESPONSE_PREFIX} -```python -{_MAGIC_SPLITTER_} -``` -""" - if reasoning: - task_prompt = tokenizer.apply_chat_template( - [ - {"role": "user", "content": task_prompt}, - ], - tokenize=False, - ) - else: - task_prompt = tokenizer.apply_chat_template( - [ - {"role": "user", "content": task_prompt}, - {"role": "assistant", "content": response}, - ], - tokenize=False, - ).split(_MAGIC_SPLITTER_)[0] - return task_prompt def evaluate( dataset: str, - samples: Optional[str] = None, + samples: str, + *, base_only: bool = False, parallel: Optional[int] = None, - i_just_wanna_run: bool = False, test_details: bool = False, min_time_limit: float = DEFAULT_MIN_TIME_LIMIT, gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR, - mini: bool = False, - noextreme: bool = False, - version: str = "default", output_file: Optional[str] = None, - gguf_file: Optional[str] = None, - num_ctx: Optional[int] = None, - **model_kwargs, -): +) -> None: + """Run evalplus untrusted-check on every sample and write pass@k results. + ``samples`` is either a path to a ``.jsonl`` (each line ``{task_id, + solution}``) or a directory containing one. Output JSON path is derived + from ``samples`` (or set explicitly via ``output_file``). + """ n_workers = parallel or max(1, multiprocessing.cpu_count() // 2) if os.path.isdir(samples): result_path = os.path.join(samples, "eval_results.json") else: - assert samples.endswith(".jsonl") - # legacy compatibility - if os.path.exists(samples.replace(".jsonl", "_eval_results.json")): - result_path = samples.replace(".jsonl", "_eval_results.json") - else: - result_path = samples.replace(".jsonl", ".eval_results.json") - - if output_file is not None: + assert samples.endswith(".jsonl"), "samples must be a directory or *.jsonl path" + legacy = samples.replace(".jsonl", "_eval_results.json") + result_path = legacy if os.path.exists(legacy) else samples.replace(".jsonl", ".eval_results.json") + if output_file: result_path = output_file - if os.path.isfile(result_path) and not i_just_wanna_run: - print(f"Load from previous results from {result_path}") + if os.path.isfile(result_path): + print(f"Loading previous results from {result_path}") with open(result_path, "r") as f: results = json.load(f) - results = compatible_eval_result(results) else: if dataset == "humaneval": - problems = get_human_eval_plus( - mini=mini, noextreme=noextreme, version=version - ) - dataset_hash = get_human_eval_plus_hash( - mini=mini, noextreme=noextreme, version=version - ) - expected_output = get_groundtruth(problems, dataset_hash, []) + problems = get_human_eval_plus() + dataset_hash = get_human_eval_plus_hash() + output_not_none_tasks: Tuple[str, ...] = () elif dataset == "mbpp": - problems = get_mbpp_plus(mini=mini, noextreme=noextreme, version=version) - dataset_hash = get_mbpp_plus_hash( - mini=mini, noextreme=noextreme, version=version - ) - expected_output = get_groundtruth( - problems, - dataset_hash, - MBPP_OUTPUT_NOT_NONE_TASKS, - ) + problems = get_mbpp_plus() + dataset_hash = get_mbpp_plus_hash() + output_not_none_tasks = MBPP_OUTPUT_NOT_NONE_TASKS + else: + raise ValueError(f"unsupported code dataset {dataset!r}") - results = { - "date": datetime.now().strftime("%Y-%m-%d %H:%M"), - "hash": dataset_hash, - "eval": {}, - } + expected_output = _get_groundtruth(problems, dataset_hash, output_not_none_tasks) + results = {"date": datetime.now().strftime("%Y-%m-%d %H:%M"), "hash": dataset_hash, "eval": {}} with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = [] completion_id = Counter() n_samples = 0 - eval_results = defaultdict(list) # task_id -> - remainings = set() + eval_results = defaultdict(list) + remaining = set() - print("Reading samples...") + print("Reading samples…") for sample in tqdm(load_solutions(samples)): task_id = sample["task_id"] if task_id not in problems: - warn( - f"Task {task_id} is found in the samples but not found in the dataset" - ) + warn(f"task {task_id} in samples but not in dataset") continue - solution = ( - sample["solution"] - if "solution" in sample - else problems[task_id]["prompt"] + sample["completion"] - ) - remainings.add(sample["_identifier"]) - args = ( - dataset, - completion_id[task_id], - problems[task_id], - solution, + solution = sample.get("solution") or (problems[task_id]["prompt"] + sample["completion"]) + remaining.add(sample["_identifier"]) + futures.append(executor.submit( + _check_correctness, + dataset, completion_id[task_id], problems[task_id], solution, expected_output[task_id], - base_only, - not test_details, # fast_check - sample["_identifier"], - min_time_limit, - gt_time_limit_factor, - ) - futures.append(executor.submit(check_correctness, *args)) + base_only=base_only, fast_check=not test_details, + identifier=sample["_identifier"], + min_time_limit=min_time_limit, gt_time_limit_factor=gt_time_limit_factor, + )) completion_id[task_id] += 1 n_samples += 1 - assert n_samples == len(remainings), "Missing problems in unfinished" - assert len(completion_id) == len(problems), "Missing problems in samples" + assert n_samples == len(remaining), "missing problems in unfinished" + assert len(completion_id) == len(problems), "missing problems in samples" - def stucking_checker(): - while remainings: - last_size = len(remainings) + def _watchdog(): + while remaining: + last = len(remaining) time.sleep(100) - if last_size != len(remainings) or len(remainings) == 0: - continue - # Potential stucking - warn("No samples had finished testing in the last 100s") - warn(f"{len(remainings)} samples to be tested: {remainings}") - - threading.Thread(target=stucking_checker).start() + if last == len(remaining) and remaining: + warn(f"no samples finished in 100s; {len(remaining)} pending: {remaining}") + threading.Thread(target=_watchdog, daemon=True).start() for future in tqdm(as_completed(futures), total=n_samples): - result = future.result() - remainings.remove(result["_identifier"]) - eval_results[result["task_id"]].append(result) + res = future.result() + remaining.discard(res["_identifier"]) + eval_results[res["task_id"]].append(res) - # sort the results for each problem by completion_id + # Sort completions per task and unpack base/plus statuses + failure details. for task_id, task_results in eval_results.items(): task_results.sort(key=lambda x: x["completion_id"]) results["eval"][task_id] = [] for res in task_results: - - def get_failed_tests(stat, details, inputs) -> List[Any]: - if stat == PASS or not details: - return [] - - if test_details: - return [ - inputs[i] for i in range(len(details)) if not details[i] - ] - - # else => simply return the only and the last fail test - return [inputs[len(details) - 1]] - base_stat, base_details = res["base"] - base_fail_tests = get_failed_tests( - base_stat, base_details, problems[task_id]["base_input"] - ) - - # initialize plus tests + base_fails = _failed_inputs(base_stat, base_details, problems[task_id]["base_input"], full=test_details) plus_stat = None - plus_fail_tests = [] - - # with plus tests + plus_fails: list = [] if not base_only: plus_stat, plus_details = res["plus"] - plus_fail_tests = get_failed_tests( - plus_stat, plus_details, problems[task_id]["plus_input"] - ) - + plus_fails = _failed_inputs(plus_stat, plus_details, problems[task_id]["plus_input"], full=test_details) if dataset == "mbpp": - base_fail_tests = mbpp_serialize_inputs(task_id, base_fail_tests) - plus_fail_tests = mbpp_serialize_inputs(task_id, plus_fail_tests) - - results["eval"][task_id].append( - { - "task_id": task_id, - "solution": res["solution"], - "base_status": base_stat, - "plus_status": plus_stat, - "base_fail_tests": base_fail_tests, - "plus_fail_tests": plus_fail_tests, - } - ) - - # Calculate pass@k. + base_fails = mbpp_serialize_inputs(task_id, base_fails) + plus_fails = mbpp_serialize_inputs(task_id, plus_fails) + results["eval"][task_id].append({ + "task_id": task_id, + "solution": res["solution"], + "base_status": base_stat, + "plus_status": plus_stat, + "base_fail_tests": base_fails, + "plus_fail_tests": plus_fails, + }) + + # pass@k from base; pass@k+ from base ∩ plus. total = np.array([len(r) for r in results["eval"].values()]) - base_correct = [] - new_correct = [] - - for res in results["eval"].values(): - bc = sum([r["base_status"] == PASS for r in res]) - base_correct.append(bc) - if not base_only: - new_correct.append( - sum( - [ - res[i]["base_status"] == res[i]["plus_status"] == PASS - for i in range(len(res)) - ] - ) - ) - base_correct = np.array(base_correct) - + base_correct = np.array([ + sum(r["base_status"] == PASS for r in tres) for tres in results["eval"].values() + ]) pass_at_k = { f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean() - for k in [1, 10, 100] - if total.min() >= k + for k in (1, 10, 100) if total.min() >= k } cprint(f"{dataset} (base tests)", "red") for k, v in pass_at_k.items(): cprint(f"{k}:\t{v:.3f}", "red") results["pass_at_k"] = {"base": pass_at_k} - if new_correct: - cprint(f"{dataset}+ (base + extra tests)", "green") - pass_at_k = { - f"pass@{k}": estimate_pass_at_k(total, np.array(new_correct), k).mean() - for k in [1, 10, 100] - if (total >= k).all() + if not base_only: + new_correct = np.array([ + sum(r["base_status"] == r["plus_status"] == PASS for r in tres) + for tres in results["eval"].values() + ]) + pass_at_k_plus = { + f"pass@{k}": estimate_pass_at_k(total, new_correct, k).mean() + for k in (1, 10, 100) if (total >= k).all() } - for k, v in pass_at_k.items(): + cprint(f"{dataset}+ (base + extra tests)", "green") + for k, v in pass_at_k_plus.items(): cprint(f"{k}:\t{v:.3f}", "green") - results["pass_at_k"]["plus"] = pass_at_k - - # save results - if os.path.isfile(result_path) and i_just_wanna_run: - decision = "" - while decision.lower() not in ["y", "n"]: - print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...") - decision = input() - - if decision.lower() == "y": - # mv the file to a backup - new_path = result_path + ".bak" - while os.path.isfile(new_path): - new_path += ".bak" - os.rename(result_path, new_path) - print(f"Backup {result_path} to {new_path}") + results["pass_at_k"]["plus"] = pass_at_k_plus if not os.path.isfile(result_path): with open(result_path, "w") as f: json.dump(results, f) - diff --git a/tah/evaluate/datasets.py b/tah/evaluate/datasets.py new file mode 100644 index 0000000..181f302 --- /dev/null +++ b/tah/evaluate/datasets.py @@ -0,0 +1,140 @@ +"""Dataset loading + per-row standardisation for the eval driver. + +Each entry in ``eval_configs/dataset_configs.json`` describes how to fetch a +benchmark (HF id or local file), which split to use, optional row filter, +and the field names to map onto our standard ``(id, question, answer)`` +schema. ``load_combined_dataset`` reads the entry, normalises every row, +and returns a flat ``list[dict]`` plus a small ``field_mapping`` dict that +downstream code uses to know the answer type (math / livecodebench / +humaneval / mbpp). + +Why standardise: the per-job runner only knows the standard schema, so a +new benchmark only needs a config-file entry to be evaluable. +""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + + +_CONFIG_PATH = Path(__file__).with_name("eval_configs") / "dataset_configs.json" + + +def _split_names(names: Union[str, List[str]]) -> List[str]: + """Accept comma-separated string or list; return a clean ``list[str]``.""" + if isinstance(names, str): + items = [n.strip() for n in names.split(",")] + elif isinstance(names, list): + items = [n.strip() if isinstance(n, str) else n for n in names] + else: + raise TypeError(f"dataset_names must be str or list, got {type(names).__name__}") + items = [n for n in items if n] + if not items: + raise ValueError("no dataset names provided") + return items + + +def _load_one(name: str, configs: Dict[str, Dict]) -> Tuple[List[dict], Dict[str, str]]: + """Load one dataset (HF id or local file), apply optional split + filter, + and standardise rows to ``{id, _original_id, question, answer, ...}``.""" + if name not in configs: + raise ValueError(f"dataset {name!r} not in eval_configs/dataset_configs.json (have {sorted(configs)})") + cfg = configs[name] + path = cfg["path"] + + # Local JSON/JSONL takes a different code path because mbpp/humaneval files + # contain test scaffolding that load_dataset can't infer schemas for. + if path.endswith((".json", ".jsonl")): + if name in ("mbpp", "humaneval"): + with open(path, "r", encoding="utf-8") as f: + rows = [json.loads(ln) for ln in f if ln.strip()] + split = "train" + ds_obj: Dict[str, Any] = {"train": rows} + else: + from datasets import load_dataset + ds_obj = load_dataset("json", data_files=path) + else: + from datasets import load_dataset + if cfg.get("subset"): + ds_obj = load_dataset(path, cfg["subset"]) + elif cfg.get("version_tag"): + ds_obj = load_dataset(path, version_tag=cfg["version_tag"]) + else: + ds_obj = load_dataset(path) + + split = cfg.get("split_name", "test") + if split not in ds_obj: + for fallback in ("train", "test"): + if fallback in ds_obj: + split = fallback + break + else: + raise ValueError(f"split {cfg.get('split_name')!r} not in {list(ds_obj)}") + ds = ds_obj[split] + print(f" {name}: split={split!r}, {len(ds)} rows") + + if "filter" in cfg: + f = cfg["filter"] + ds = ds.filter(lambda x, k=f["key"], v=f["value"]: x.get(k) in v) + + id_field = cfg["id_field"] + q_field = cfg["question_field"] + a_field = cfg["answer_field"] + template = cfg.get("prompt_template", "{question}") + entry_point_field = cfg.get("entry_point") + + standardised: List[dict] = [] + for idx, row in enumerate(ds): + original_id = str(row[id_field]) if row.get(id_field) is not None else str(idx) + question = template.replace("{question}", str(row.get(q_field, "")).strip()) + item = { + "id": f"{name}_{original_id}", + "_original_id": original_id, + "question": question, + "answer": str(row.get(a_field, "")).strip(), + "_source_dataset": name, + } + if entry_point_field: + item["entry_point"] = row.get(entry_point_field) + # Carry remaining columns under _original_ for downstream debugging. + for k, v in row.items(): + if k not in (id_field, q_field, a_field) and not k.startswith("_"): + item[f"_original_{k}"] = v + standardised.append(item) + return standardised, { + "id_field": "id", + "question_field": "question", + "answer_field": "answer", + "answer_type": cfg["answer_type"], + "prompt_template": "{question}", + } + + +def load_combined_dataset(dataset_names: Union[str, List[str]]) -> Tuple[List[dict], Dict]: + """Resolve one or more benchmark names → flat row list + field mapping. + + Field mapping is derived from the *first* dataset's ``answer_type`` and + is stamped onto every row (combined runs of mixed answer types are not + expected; we warn but proceed). + """ + names = _split_names(dataset_names) + with open(_CONFIG_PATH, "r", encoding="utf-8") as f: + configs = json.load(f) + + print(f"Loading {len(names)} dataset(s): {names}") + combined: List[dict] = [] + answer_types: List[str] = [] + field_mapping: Dict[str, Any] = {} + for name in names: + rows, fm = _load_one(name, configs) + combined.extend(rows) + answer_types.append(fm["answer_type"]) + if not field_mapping: + field_mapping = dict(fm) + + if len(set(answer_types)) > 1: + print(f"WARNING: multiple answer_types in combined run: {set(answer_types)} — using {answer_types[0]!r}") + field_mapping["dataset_names"] = names + print(f"Combined dataset: {len(combined)} rows total\n") + return combined, field_mapping diff --git a/tah/evaluate/eval_unified.py b/tah/evaluate/eval_unified.py index e5ce40b..b82870c 100644 --- a/tah/evaluate/eval_unified.py +++ b/tah/evaluate/eval_unified.py @@ -1,1422 +1,12 @@ -import os -import json -import yaml -import csv -import time -from pathlib import Path -from typing import List, Dict, Tuple -from tqdm import tqdm -import math -import multiprocessing as mp -from multiprocessing import Process, Queue -mp.set_start_method("spawn", force=True) -import pandas as pd -from transformers.utils import logging as hf_logging -import logging as pylog - -# some constants -SYSTEM_PROMPT = """ -You are a helpful assistant. To answer the user's question, you first think about the reasoning process and then provide the user with the answer. -The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . +"""Backwards-compat shim: ``tah.evaluate.eval_unified`` was the original +single-file driver. The driver has been split into ``datasets.py``, +``backends.py``, and ``jobs.py``; this module re-exports the public entry +points so existing callers keep working. """ -USER_PROMPT = """ Please reason step by step, and put your final answer within \\boxed{}.""" -QWEN3_EOS_TOKEN_ID = 151645 -DEEPSEEK_R1_EOS_TOKEN_ID = 151643 - - -def load_datasets_with_config(dataset_names) -> Tuple[object, Dict]: - """Load single or multiple datasets using their configurations - - Args: - dataset_names: Can be: - - A single dataset name (string) - - Multiple dataset names (comma-separated string) - - A list of dataset names - - Returns: - Tuple of (dataset, field_mapping) where field_mapping uses standard internal format - """ - from datasets import load_dataset - - # Parse dataset names - if isinstance(dataset_names, str): - # Handle comma-separated string or single dataset name - dataset_names_list = [name.strip() for name in dataset_names.split(',')] - elif isinstance(dataset_names, list): - # Handle list of dataset names - dataset_names_list = dataset_names - else: - raise ValueError(f"Invalid dataset_names type: {type(dataset_names)}. Expected str or list.") - - # Remove empty strings - dataset_names_list = [name for name in dataset_names_list if name] - - if not dataset_names_list: - raise ValueError("No dataset names provided") - - # Load and combine datasets (works for both single and multiple datasets) - combined_data = [] - answer_types = [] # Collect answer types from all datasets - - print(f"Loading and combining {len(dataset_names_list)} datasets: {dataset_names_list}") - - for i, dataset_name in enumerate(dataset_names_list): - print(f"\n[{i+1}/{len(dataset_names_list)}] Loading dataset: {dataset_name}") - - # Load dataset configuration and data - script_dir = Path(__file__).parent - config_file_path = script_dir / "eval_configs" / "dataset_configs.json" - with open(config_file_path, 'r', encoding='utf-8') as f: - dataset_configs = json.load(f) - - if dataset_name not in dataset_configs: - available_datasets = list(dataset_configs.keys()) - raise ValueError(f"Dataset '{dataset_name}' not found in dataset configs. Available datasets: {available_datasets}") - - dataset_config = dataset_configs[dataset_name] - - # Load dataset - dataset_path = dataset_config['path'] - # Support both 'subset' and 'dataset_config' as HF config name - subset = dataset_config.get('subset', None) - version_tag = dataset_config.get('version_tag', None) - - # Load dataset from JSON/JSONL local file - if dataset_path.endswith('.json') or dataset_path.endswith('.jsonl'): - if dataset_name in ["mbpp", "humaneval"]: - with open(dataset_path, "r", encoding="utf-8") as f: - records = [json.loads(line) for line in f if line.strip()] - # Mimic the structure of `load_dataset` for local files, - # which usually returns a dict like {"train": Dataset(...)}. - dataset_obj = {"train": records} - else: - dataset_obj = load_dataset('json', data_files=dataset_path) - else: - # For HuggingFace datasets, use dataset_config or subset as configuration name - if subset: - dataset_obj = load_dataset(dataset_path, subset) - elif version_tag: - dataset_obj = load_dataset(dataset_path, version_tag = version_tag) - else: - dataset_obj = load_dataset(dataset_path) - - split_name = dataset_config.get('split_name', 'test') - - # Try to get train split first, if not available try test split - if split_name in dataset_obj: - dataset = dataset_obj[split_name] - print(f"Using '{split_name}' split with {len(dataset)} samples") - elif "train" in dataset_obj: - dataset = dataset_obj["train"] - print(f"Using 'train' split with {len(dataset)} samples") - elif "test" in dataset_obj: - dataset = dataset_obj["test"] - print(f"Using 'test' split with {len(dataset)} samples") - else: - available_splits = list(dataset_obj.keys()) - raise ValueError(f"Split '{split_name}' not found. Available splits: {available_splits}") - - # Apply filter if specified - if 'filter' in dataset_config: - filter_config = dataset_config['filter'] - filter_key = filter_config['key'] - filter_values = filter_config['value'] - dataset = dataset.filter(lambda x: x.get(filter_key) in filter_values) - - # Get field mapping for this dataset - original_field_mapping = { - 'id_field': dataset_config['id_field'], - 'question_field': dataset_config['question_field'], - 'answer_field': dataset_config['answer_field'], - 'answer_type': dataset_config['answer_type'], - 'prompt_template': dataset_config.get('prompt_template', '{question}') - } - # Optional extra fields from config (e.g., entry_point for code datasets) - entry_point_field = dataset_config.get('entry_point', None) - - # print(f"Original field mapping for {dataset_name}: {original_field_mapping}") - - # Collect answer type (should be same for all datasets in practice) - answer_types.append(original_field_mapping['answer_type']) - - # Convert dataset to list and standardize field names - dataset_list = list(dataset) - for idx, item in enumerate(dataset_list): - # Create new standardized item - standardized_item = {} - - # Convert ID field to standard format - id_field = original_field_mapping['id_field'] - if id_field in item and item[id_field] is not None: - original_id = str(item[id_field]) - # Keep both standardized id (for this pipeline) and original_id (for downstream tools) - standardized_item['id'] = f"{dataset_name}_{original_id}" - standardized_item['_original_id'] = original_id - else: - # Generate ID if not present - generated_id = f"{dataset_name}_{idx}" - standardized_item['id'] = generated_id - standardized_item['_original_id'] = generated_id - - # Convert question field to standard format - question_field = original_field_mapping['question_field'] - question_text = str(item.get(question_field, '')).strip() - - # Apply prompt template if specified - prompt_template = original_field_mapping['prompt_template'] - if prompt_template and '{question}' in prompt_template: - question_text = prompt_template.replace('{question}', question_text) - - standardized_item['question'] = question_text - - # Convert answer field to standard format - answer_field = original_field_mapping['answer_field'] - standardized_item['answer'] = str(item.get(answer_field, '')).strip() - - # Optionally keep entry_point in standardized format if specified in config - if entry_point_field: - standardized_item['entry_point'] = item.get(entry_point_field) - - # Add source dataset information - standardized_item['_source_dataset'] = dataset_name - - # Copy any other fields that might be useful - for key, value in item.items(): - if key not in [id_field, question_field, answer_field] and not key.startswith('_'): - standardized_item[f'_original_{key}'] = value - - combined_data.append(standardized_item) - - print(f"Added {len(dataset_list)} problems from {dataset_name} (converted to standard format)") - - print(f"\nTotal combined dataset size: {len(combined_data)} problems") - - # Verify all datasets have the same answer_type - unique_answer_types = set(answer_types) - if len(unique_answer_types) > 1: - print(f"Warning: Multiple answer types found: {unique_answer_types}") - print("Using the first answer type as default") - - # Create combined field mapping using standard format - combined_field_mapping = { - 'id_field': 'id', - 'question_field': 'question', - 'answer_field': 'answer', - 'answer_type': answer_types[0] if answer_types else 'string', - 'prompt_template': '{question}', # Already applied during conversion - 'dataset_names': dataset_names_list - } - - # print(f"Using standardized field mapping: {combined_field_mapping}") - - return combined_data, combined_field_mapping - -def combine_job_results(output_dir: Path, job_nums: int, del_job_dir: bool = False): - """Combine results from all job directories""" - all_results = [] - problem_stats = {} - # Map (problem_id, sample_idx) -> output_tokens for truncating trackers - sample_output_tokens_map = {} - - # Initialize iter count distribution tracking - iter_count_distribution = {i: 0 for i in range(1, 6)} # iter_count 1 to 5 - - # Prepare combined output directory early (for new aggregated artifacts) - combined_dir = output_dir - combined_dir.mkdir(parents=True, exist_ok=True) - # Prepare aggregated samples.jsonl (truncate if exists) - samples_jsonl_path = combined_dir / "samples.jsonl" - with open(samples_jsonl_path, 'w', encoding='utf-8'): - pass - # Collect all tracker csv files for later concatenation - all_tracker_files = [] - - # Collect results from all job directories - for job_id in range(job_nums): - job_dir = output_dir / f'job_{job_id}' - - # Read detailed results - results_file = job_dir / "detailed_results.csv" - if results_file.exists(): - with open(results_file, 'r', encoding='utf-8') as f: - reader = csv.DictReader(f) - for row in reader: - # Convert string values back to appropriate types - row['is_correct'] = row['is_correct'] == 'True' - # Handle both old and new field names for backward compatibility - if 'has_boxed_answer' in row: - row['has_answer'] = row['has_boxed_answer'] == 'True' - del row['has_boxed_answer'] # Remove old field - elif 'has_answer' in row: - row['has_answer'] = row['has_answer'] == 'True' - row['sample_idx'] = int(row['sample_idx']) - row['input_tokens'] = int(row['input_tokens']) - row['output_tokens'] = int(row['output_tokens']) - row['processing_time'] = float(row['processing_time']) - all_results.append(row) - # Record output length for this sample for later tracker truncation - sample_output_tokens_map[(row['problem_id'], row['sample_idx'])] = row['output_tokens'] - - # Read problem statistics - stats_file = job_dir / "evaluation_stats.csv" - if stats_file.exists(): - with open(stats_file, 'r', encoding='utf-8') as f: - reader = csv.reader(f) - header = next(reader) # Read header to check format - for row in reader: - if row and row[0] != '' and row[0] != 'Total Accuracy': - # Handle both old format (without avg_output_length) and new format - if len(row) >= 5 and len(header) >= 5: - problem_stats[row[0]] = { - "accuracy": f"{float(row[1]):.3f}", - "correct_count": int(row[2]), - "total_samples": int(row[3]), - "avg_output_length": float(row[4]) - } - else: - problem_stats[row[0]] = { - "accuracy": f"{float(row[1]):.3f}", - "correct_count": int(row[2]), - "total_samples": int(row[3]), - "avg_output_length": 0.0 # Default value for old format - } - - # Process tracker CSV files for iter count distribution and aggregate samples/tracker - details_dir = job_dir / 'details' - if details_dir.exists(): - for problem_dir in details_dir.iterdir(): - if problem_dir.is_dir(): - # Aggregate sample_*.json into combined samples.jsonl with added fields - sample_json_files = sorted(problem_dir.glob('sample_*.json')) - for sample_json_file in sample_json_files: - with open(sample_json_file, 'r', encoding='utf-8') as f_json: - sample_obj = json.load(f_json) - problem_id = problem_dir.name - try: - sample_idx = int(sample_json_file.stem.split('_')[-1]) - except Exception: - sample_idx = -1 - # Add required fields - sample_obj['id'] = problem_id - sample_obj['sample'] = sample_idx - # Append to JSONL - with open(samples_jsonl_path, 'a', encoding='utf-8') as out_f: - out_f.write(json.dumps(sample_obj, ensure_ascii=False) + "\n") - - # Look for tracker CSV files - for tracker_file in problem_dir.glob('*_tracker.csv'): - df = pd.read_csv(tracker_file) - # Truncate tracker rows by output length based on iter_depth==0 counts - if 'iter_depth' in df.columns: - # Parse sample index from filename like sample_{idx}_tracker.csv - sample_idx = -1 - stem_parts = tracker_file.stem.split('_') - if len(stem_parts) >= 3 and stem_parts[0] == 'sample' and stem_parts[-1] == 'tracker': - sample_idx = int(stem_parts[1]) - - output_len = sample_output_tokens_map.get((problem_dir.name, sample_idx)) - if isinstance(output_len, int) and output_len >= 0: - depth0_cum = (df['iter_depth'] == 0).cumsum() - df = df[depth0_cum <= output_len] - # Persist truncated tracker back to file so later combination uses it - df.to_csv(tracker_file, index=False) - - # Use (possibly truncated) df to accumulate iter count distribution - iter_depth_counts = df['iter_depth'].value_counts().to_dict() - for iter_count in range(1, 6): - current_depth = iter_count - 1 - next_depth = iter_count - current_count = iter_depth_counts.get(current_depth, 0) - next_count = iter_depth_counts.get(next_depth, 0) - tokens_with_this_iter_count = current_count - next_count - if tokens_with_this_iter_count > 0: - iter_count_distribution[iter_count] += tokens_with_this_iter_count - all_tracker_files.append(tracker_file) - - # Calculate statistics for each problem from all_results - problem_output_stats = {} - for result in all_results: - problem_id = result['problem_id'] - if problem_id not in problem_output_stats: - problem_output_stats[problem_id] = [] - problem_output_stats[problem_id].append(result['output_tokens']) - - # Calculate average output length for each problem - for problem_id in problem_stats: - if problem_id in problem_output_stats: - avg_output_length = sum(problem_output_stats[problem_id]) / len(problem_output_stats[problem_id]) - problem_stats[problem_id]["avg_output_length"] = avg_output_length - else: - problem_stats[problem_id]["avg_output_length"] = 0.0 - - # Calculate overall statistics - total_correct = sum(1 for r in all_results if r['is_correct']) - total_accuracy = total_correct / len(all_results) if all_results else 0 - overall_avg_output_length = sum(r['output_tokens'] for r in all_results) / len(all_results) if all_results else 0 - - # Save combined statistics - combined_dir = output_dir - combined_dir.mkdir(parents=True, exist_ok=True) - - stats_file = combined_dir / "evaluation_stats.csv" - with open(stats_file, 'w', newline='', encoding='utf-8') as f: - writer = csv.writer(f) - writer.writerow(["problem_id", "accuracy", "correct_count", "total_samples", "avg_output_length"]) - - for problem_id, stats in sorted(problem_stats.items()): - writer.writerow([problem_id, stats['accuracy'], stats['correct_count'], stats['total_samples'], f"{stats['avg_output_length']:.2f}"]) - - writer.writerow([]) - writer.writerow(["Total Accuracy", f"{total_accuracy:.3f}", total_correct, len(all_results), f"{overall_avg_output_length:.2f}"]) - - # Save combined detailed results - results_file = combined_dir / "detailed_results.csv" - with open(results_file, 'w', newline='', encoding='utf-8') as f: - fieldnames = ["problem_id", "sample_idx", "correct_answer", "predicted_answer", - "has_answer", "is_correct", "input_tokens", "output_tokens", - "processing_time"] - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - - for result in sorted(all_results, key=lambda x: (x['problem_id'], x['sample_idx'])): - writer.writerow(result) - - # Save iter count distribution - iter_count_file = combined_dir / "iter_count_distribution.csv" - with open(iter_count_file, 'w', newline='', encoding='utf-8') as f: - writer = csv.writer(f) - writer.writerow(["iter_count", "token_count"]) - - total_tokens = sum(iter_count_distribution.values()) - for iter_count in sorted(iter_count_distribution.keys()): - token_count = iter_count_distribution[iter_count] - writer.writerow([iter_count, token_count]) - - writer.writerow([]) - writer.writerow(["Total Tokens", total_tokens]) - - print(f"\nCombined results from {job_nums} jobs") - print(f"Overall accuracy: {total_accuracy:.4f}") - print(f"Total problems: {len(problem_stats)}") - print(f"Combined statistics saved to: {stats_file}") - print(f"Combined detailed results saved to: {results_file}") - print(f"Iter count distribution saved to: {iter_count_file}") - - # Print iter count distribution - print("\n=== Iter Count Distribution ===") - total_tokens = sum(iter_count_distribution.values()) - for iter_count in sorted(iter_count_distribution.keys()): - token_count = iter_count_distribution[iter_count] - if token_count <= 0: - continue - percentage = (token_count / total_tokens * 100) if total_tokens > 0 else 0 - print(f"Iter count {iter_count}: {token_count} tokens ({percentage:.2f}%)") - print(f"Total tokens: {total_tokens}") - print("==============================\n") - - # Concatenate all tracker CSV files into a single CSV if any exist - if all_tracker_files: - combined_tracker_path = combined_dir / 'all_trackers.csv' - # Write header from the first file, then append rows from all files - with open(combined_tracker_path, 'w', newline='', encoding='utf-8') as out_f: - writer = None - header_written = False - for idx, tf in enumerate(all_tracker_files): - with open(tf, 'r', encoding='utf-8') as in_f: - reader = csv.reader(in_f) - rows = list(reader) - if not rows: - continue - # Augment header with data_id on first write - if not header_written: - writer = csv.writer(out_f) - header = rows[0] + ['data_id'] - writer.writerow(header) - header_written = True - data_id = tf.parent.name - for row in rows[1:]: - writer.writerow(row + [data_id]) - print(f"Combined tracker CSV saved to: {combined_tracker_path}") - - # # After combining, remove per-job directories to reduce disk usage - if del_job_dir: - import shutil - for job_id in range(job_nums): - job_dir = output_dir / f'job_{job_id}' - if job_dir.exists(): - shutil.rmtree(job_dir, ignore_errors=True) - print("Removed per-job directories after combining results") - - -def _time_inference(func, cuda_available=True): - """Common timing wrapper for inference""" - import torch - if cuda_available: - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - result = func() - end_event.record() - torch.cuda.synchronize() - elapsed_time = start_event.elapsed_time(end_event) / 1000.0 - else: - start_time = time.time() - result = func() - end_time = time.time() - elapsed_time = end_time - start_time - return result, elapsed_time - -def _warmup_model(model, tokenizer, backend, tp_size=1): - """Common warmup function for all backends""" - import torch - print(f"Warming up {backend} model...") - if backend == 'sglang': - _ = model.generate(["who are you?"], {"temperature": 0.6, "max_new_tokens": 100, "top_p": 0.95, "top_k": 20, "min_p": 0.0}) - elif backend == 'hf' or backend == 'tah': - warmup_input = tokenizer("who are you?", return_tensors="pt") - # Move inputs to the same device as the model - try: - model_device = next(model.parameters()).device - warmup_input = {k: v.to(model_device) for k, v in warmup_input.items()} - except StopIteration: - # If model has no parameters, try to use first available CUDA device - if torch.cuda.is_available(): - warmup_input = {k: v.cuda() for k, v in warmup_input.items()} - with torch.no_grad(): - output = model.generate(**warmup_input, max_new_tokens=100, do_sample=True) - print(output) - -def _cleanup_resources(model, backend): - """Common cleanup function for all backends""" - import torch - if model is not None: - if backend == 'sglang': - model.shutdown() - else: - del model - if torch.cuda.is_available(): - torch.cuda.empty_cache() - -def run_single_job(config: Dict, combined_dataset_name: str, output_dir: str, timestamp: str, model_path: str, job_id: int, job_nums: int, start_idx: int, end_idx: int, tp_size: int, backend: str, data_range=None, problems_data=None, field_mapping=None, unified_code_solutions_file=None): - """Run inference for a single job""" - # Lazy import of torch and related libraries to ensure CUDA_VISIBLE_DEVICES is respected. - import torch - from transformers import AutoTokenizer, AutoModelForCausalLM - - import tah.evaluate.matheval as matheval - import tah.evaluate.codeeval as codeeval - from tah.model.tah_config import TaHConfig - - if backend == 'sglang': - try: - import sglang as sgl - except ImportError: - raise ImportError("sglang backend requires sglang to be installed.") - elif backend == 'tah': - try: - from tah.model.recurrent_transformer import TaHForCausalLM - from tah.model.utils import TaHForCasualLM_generate - from tah.model.tracker import TaHTracker - except ImportError: - raise ImportError("tah backend requires TaH components to be installed.") - - # Update output directory for this job (include task suffix if data_range is provided) - task_suffix = "" - if data_range: - range_start, range_end = parse_data_range(data_range) - task_suffix = f"TASK_{range_start}_{range_end}" - - output_dir = Path(output_dir) / (combined_dataset_name + "_" + backend) / timestamp - if task_suffix: - output_dir = output_dir / task_suffix - output_dir = output_dir / f'job_{job_id}' - detail_dir = output_dir / 'details' - output_dir.mkdir(parents=True, exist_ok=True) - detail_dir.mkdir(parents=True, exist_ok=True) - - problems = list(problems_data) - print(f"Job {job_id+1}/{job_nums}: Processing {len(problems)} problems") - print(f"Backend: {backend}") - print(f"Combined datasets: {combined_dataset_name}") - - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path) - # tokenizer.eos_token_id = QWEN3_EOS_TOKEN_ID - # if "r1" in model_path.lower(): - # tokenizer.eos_token_id = DEEPSEEK_R1_EOS_TOKEN_ID - # elif "qwen3" in model_path.lower(): - # tokenizer.eos_token_id = QWEN3_EOS_TOKEN_ID - # else: - # tokenizer.eos_token_id = tokenizer.eos_token_id - - # Set padding token if not exists - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # Initialize model and inference function based on backend - model = None - if backend == 'sglang': - # SGLang backend - print(f"Loading SGLang engine from: {model_path}") - - try: - # Set sampling parameters for SGLang - sampling_params = { - "temperature": config['temperature'], - "max_new_tokens": config['max_new_tokens'], - "top_p": config['top_p'], - } - if config.get('top_k', None) is not None: - sampling_params["top_k"] = config['top_k'] - if config.get('min_p', None) is not None: - sampling_params["min_p"] = config['min_p'] - - # Create LLM engine with specified tp_size - model = sgl.Engine( - model_path=model_path, - dtype=config.get('dtype', 'bfloat16'), - tp_size=tp_size, - mem_fraction_static=config.get('mem_fraction_static', 0.90), - host="127.0.0.1", - port=int(os.getenv("SGLANG_NCCL_PORT", 30000)), - attention_backend=config.get('attention_backend', 'triton'), - ) - - _warmup_model(model, tokenizer, backend, tp_size) - - except Exception as e: - print(f"Error loading SGLang engine: {e}") - raise - - def inference_function(inputs): - """SGLang inference function""" - batch_outputs = [] - for i in range(0, len(inputs), config['batch_size']): - batch = inputs[i:i + config['batch_size']] - - def generate_batch(): - return model.generate(batch, sampling_params) - - outputs, elapsed_time = _time_inference(generate_batch) - batch_outputs.extend([(out['text'], elapsed_time) for out in outputs]) - - return batch_outputs - - elif backend == 'hf': - # Hugging Face backend - print(f"Loading Hugging Face model from: {model_path}") - print(f"Process CUDA devices available: {torch.cuda.device_count()}") - - try: - # In multiprocess environment, always use device_map="auto" to properly handle GPU allocation - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype=getattr(torch, config.get('dtype', 'bfloat16')), - device_map="auto", # Let transformers handle device allocation based on visible GPUs - trust_remote_code=True, - attn_implementation="flash_attention_2" if config.get('use_flash_attention', False) else None, - low_cpu_mem_usage=True # Important for proper weight initialization - ) - - # Set sampling parameters for Hugging Face - generation_config = { - "temperature": config['temperature'], - "max_new_tokens": config['max_new_tokens'], - "top_p": config['top_p'], - "do_sample": True if config['temperature'] > 0.0 else False, - "pad_token_id": tokenizer.eos_token_id, - "eos_token_id": tokenizer.eos_token_id, - } - - # Add min_p if supported - if config.get('min_p', None) is not None: - generation_config["min_p"] = config['min_p'] - - # Add top_k if supported - if config.get('top_k', None) is not None: - generation_config["top_k"] = config['top_k'] - - _warmup_model(model, tokenizer, backend, tp_size) - - except Exception as e: - print(f"Error loading Hugging Face model: {e}") - raise - - def inference_function(inputs): - """Hugging Face inference function""" - batch_outputs = [] - for i in range(0, len(inputs), config['batch_size']): - batch = inputs[i:i + config['batch_size']] - - # Tokenize batch - batch_inputs = tokenizer(batch, return_tensors="pt", padding=True, padding_side="left", truncation=True) - # Move inputs to the same device as the model - try: - model_device = next(model.parameters()).device - batch_inputs = {k: v.to(model_device) for k, v in batch_inputs.items()} - except StopIteration: - # If model has no parameters, try to use first available CUDA device - if torch.cuda.is_available(): - batch_inputs = {k: v.cuda() for k, v in batch_inputs.items()} - - def generate_batch(): - with torch.no_grad(): - return model.generate(**batch_inputs, **generation_config) - - outputs, elapsed_time = _time_inference(generate_batch, torch.cuda.is_available()) - - # Decode outputs - for j, output in enumerate(outputs): - # Remove input tokens from output - input_length = batch_inputs['input_ids'][j].shape[0] - generated_tokens = output[input_length:] - output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) - - batch_outputs.append((output_text, elapsed_time / len(outputs))) - - return batch_outputs - - elif backend == 'tah': - # TaH backend - print(f"Loading TaH model from: {model_path}") - print(f"Process CUDA devices available: {torch.cuda.device_count()}") - - max_iter = config.get('max_iter', 3) - embedding_key = config.get('embedding_key', None) - iter_decider = config.get('iter_decider', None) - iter_decider_kwargs = config.get('iter_decider_kwargs', None) - eval_iter_decider = config.get('eval_iter_decider', None) - eval_iter_decider_kwargs = config.get('eval_iter_decider_kwargs', None) - use_tracker = config.get('use_tracker', False) - tracker_kwargs = config.get('tracker_kwargs', None) - prompt_iter_count = config.get('prompt_iter_count', None) - - override_config = TaHConfig( - embedding_key=embedding_key, - max_iter=max_iter, - iter_decider=iter_decider, - iter_decider_kwargs=iter_decider_kwargs, - eval_iter_decider=eval_iter_decider, - eval_iter_decider_kwargs=eval_iter_decider_kwargs, - ) - - try: - # In multiprocess environment, always use device_map="auto" - model = TaHForCausalLM.from_pretrained( - model_path, - torch_dtype=getattr(torch, config.get('dtype', 'bfloat16')), - device_map="auto", # Let the library handle device allocation - attn_implementation="sdpa", - # low_cpu_mem_usage=True, - tah_config=override_config, - ) - - - except Exception as e: - print(f"Error loading TaH model: {e}") - raise - - model = model.to(dtype=model.dtype) - - tracker = None - if use_tracker: - tracker = TaHTracker(top_k=tracker_kwargs.get('top_k', 5)) - tracker.attach(model) - - def inference_function(inputs): - """TaH inference function""" - batch_outputs = [] - for i in range(0, len(inputs), config['batch_size']): - batch = inputs[i:i + config['batch_size']] - - # Tokenize all inputs in the batch at once - input_tokens = tokenizer(batch, return_tensors="pt", padding=True, padding_side="left") - - # Move inputs to the same device as the model - model_device = model.device - input_tokens = {k: v.to(model_device) for k, v in input_tokens.items()} - - if prompt_iter_count is not None: - input_ids = input_tokens["input_ids"] - batch_size, seq_len = input_ids.shape - iter_count = prompt_iter_count * torch.ones( - batch_size, - seq_len, - dtype=torch.long, - device=model_device, - ) - else: - iter_count = None - - # Record the number of tracker records before generation if tracker is enabled - prev_record_len = len(tracker.records) if tracker else 0 - - def generate_batch(): - with torch.no_grad(): - return TaHForCasualLM_generate( - tah_model=model, - tokenizer=tokenizer, - model_inputs=input_tokens, - iter_count=iter_count, - max_new_tokens=config['max_new_tokens'], - do_sample=True if config['temperature'] > 0.0 else False, - temperature=config['temperature'], - top_p=config['top_p'], - top_k=config.get('top_k', 0), - min_p=config.get('min_p', 0.0), - verbose=False - ) - - (_, output_texts), elapsed_time = _time_inference(generate_batch, torch.cuda.is_available()) - - # Process tracker records if enabled - if tracker: - new_records = tracker.records[prev_record_len:] - records_by_batch = {} - for rec in new_records: - bidx = rec.get("batch_idx", 0) - records_by_batch.setdefault(bidx, []).append(rec) - - for j, output_text in enumerate(output_texts): - sample_records = records_by_batch.get(j, []) - batch_outputs.append((output_text, elapsed_time / len(output_texts), sample_records)) - else: - for output_text in output_texts: - batch_outputs.append((output_text, elapsed_time / len(output_texts))) - - return batch_outputs - - else: - raise ValueError(f"Unsupported backend: {backend}. Choose 'sglang', 'hf', or 'tah'.") - - def cleanup_function(): - """Common cleanup function""" - _cleanup_resources(model, backend) - - # No longer require batch_size to be a multiple of repeat_size - print(f"Processing with batch_size={config['batch_size']}, repeat_size={config['repeat_size']}") - - # Store all results - all_results = [] - problem_stats = {} - - # Create intermediate results file - intermediate_stats_file = output_dir / "intermediate_stats.csv" - - # Prepare detailed results CSV file and write header - results_file = output_dir / "detailed_results.csv" - fieldnames = ["problem_id", "sample_idx", "correct_answer", "predicted_answer", - "has_answer", "is_correct", "input_tokens", "output_tokens", - "processing_time"] - with open(results_file, 'w', newline='', encoding='utf-8') as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - - # For code datasets, use unified file or create job-specific file - answer_type = field_mapping.get('answer_type', 'boxed') - is_code_dataset = answer_type in ['livecodebench', 'humaneval', 'mbpp'] - code_solutions_file = unified_code_solutions_file # Use unified file if provided - - # Prepare all problem data first - problem_data = [] - for idx, item in enumerate(problems): - # Prefer original index if provided (for interleaved assignment), fallback to sequential - actual_idx = item.get('_original_index', (start_idx + idx)) - - # Get problem ID using dynamic field mapping - id_field = field_mapping['id_field'] - if id_field in item and item[id_field] is not None: - problem_id = str(item[id_field]) - else: - problem_id = f"problem_{actual_idx}" - - # Get problem text using dynamic field mapping - question_field = field_mapping['question_field'] - problem_text = str(item.get(question_field, '')).strip() - - # Apply prompt template if specified - prompt_template = field_mapping['prompt_template'] - if prompt_template and '{question}' in prompt_template: - problem_text = prompt_template.replace('{question}', problem_text) - - # Get correct answer using dynamic field mapping - answer_field = field_mapping['answer_field'] - correct_answer = str(item.get(answer_field, '')).strip() - - # Create problem-specific directory - problem_dir = detail_dir / problem_id - problem_dir.mkdir(parents=True, exist_ok=True) - - # Prepare problem data dict - prob_dict = { - 'problem_id': problem_id, - # Preserve the original task id for downstream evaluators (e.g., evalplus) - 'original_problem_id': item.get('_original_id', problem_id), - 'problem_text': problem_text, - 'correct_answer': correct_answer, - 'problem_dir': problem_dir, - 'actual_idx': actual_idx, - } - - if is_code_dataset: - prob_dict['entry_point'] = item['entry_point'] - - problem_data.append(prob_dict) - - # Prepare all inputs upfront (each problem repeated repeat_size times) - all_inputs = [] - input_to_problem_mapping = [] # Track which input belongs to which problem and sample - - for prob_idx, prob_data in enumerate(problem_data): - problem_text = prob_data['problem_text'] - # Create repeat_size copies of this problem - for sample_idx in range(config['repeat_size']): - if is_code_dataset: - input_text = codeeval.make_raw_chat_prompt_for_code_evaluation(task_prompt=problem_text, reasoning=False, tokenizer=tokenizer) - else: - messages = [{"role": "user", "content": problem_text}] - input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - all_inputs.append(input_text) - input_to_problem_mapping.append((prob_idx, sample_idx)) - - # Process all inputs in batches using the specified batch_size - total_batches = math.ceil(len(all_inputs) / config['batch_size']) - - for batch_idx in tqdm(range(total_batches), desc=f"Job {job_id} processing inference batches", position=job_id, leave=True): - batch_start = batch_idx * config['batch_size'] - batch_end = min(batch_start + config['batch_size'], len(all_inputs)) - batch_inputs = all_inputs[batch_start:batch_end] - - # Run inference for this batch - batch_outputs = inference_function(batch_inputs) - - # Process and save results for this batch immediately - with open(results_file, 'a', newline='', encoding='utf-8') as f_results: - writer = csv.DictWriter(f_results, fieldnames=fieldnames) - for i, inference_output in enumerate(batch_outputs): - input_idx = batch_start + i - prob_idx, sample_idx = input_to_problem_mapping[input_idx] - - prob_data = problem_data[prob_idx] - problem_id = prob_data['problem_id'] - # original_problem_id = prob_data['original_problem_id'] - problem_text = prob_data['problem_text'] - correct_answer = prob_data['correct_answer'] - problem_dir = prob_data['problem_dir'] - - # Unpack output data - if isinstance(inference_output, tuple) and len(inference_output) == 3: - output_text, proc_time, sample_tracker_records = inference_output - else: - output_text, proc_time = inference_output - sample_tracker_records = None - - # Extract answer based on dataset type - # Check if this is a code evaluation dataset - answer_type = field_mapping.get('answer_type', 'boxed') - is_code_dataset = answer_type in ['livecodebench', 'humaneval', 'mbpp'] - - if is_code_dataset: - # For code datasets, skip evaluation during generation - # Save to jsonl for later batch evaluation - predicted_answer = "pending_code_eval" - has_answer = False - is_correct = False - else: - # Math evaluation path (original logic) - result_eval = matheval.evaluator_map[combined_dataset_name].rule_judge(output_text, correct_answer) - if result_eval[1] == "No extracted answer": - predicted_answer = "" - has_answer = False - else: - predicted_answer = result_eval[1] - has_answer = True - is_correct = result_eval[0] - - # Calculate token counts - input_tokens = len(tokenizer.encode(problem_text)) - output_tokens = len(tokenizer.encode(output_text)) - - result_dict = { - "problem_id": problem_id, - "sample_idx": sample_idx, - "correct_answer": correct_answer, - "predicted_answer": predicted_answer, - "has_answer": has_answer, - "is_correct": is_correct, - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "processing_time": proc_time, - "full_output": output_text - } - - all_results.append(result_dict) - - # Save detailed output in problem-specific directory - detail_file = problem_dir / f"sample_{sample_idx}.json" - detail_data = { - "problem": problem_text, - "output": output_text, - "correct_answer": correct_answer, - "predicted_answer": predicted_answer, - "is_correct": is_correct, - } - - # For code datasets, also save extracted code - if is_code_dataset: - entry_point = prob_data['entry_point'] - extracted_code = codeeval.sanitize(output_text, entry_point) - detail_data["extracted_code"] = extracted_code - detail_data["entry_point"] = entry_point - - with open(detail_file, 'w', encoding='utf-8') as f_detail: - json.dump(detail_data, f_detail, ensure_ascii=False, indent=2) - - # Save tracker records if available - if sample_tracker_records: - tracker_file = problem_dir / f"sample_{sample_idx}_tracker.csv" - pd.DataFrame(sample_tracker_records).to_csv(tracker_file, index=False) - - # Write to detailed_results.csv - row_to_write = {k: v for k, v in result_dict.items() if k != 'full_output'} - writer.writerow(row_to_write) - - # For code datasets, save solution to JSONL for batch evaluation - if is_code_dataset and code_solutions_file: - import fcntl - # Use the original problem id so that it matches external evaluators' expectations - original_problem_id = prob_data.get('original_problem_id', problem_id) - solution_entry = { - "task_id": original_problem_id, - "solution": str(extracted_code) - } - # Use file lock to avoid conflicts when multiple jobs write simultaneously - with open(code_solutions_file, 'a', encoding='utf-8') as f_code: - fcntl.flock(f_code.fileno(), fcntl.LOCK_EX) # Exclusive lock - try: - f_code.write(json.dumps(solution_entry, ensure_ascii=False) + '\n') - finally: - fcntl.flock(f_code.fileno(), fcntl.LOCK_UN) # Unlock - - # Group results by problem_id to calculate statistics - results_by_problem = {} - for r in all_results: - pid = r['problem_id'] - if pid not in results_by_problem: - results_by_problem[pid] = [] - results_by_problem[pid].append(r) - - # Calculate stats for each problem - for problem_id, results in results_by_problem.items(): - correct_count = sum(1 for r in results if r['is_correct']) - total_samples = len(results) - accuracy = correct_count / total_samples if total_samples > 0 else 0 - avg_output_length = sum(r['output_tokens'] for r in results) / total_samples if total_samples > 0 else 0 - problem_stats[problem_id] = { - "accuracy": f"{accuracy:.3f}", - "correct_count": correct_count, - "total_samples": total_samples, - "avg_output_length": avg_output_length - } - - # Save intermediate statistics after processing all problems - with open(intermediate_stats_file, 'w', newline='', encoding='utf-8') as f: - writer = csv.writer(f) - writer.writerow(["problem_id", "accuracy", "correct_count", "total_samples", "avg_output_length"]) - - for pid, stats in problem_stats.items(): - writer.writerow([pid, stats['accuracy'], stats['correct_count'], stats['total_samples'], f"{stats['avg_output_length']:.2f}"]) - - # Calculate overall statistics - total_correct = sum(r['is_correct'] for r in all_results) - total_accuracy = total_correct / len(all_results) if all_results else 0 - overall_avg_output_length = sum(r['output_tokens'] for r in all_results) / len(all_results) if all_results else 0 - - # Save statistics to CSV - stats_file = output_dir / "evaluation_stats.csv" - with open(stats_file, 'w', newline='', encoding='utf-8') as f: - writer = csv.writer(f) - writer.writerow(["problem_id", "accuracy", "correct_count", "total_samples", "avg_output_length"]) - - for problem_id, stats in problem_stats.items(): - writer.writerow([problem_id, stats['accuracy'], stats['correct_count'], stats['total_samples'], f"{stats['avg_output_length']:.2f}"]) - - writer.writerow([]) - writer.writerow(["Total Accuracy", f"{total_accuracy:.3f}", total_correct, len(all_results), f"{overall_avg_output_length:.2f}"]) - - print(f"\nJob {job_id} completed!") - print(f"Job accuracy: {total_accuracy:.4f}") - - # Clean up resources - cleanup_function() - - -def _is_port_available(port: int) -> bool: - """Check if a port is available for binding""" - import socket - # Create a socket and try to bind to the port - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(('127.0.0.1', port)) - return True - -def _run_job_process(job_args: Tuple, result_queue: Queue): - """Run a single job in a separate process with isolated GPU environment""" - (job_id, config, combined_dataset_name, output_dir, timestamp, model_path, - job_nums, start_idx, end_idx, tp_size, backend, data_range, gpu_devices, problems_data, field_mapping, unified_code_solutions_file) = job_args - - # initialize logger in generated process - lvl_name = (config.get("_logger_level") or "WARNING").upper() - hf_level = getattr(hf_logging, lvl_name, hf_logging.WARNING) - std_level = getattr(pylog, lvl_name, pylog.WARNING) - hf_logging.set_verbosity(hf_level) - hf_logging.enable_default_handler() - hf_logging.enable_propagation() - pylog.basicConfig( - level=std_level, - format="%(asctime)s %(levelname)s %(name)s: %(message)s" - ) - - # Set GPU environment for this process - this is isolated per process - gpu_str = ','.join(map(str, gpu_devices)) - - # Set unique NCCL port for each job to avoid conflicts - # Use a base port and add job_id to ensure uniqueness - base_port = 29555 - max_retries = 100 - unique_port = None - - # Try to find an available port - for retry in range(max_retries): - port_candidate = base_port + job_id * max_retries + retry - if port_candidate > 65535: - port_candidate = 30514 + ((job_id + retry) % 100) - - if _is_port_available(port_candidate): - unique_port = port_candidate - break - - if unique_port is None: - raise RuntimeError(f"Job {job_id}: Could not find an available port after {max_retries} attempts") - - os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str - os.environ['MASTER_PORT'] = str(unique_port) - os.environ['SGLANG_NCCL_PORT'] = str(unique_port) - os.environ['MASTER_ADDR'] = '127.0.0.1' - - # Set random seed for this process - import torch - from tah.model.utils import set_all_seeds - seed = config.get("_random_seed", 420) - set_all_seeds(seed) - # Force CUDA to reinitialize in this process - if torch.cuda.is_available(): - torch.cuda.init() - - try: - print(f"\nJob {job_id}: Starting with GPUs {gpu_devices} (CUDA_VISIBLE_DEVICES={gpu_str})") - print(f"Job {job_id}: Using NCCL port {unique_port}") - print(f"Job {job_id}: Processing indices {start_idx} to {end_idx-1}") - print(f"Job {job_id}: Available CUDA devices in process: {torch.cuda.device_count()}") - - # Run the actual job - run_single_job( - config=config, - combined_dataset_name=combined_dataset_name, - output_dir=output_dir, - timestamp=timestamp, - model_path=model_path, - job_id=job_id, - job_nums=job_nums, - start_idx=start_idx, - end_idx=end_idx, - tp_size=tp_size, - backend=backend, - data_range=data_range, - problems_data=problems_data, - field_mapping=field_mapping, - unified_code_solutions_file=unified_code_solutions_file - ) - - print(f"Job {job_id}: Completed successfully") - result_queue.put((job_id, True, f"Job {job_id} completed successfully")) - - except Exception as e: - import traceback - error_msg = f"Job {job_id} failed: {str(e)}\n{traceback.format_exc()}" - print(f"\nError in {error_msg}") - result_queue.put((job_id, False, error_msg)) - finally: - # Clean up environment variables - if 'MASTER_PORT' in os.environ: - del os.environ['MASTER_PORT'] - if 'MASTER_ADDR' in os.environ: - del os.environ['MASTER_ADDR'] - if 'SGLANG_NCCL_PORT' in os.environ: - del os.environ['SGLANG_NCCL_PORT'] - if 'NCCL_SOCKET_IFNAME' in os.environ: - del os.environ['NCCL_SOCKET_IFNAME'] - -def allocate_gpus_and_run_jobs(args): - """Allocate GPUs to jobs and run them using multiprocessing""" - # Set multiprocessing start method to 'spawn' for CUDA compatibility - mp.set_start_method('spawn', force=True) - # Create a unified timestamp for this run - timestamp = time.strftime("%Y%m%d_%H%M%S") - - # Calculate GPU allocation - gpus_per_job = args.tp_size_per_job - - print(f"Running {args.job_nums} jobs with {gpus_per_job} GPUs per job") - - # Get the current CUDA_VISIBLE_DEVICES setting - current_cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '') - if current_cuda_devices: - available_gpus = [int(x.strip()) for x in current_cuda_devices.split(',') if x.strip()] - else: - available_gpus = [0,1,2,3,4,5,6,7] - - print(f"Available GPU devices: {available_gpus}") - - # Load dataset to get total problem count and apply optional data_range / data_ids - dataset_names = [name.strip() for name in args.dataset_name.split(',')] - dataset, field_mapping = load_datasets_with_config(dataset_names) - combined_dataset_name = "_".join(dataset_names) - - total_problems = len(dataset) - - # Unified problem selection: get selected problem indices regardless of selection method - if getattr(args, 'data_ids', None): - # Select by specific problem IDs - valid_ids = set(str(item.get('id')) for item in dataset) - raw_ids = [s.strip() for s in str(args.data_ids).split(',') if s.strip() != ''] - seen = set() - selected_problems = [] - for pid in raw_ids: - if pid in valid_ids and pid not in seen: - seen.add(pid) - selected_problems.append(pid) - if not selected_problems: - raise ValueError("--data_ids did not match any problem IDs. Use standardized IDs like _.") - else: - # Select by data range - convert to list of indices - range_start, range_end = parse_data_range(args.data_range, total_problems) - selected_problems = list(range(range_start, range_end)) - - problems_per_job = math.ceil(len(selected_problems) / args.job_nums) - - # Load configuration once for all jobs - eval_config = load_config(args.eval_config) - eval_config["_logger_level"] = args.logger_level - # Pass random seed from CLI args into config so worker processes can read it - eval_config["_random_seed"] = getattr(args, "random_seed", 420) - - # Save the yaml configuration file to output directory - # Create output directory structure first - task_suffix = "" - if args.data_range and not getattr(args, 'data_ids', None): - range_start, range_end = parse_data_range(args.data_range, total_problems) - task_suffix = f"TASK_{range_start}_{range_end}" - - default_output_dir = Path(args.model_path) / "eval_results" - if args.output_dir is None: - args.output_dir = default_output_dir - - combined_output_dir = Path(args.output_dir) / (combined_dataset_name + "_" + args.backend) / timestamp - if task_suffix: - combined_output_dir = combined_output_dir / task_suffix - combined_output_dir.mkdir(parents=True, exist_ok=True) - - # Save the original eval config file to output directory - import shutil - config_filename = Path(args.eval_config).name - saved_config_path = combined_output_dir / config_filename - shutil.copy2(args.eval_config, saved_config_path) - print(f"Saved evaluation config to: {saved_config_path}") - - # For code datasets, prepare unified code_solutions.jsonl file - answer_type = field_mapping.get('answer_type', 'boxed') - is_code_dataset = answer_type in ['livecodebench', 'humaneval', 'mbpp'] - unified_code_solutions_file = None - - if is_code_dataset: - unified_code_solutions_file = combined_output_dir / "code_solutions.jsonl" - # Create empty file - with open(unified_code_solutions_file, 'w', encoding='utf-8') as f: - pass - print(f"Created unified code solutions file: {unified_code_solutions_file}") - - # Prepare job arguments by splitting selected problems across jobs - job_args_list = [] - for job_id in range(args.job_nums): - # Split selected problems across jobs using interleaving - job_problems_indices = selected_problems[job_id::args.job_nums] - - if not job_problems_indices: - continue - - gpu_start = job_id * gpus_per_job - job_gpus = available_gpus[gpu_start:gpu_start + gpus_per_job] - - # Extract actual problem data for this job - if getattr(args, 'data_ids', None): - # For data_ids mode, filter by problem IDs - id_to_item = {str(item.get('id')): item for item in dataset} - job_problems_data = [id_to_item[str(pid)] for pid in job_problems_indices if str(pid) in id_to_item] - start_idx = 0 - end_idx = len(job_problems_indices) - else: - # For data_range mode, use indices directly and attach original index for bookkeeping - job_problems_data = [] - for i in job_problems_indices: - item = dict(dataset[i]) - item['_original_index'] = i - job_problems_data.append(item) - start_idx = 0 - end_idx = len(job_problems_indices) - - # Create unified job_args with pre-processed data - job_args = ( - job_id, eval_config, combined_dataset_name, args.output_dir, timestamp, args.model_path, - args.job_nums, start_idx, end_idx, gpus_per_job, args.backend, args.data_range, - job_gpus, job_problems_data, field_mapping, unified_code_solutions_file - ) - - job_args_list.append(job_args) - - print(f"\nPrepared {len(job_args_list)} jobs for execution") - - # Create result queue for inter-process communication - result_queue = mp.Queue() - - # Execute jobs using multiprocessing with limited concurrency - completed_jobs = 0 - failed_jobs = 0 - active_processes = [] - - # Start processes in batches based on max_concurrent_jobs - job_idx = 0 - while job_idx < len(job_args_list) or active_processes: - # Start new processes if we have capacity - while job_idx < len(job_args_list): - job_args = job_args_list[job_idx] - p = Process(target=_run_job_process, args=(job_args, result_queue), name=f"sgl-{job_idx}") - p.start() - active_processes.append((p, job_args[0])) # Store process with job_id - print(f"Started job {job_args[0]}") - job_idx += 1 - - # Check for completed processes - still_active = [] - for p, job_id in active_processes: - if p.is_alive(): - still_active.append((p, job_id)) - else: - p.join(timeout=1) # Ensure process is cleaned up - if p.exitcode != 0 and p.exitcode is not None: - print(f"Process for job {job_id} exited with code {p.exitcode}") - - active_processes = still_active - - # Process results from queue (non-blocking) - while not result_queue.empty(): - job_id_result, success, message = result_queue.get_nowait() - if success: - completed_jobs += 1 - print(f"\n✓ Job {job_id_result} finished successfully") - else: - failed_jobs += 1 - print(f"\n✗ Job {job_id_result} failed: {message}") - - # Small sleep to prevent busy waiting - if active_processes: - time.sleep(0.1) - - # Final check for any remaining results - while not result_queue.empty(): - job_id_result, success, message = result_queue.get_nowait() - if success: - completed_jobs += 1 - print(f"\n✓ Job {job_id_result} finished successfully") - else: - failed_jobs += 1 - print(f"\n✗ Job {job_id_result} failed: {message}") - - print(f"\nAll jobs completed!") - print(f"Successful jobs: {completed_jobs}") - print(f"Failed jobs: {failed_jobs}") - - if failed_jobs > 0: - print(f"Warning: {failed_jobs} jobs failed. Results may be incomplete.") - - # Combine results - print("\nCombining results from all jobs...") - combined_output_dir = Path(args.output_dir) / (combined_dataset_name + "_" + args.backend) / timestamp - if args.data_range and not getattr(args, 'data_ids', None): - range_start, range_end = parse_data_range(args.data_range, total_problems) - combined_output_dir = combined_output_dir / f"TASK_{range_start}_{range_end}" - combine_job_results(combined_output_dir, len(job_args_list), args.del_job_dir) - - # For code datasets, run batch evaluation using the unified file - if is_code_dataset and unified_code_solutions_file and unified_code_solutions_file.exists(): - print(f"\n{'='*60}") - print(f"Starting code evaluation for {combined_dataset_name}...") - print(f"Solutions file: {unified_code_solutions_file}") - print(f"Total lines: {sum(1 for _ in open(unified_code_solutions_file))}") - print(f"{'='*60}\n") - - # Import codeeval - from tah.evaluate.codeeval import evaluate as code_evaluate - - # Determine dataset name for evalplus (humaneval or mbpp) - answer_type = field_mapping.get('answer_type', 'boxed') - evalplus_dataset = answer_type if answer_type in ['humaneval', 'mbpp'] else 'humaneval' - - # Call codeeval.evaluate - code_evaluate( - dataset=evalplus_dataset, - samples=str(unified_code_solutions_file), - ) - - print(f"\n{'='*60}") - print(f"Code evaluation completed!") - print(f"Results saved to: {str(unified_code_solutions_file).replace('.jsonl', '.eval_results.json')}") - print(f"{'='*60}\n") - - -def load_config(config_path: str) -> Dict: - """Load YAML configuration file""" - with open(config_path, 'r') as f: - return yaml.safe_load(f) - -# Utility to parse data_range argument - -def parse_data_range(data_range_list, total_problems: int = None) -> Tuple[int, int]: - """Parse data_range list and return (start_idx, end_idx). - - data_range_list format examples: - - [200] -> start=0, end=200 - - [100, 200] -> start=100, end=200 - - If total_problems is provided, end_idx will be clipped to this value. - The returned end_idx is exclusive (i.e. slice compatible). - """ - if not data_range_list: - return 0, total_problems if total_problems is not None else 0 - - # Handle single value (treated as end index) - if len(data_range_list) == 1: - start_idx, end_idx = 0, data_range_list[0] - elif len(data_range_list) == 2: - start_idx, end_idx = data_range_list[0], data_range_list[1] - else: - raise ValueError(f"Invalid data_range: expected 1 or 2 values, got {len(data_range_list)}") - - if total_problems is not None: - end_idx = min(end_idx, total_problems) - - if start_idx < 0 or end_idx <= start_idx: - raise ValueError(f"Invalid data_range: start={start_idx}, end={end_idx}") - - return start_idx, end_idx \ No newline at end of file +from tah.evaluate.datasets import load_combined_dataset as load_datasets_with_config # noqa: F401 +from tah.evaluate.jobs import ( # noqa: F401 + allocate_gpus_and_run_jobs, + combine_job_results, + parse_data_range, + run_single_job, +) diff --git a/tah/evaluate/jobs.py b/tah/evaluate/jobs.py new file mode 100644 index 0000000..4ef0724 --- /dev/null +++ b/tah/evaluate/jobs.py @@ -0,0 +1,624 @@ +"""Multi-job evaluation runner. + +Three layers of orchestration: + +* :func:`run_single_job` — load a backend, run inference over an assigned + slice of problems, score each output, write per-problem CSV/JSON files. +* :func:`_run_job_process` — process wrapper that pins ``CUDA_VISIBLE_DEVICES`` + and a fresh NCCL port, then calls ``run_single_job``. +* :func:`allocate_gpus_and_run_jobs` — top-level entry point. Splits the + selected problems across jobs, fans out one process per job, joins, then + combines per-job outputs. + +:func:`combine_job_results` aggregates the per-job CSV/JSON files into a +single set of files at the run-level directory. + +:func:`parse_data_range` parses the ``--data_range`` CLI argument format. +""" +from __future__ import annotations + +import csv +import fcntl +import json +import logging as pylog +import math +import os +import shutil +import socket +import time +import traceback +from multiprocessing import Process, Queue +import multiprocessing as mp +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import yaml +from tqdm import tqdm +from transformers.utils import logging as hf_logging + +from tah.evaluate import codeeval, matheval +from tah.evaluate.backends import cleanup, setup_backend +from tah.evaluate.datasets import load_combined_dataset + +# Spawn is required for CUDA-capable subprocess workers. +mp.set_start_method("spawn", force=True) + + +_CSV_FIELDS = ( + "problem_id", "sample_idx", "correct_answer", "predicted_answer", + "has_answer", "is_correct", "input_tokens", "output_tokens", "processing_time", +) +_CODE_ANSWER_TYPES = frozenset({"livecodebench", "humaneval", "mbpp"}) + + +# ──────────────────────────────────────────────────────────────────────────── +# Per-problem prompt building + scoring +# ──────────────────────────────────────────────────────────────────────────── + + +def _build_problem(item: dict, idx: int, field_mapping: Dict, detail_dir: Path) -> dict: + """Pull standard fields out of an item; also create the per-problem dir.""" + pid = str(item.get(field_mapping["id_field"]) or f"problem_{idx}") + text = str(item.get(field_mapping["question_field"], "")).strip() + template = field_mapping.get("prompt_template") + if template and "{question}" in template: + text = template.replace("{question}", text) + answer = str(item.get(field_mapping["answer_field"], "")).strip() + problem_dir = detail_dir / pid + problem_dir.mkdir(parents=True, exist_ok=True) + return { + "problem_id": pid, + "original_problem_id": item.get("_original_id", pid), + "problem_text": text, + "correct_answer": answer, + "problem_dir": problem_dir, + "entry_point": item.get("entry_point"), + } + + +def _make_prompt(text: str, tokenizer, *, is_code: bool) -> str: + if is_code: + return codeeval.make_raw_chat_prompt_for_code_evaluation( + task_prompt=text, reasoning=False, tokenizer=tokenizer, + ) + return tokenizer.apply_chat_template( + [{"role": "user", "content": text}], tokenize=False, add_generation_prompt=True, + ) + + +def _score_one(output_text: str, correct: str, dataset_name: str, *, is_code: bool) -> Tuple[str, bool, bool]: + """Return ``(predicted_answer, has_answer, is_correct)``. + + Code datasets defer scoring until after generation finishes (run via + evalplus on the unified solutions file), so we just stamp placeholders. + """ + if is_code: + return "pending_code_eval", False, False + is_correct, predicted = matheval.evaluator_map[dataset_name].rule_judge(output_text, correct) + if predicted == "No extracted answer": + return "", False, bool(is_correct) + return predicted, True, bool(is_correct) + + +def _append_code_solution(path: Path, task_id: str, code: str) -> None: + """Append a {task_id, solution} entry to the unified jsonl with file lock.""" + with open(path, "a", encoding="utf-8") as f: + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + try: + f.write(json.dumps({"task_id": task_id, "solution": str(code)}, ensure_ascii=False) + "\n") + finally: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + + +# ──────────────────────────────────────────────────────────────────────────── +# Per-job runner +# ──────────────────────────────────────────────────────────────────────────── + + +def _job_output_dir(output_dir: str, combined_dataset_name: str, backend: str, + timestamp: str, job_id: int, data_range) -> Path: + """``/_/[/TASK_a_b]/job_/``.""" + out = Path(output_dir) / f"{combined_dataset_name}_{backend}" / timestamp + if data_range: + a, b = parse_data_range(data_range) + out = out / f"TASK_{a}_{b}" + return out / f"job_{job_id}" + + +def _build_prompts( + problems: List[dict], tokenizer, repeat_size: int, is_code: bool, +) -> Tuple[List[str], List[Tuple[int, int]]]: + """Render each problem ``repeat_size`` times and return + ``(prompts, prompt_to_problem)`` where ``prompt_to_problem[i] = (problem_idx, sample_idx)``.""" + prompts: List[str] = [] + prompt_to_problem: List[Tuple[int, int]] = [] + for pi, p in enumerate(problems): + rendered = _make_prompt(p["problem_text"], tokenizer, is_code=is_code) + for s in range(repeat_size): + prompts.append(rendered) + prompt_to_problem.append((pi, s)) + return prompts, prompt_to_problem + + +def _process_batch( + batch_outputs: List[Tuple[str, float]], + batch_offset: int, + prompts: List[str], + prompt_to_problem: List[Tuple[int, int]], + problems: List[dict], + tokenizer, + *, + combined_dataset_name: str, + is_code: bool, + unified_code_solutions_file: Optional[Path], + writer, + all_results: List[dict], +) -> None: + """For each (text, time) in ``batch_outputs``, score it, write the + per-sample json under the problem's dir, append a row to the open CSV + writer + the in-memory ``all_results`` list.""" + for j, (output_text, proc_time) in enumerate(batch_outputs): + pi, sample_idx = prompt_to_problem[batch_offset + j] + p = problems[pi] + pred, has_ans, ok = _score_one(output_text, p["correct_answer"], combined_dataset_name, is_code=is_code) + + detail = { + "problem": p["problem_text"], + "output": output_text, + "correct_answer": p["correct_answer"], + "predicted_answer": pred, + "is_correct": ok, + } + if is_code: + extracted = codeeval.sanitize(output_text, p["entry_point"]) + detail["extracted_code"] = extracted + detail["entry_point"] = p["entry_point"] + if unified_code_solutions_file is not None: + _append_code_solution(unified_code_solutions_file, p["original_problem_id"], extracted) + with open(p["problem_dir"] / f"sample_{sample_idx}.json", "w", encoding="utf-8") as fd: + json.dump(detail, fd, ensure_ascii=False, indent=2) + + row = { + "problem_id": p["problem_id"], + "sample_idx": sample_idx, + "correct_answer": p["correct_answer"], + "predicted_answer": pred, + "has_answer": has_ans, + "is_correct": ok, + "input_tokens": len(tokenizer.encode(p["problem_text"])), + "output_tokens": len(tokenizer.encode(output_text)), + "processing_time": proc_time, + } + writer.writerow(row) + all_results.append(row) + + +def run_single_job( + *, config: Dict, combined_dataset_name: str, output_dir: str, timestamp: str, + model_path: str, job_id: int, job_nums: int, start_idx: int, end_idx: int, + tp_size: int, backend: str, data_range, problems_data, field_mapping: Dict, + unified_code_solutions_file: Optional[Path], +) -> None: + """Run inference for one job (one process) and write per-problem files.""" + from transformers import AutoTokenizer + + out = _job_output_dir(output_dir, combined_dataset_name, backend, timestamp, job_id, data_range) + detail_dir = out / "details" + detail_dir.mkdir(parents=True, exist_ok=True) + print(f"Job {job_id+1}/{job_nums}: {len(problems_data)} problems, backend={backend}, datasets={combined_dataset_name}") + + tokenizer = AutoTokenizer.from_pretrained(model_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + model, infer = setup_backend(backend, config, model_path, tokenizer, tp_size=tp_size) + + is_code = field_mapping.get("answer_type", "boxed") in _CODE_ANSWER_TYPES + problems = [ + _build_problem(item, item.get("_original_index", start_idx + i), field_mapping, detail_dir) + for i, item in enumerate(problems_data) + ] + prompts, prompt_to_problem = _build_prompts(problems, tokenizer, config["repeat_size"], is_code) + print(f"Processing batch_size={config['batch_size']} repeat_size={config['repeat_size']} → {len(prompts)} prompts") + + results_file = out / "detailed_results.csv" + with open(results_file, "w", newline="", encoding="utf-8") as f: + csv.DictWriter(f, fieldnames=_CSV_FIELDS).writeheader() + + all_results: List[dict] = [] + bs = config["batch_size"] + for bi in tqdm(range(math.ceil(len(prompts) / bs)), + desc=f"Job {job_id} batches", position=job_id, leave=True): + s, e = bi * bs, min((bi + 1) * bs, len(prompts)) + batch_outputs = infer(prompts[s:e]) + with open(results_file, "a", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=_CSV_FIELDS) + _process_batch( + batch_outputs, batch_offset=s, + prompts=prompts, prompt_to_problem=prompt_to_problem, + problems=problems, tokenizer=tokenizer, + combined_dataset_name=combined_dataset_name, is_code=is_code, + unified_code_solutions_file=unified_code_solutions_file, + writer=writer, all_results=all_results, + ) + + _write_job_stats(out, all_results) + print(f"Job {job_id} done; accuracy={_overall_accuracy(all_results):.4f}") + cleanup(model, backend) + + +def _overall_accuracy(rows: List[dict]) -> float: + return sum(1 for r in rows if r["is_correct"]) / len(rows) if rows else 0.0 + + +def _write_job_stats(out: Path, all_results: List[dict]) -> None: + """Write evaluation_stats.csv with per-problem stats + a totals row.""" + by_problem: Dict[str, List[dict]] = {} + for r in all_results: + by_problem.setdefault(r["problem_id"], []).append(r) + + rows = [] + for pid, rs in by_problem.items(): + n = len(rs) + c = sum(1 for r in rs if r["is_correct"]) + avg_out = sum(r["output_tokens"] for r in rs) / n if n else 0.0 + rows.append((pid, c / n if n else 0.0, c, n, avg_out)) + + total_n = len(all_results) + total_c = sum(1 for r in all_results if r["is_correct"]) + total_avg_out = sum(r["output_tokens"] for r in all_results) / total_n if total_n else 0.0 + + with open(out / "evaluation_stats.csv", "w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerow(["problem_id", "accuracy", "correct_count", "total_samples", "avg_output_length"]) + for pid, acc, c, n, avg in rows: + w.writerow([pid, f"{acc:.3f}", c, n, f"{avg:.2f}"]) + w.writerow([]) + w.writerow(["Total Accuracy", f"{(total_c/total_n if total_n else 0):.3f}", total_c, total_n, f"{total_avg_out:.2f}"]) + + +# ──────────────────────────────────────────────────────────────────────────── +# Process wrapper (one per job) +# ──────────────────────────────────────────────────────────────────────────── + + +def _free_port(start: int = 29555, max_tries: int = 100) -> int: + """Find a port we can bind to from ``start``, wrapping into the 30k range if needed.""" + for i in range(max_tries): + p = start + i + if p > 65535: + p = 30514 + (i % 100) + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("127.0.0.1", p)) + return p + except OSError: + continue + raise RuntimeError(f"no free port after {max_tries} tries from {start}") + + +def _setup_logging(level_name: str) -> None: + level = level_name.upper() + hf_logging.set_verbosity(getattr(hf_logging, level, hf_logging.WARNING)) + hf_logging.enable_default_handler() + hf_logging.enable_propagation() + pylog.basicConfig( + level=getattr(pylog, level, pylog.WARNING), + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + +def _run_job_process(job_args: Tuple, result_queue: Queue) -> None: + """One job per Process; pins CUDA_VISIBLE_DEVICES + NCCL port for the worker.""" + (job_id, config, combined_dataset_name, output_dir, timestamp, model_path, + job_nums, start_idx, end_idx, tp_size, backend, data_range, gpu_devices, + problems_data, field_mapping, unified_code_solutions_file) = job_args + + _setup_logging(config.get("_logger_level") or "WARNING") + + port = _free_port(start=29555 + job_id * 100) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_devices)) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["SGLANG_NCCL_PORT"] = str(port) + + import torch + from tah.model.utils import set_all_seeds + set_all_seeds(config.get("_random_seed", 420)) + if torch.cuda.is_available(): + torch.cuda.init() + + try: + print(f"\nJob {job_id}: GPUs={gpu_devices}, NCCL port={port}, indices {start_idx}..{end_idx-1}") + run_single_job( + config=config, combined_dataset_name=combined_dataset_name, + output_dir=output_dir, timestamp=timestamp, model_path=model_path, + job_id=job_id, job_nums=job_nums, start_idx=start_idx, end_idx=end_idx, + tp_size=tp_size, backend=backend, data_range=data_range, + problems_data=problems_data, field_mapping=field_mapping, + unified_code_solutions_file=unified_code_solutions_file, + ) + result_queue.put((job_id, True, "ok")) + except Exception as e: + result_queue.put((job_id, False, f"Job {job_id} failed: {e}\n{traceback.format_exc()}")) + finally: + for k in ("MASTER_PORT", "MASTER_ADDR", "SGLANG_NCCL_PORT", "NCCL_SOCKET_IFNAME"): + os.environ.pop(k, None) + + +# ──────────────────────────────────────────────────────────────────────────── +# Top-level orchestrator +# ──────────────────────────────────────────────────────────────────────────── + + +def parse_data_range(data_range_list, total_problems: Optional[int] = None) -> Tuple[int, int]: + """Parse ``--data_range`` argument: ``[end]`` or ``[start, end]`` (end exclusive).""" + if not data_range_list: + return 0, total_problems if total_problems is not None else 0 + if len(data_range_list) == 1: + s, e = 0, data_range_list[0] + elif len(data_range_list) == 2: + s, e = data_range_list[0], data_range_list[1] + else: + raise ValueError(f"data_range expects 1 or 2 values, got {len(data_range_list)}") + if total_problems is not None: + e = min(e, total_problems) + if s < 0 or e <= s: + raise ValueError(f"invalid data_range: start={s}, end={e}") + return s, e + + +def _select_problems(args, dataset: List[dict]) -> List: + """Resolve --data_ids or --data_range into a list of items / indices to evaluate.""" + if getattr(args, "data_ids", None): + valid_ids = {str(item.get("id")) for item in dataset} + seen, out = set(), [] + for pid in (s.strip() for s in str(args.data_ids).split(",") if s.strip()): + if pid in valid_ids and pid not in seen: + seen.add(pid) + out.append(pid) + if not out: + raise ValueError("--data_ids matched no problem IDs (use _)") + return out + a, b = parse_data_range(args.data_range, len(dataset)) + return list(range(a, b)) + + +def _build_job_args(args, dataset, selected, eval_config, combined_dataset_name, timestamp, + available_gpus, gpus_per_job, field_mapping, unified_code_solutions_file): + """Materialise per-job tuples for _run_job_process. Splits ``selected`` interleaved across jobs.""" + by_id = {str(item.get("id")): item for item in dataset} if getattr(args, "data_ids", None) else None + job_args_list = [] + for job_id in range(args.job_nums): + slice_ = selected[job_id::args.job_nums] + if not slice_: + continue + gpu_start = job_id * gpus_per_job + job_gpus = available_gpus[gpu_start:gpu_start + gpus_per_job] + if by_id is not None: + job_problems = [by_id[str(p)] for p in slice_ if str(p) in by_id] + else: + job_problems = [] + for i in slice_: + item = dict(dataset[i]) + item["_original_index"] = i + job_problems.append(item) + job_args_list.append(( + job_id, eval_config, combined_dataset_name, args.output_dir, timestamp, + args.model_path, args.job_nums, 0, len(slice_), gpus_per_job, args.backend, + args.data_range, job_gpus, job_problems, field_mapping, unified_code_solutions_file, + )) + return job_args_list + + +def _run_jobs(job_args_list) -> Tuple[int, int]: + """Spawn one Process per job, drain the result queue, return ``(ok, failed)``.""" + queue: Queue = mp.Queue() + completed = failed = 0 + + def _drain_queue() -> None: + """Process any pending job results from the queue.""" + nonlocal completed, failed + while not queue.empty(): + jid, ok, msg = queue.get_nowait() + if ok: + completed += 1 + print(f"\n✓ Job {jid} ok") + else: + failed += 1 + print(f"\n✗ Job {jid} failed: {msg}") + + # Start every job up front; the workers run in parallel and we just wait. + processes: List[Tuple[Process, int]] = [] + for idx, ja in enumerate(job_args_list): + p = Process(target=_run_job_process, args=(ja, queue), name=f"sgl-{idx}") + p.start() + processes.append((p, ja[0])) + print(f"Started job {ja[0]}") + + # Poll until every worker is dead; reap dead processes and drain the queue. + while processes: + still_active = [] + for p, jid in processes: + if p.is_alive(): + still_active.append((p, jid)) + else: + p.join(timeout=1) + processes = still_active + _drain_queue() + if processes: + time.sleep(0.1) + + _drain_queue() # final pass for any results that arrived after the last check + return completed, failed + + +def allocate_gpus_and_run_jobs(args) -> None: + """Top-level: split selected problems across jobs, run in parallel, then combine.""" + timestamp = time.strftime("%Y%m%d_%H%M%S") + print(f"Running {args.job_nums} jobs with {args.tp_size_per_job} GPU(s) per job") + + # GPU pool comes from CUDA_VISIBLE_DEVICES (or 0..7 if unset). + visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + available_gpus = [int(x.strip()) for x in visible.split(",") if x.strip()] or list(range(8)) + print(f"Available GPU devices: {available_gpus}") + + dataset_names = [n.strip() for n in args.dataset_name.split(",")] + dataset, field_mapping = load_combined_dataset(dataset_names) + combined_dataset_name = "_".join(dataset_names) + + selected = _select_problems(args, dataset) + + # Load + tag YAML config (worker reads back the tags via config["_..."]). + with open(args.eval_config, "r") as f: + eval_config = yaml.safe_load(f) + eval_config["_logger_level"] = args.logger_level + eval_config["_random_seed"] = getattr(args, "random_seed", 420) + + # Output dir: /_/[/TASK_a_b]/ + if args.output_dir is None: + args.output_dir = Path(args.model_path) / "eval_results" + suffix = "" + if args.data_range and not getattr(args, "data_ids", None): + a, b = parse_data_range(args.data_range, len(dataset)) + suffix = f"TASK_{a}_{b}" + combined_dir = Path(args.output_dir) / f"{combined_dataset_name}_{args.backend}" / timestamp + if suffix: + combined_dir = combined_dir / suffix + combined_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(args.eval_config, combined_dir / Path(args.eval_config).name) + print(f"Saved evaluation config to: {combined_dir / Path(args.eval_config).name}") + + is_code = field_mapping.get("answer_type", "boxed") in _CODE_ANSWER_TYPES + code_solutions_file: Optional[Path] = None + if is_code: + code_solutions_file = combined_dir / "code_solutions.jsonl" + code_solutions_file.write_text("", encoding="utf-8") + print(f"Code solutions file: {code_solutions_file}") + + job_args_list = _build_job_args( + args, dataset, selected, eval_config, combined_dataset_name, timestamp, + available_gpus, args.tp_size_per_job, field_mapping, code_solutions_file, + ) + print(f"\nPrepared {len(job_args_list)} jobs for execution") + + completed, failed = _run_jobs(job_args_list) + print(f"\nDone. ok={completed} failed={failed}") + if failed > 0: + print(f"WARNING: {failed} job(s) failed; results may be incomplete.") + + print("\nCombining results from all jobs...") + combine_job_results(combined_dir, len(job_args_list), args.del_job_dir) + + if is_code and code_solutions_file is not None and code_solutions_file.exists(): + from tah.evaluate.codeeval import evaluate as code_evaluate + evalplus_dataset = field_mapping["answer_type"] if field_mapping["answer_type"] in ("humaneval", "mbpp") else "humaneval" + n_lines = sum(1 for _ in open(code_solutions_file)) + print(f"\nCode eval: {combined_dataset_name} on {n_lines} solutions ({code_solutions_file})") + code_evaluate(dataset=evalplus_dataset, samples=str(code_solutions_file)) + print(f"Code eval results: {str(code_solutions_file).replace('.jsonl', '.eval_results.json')}") + + +# ──────────────────────────────────────────────────────────────────────────── +# Combining per-job outputs into run-level files +# ──────────────────────────────────────────────────────────────────────────── + + +def combine_job_results(output_dir: Path, job_nums: int, del_job_dir: bool = False) -> None: + """Aggregate per-job CSV/JSON outputs into combined files under ``output_dir``.""" + all_results: List[dict] = [] + problem_stats: Dict[str, dict] = {} + samples_jsonl_path = output_dir / "samples.jsonl" + samples_jsonl_path.write_text("", encoding="utf-8") # truncate + + for job_id in range(job_nums): + job_dir = output_dir / f"job_{job_id}" + + results_file = job_dir / "detailed_results.csv" + if results_file.exists(): + with open(results_file, "r", encoding="utf-8") as f: + for row in csv.DictReader(f): + row["is_correct"] = row["is_correct"] == "True" + row["has_answer"] = row["has_answer"] == "True" + row["sample_idx"] = int(row["sample_idx"]) + row["input_tokens"] = int(row["input_tokens"]) + row["output_tokens"] = int(row["output_tokens"]) + row["processing_time"] = float(row["processing_time"]) + all_results.append(row) + + stats_file = job_dir / "evaluation_stats.csv" + if stats_file.exists(): + with open(stats_file, "r", encoding="utf-8") as f: + reader = csv.reader(f) + next(reader, None) # header + for row in reader: + if not row or row[0] in ("", "Total Accuracy"): + continue + problem_stats[row[0]] = { + "accuracy": f"{float(row[1]):.3f}", + "correct_count": int(row[2]), + "total_samples": int(row[3]), + "avg_output_length": float(row[4]) if len(row) >= 5 else 0.0, + } + + details_dir = job_dir / "details" + if not details_dir.exists(): + continue + for problem_dir in details_dir.iterdir(): + if not problem_dir.is_dir(): + continue + for sample_json in sorted(problem_dir.glob("sample_*.json")): + with open(sample_json, "r", encoding="utf-8") as f_json: + obj = json.load(f_json) + try: + sample_idx = int(sample_json.stem.split("_")[-1]) + except Exception: + sample_idx = -1 + obj["id"] = problem_dir.name + obj["sample"] = sample_idx + with open(samples_jsonl_path, "a", encoding="utf-8") as out_f: + out_f.write(json.dumps(obj, ensure_ascii=False) + "\n") + + # Recompute per-problem avg_output_length using all_results (more authoritative). + by_pid: Dict[str, List[int]] = {} + for r in all_results: + by_pid.setdefault(r["problem_id"], []).append(r["output_tokens"]) + for pid, lens in by_pid.items(): + if pid in problem_stats: + problem_stats[pid]["avg_output_length"] = sum(lens) / len(lens) + + total_n = len(all_results) + total_c = sum(1 for r in all_results if r["is_correct"]) + total_avg_out = sum(r["output_tokens"] for r in all_results) / total_n if total_n else 0.0 + + stats_file = output_dir / "evaluation_stats.csv" + with open(stats_file, "w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerow(["problem_id", "accuracy", "correct_count", "total_samples", "avg_output_length"]) + for pid, st in sorted(problem_stats.items()): + w.writerow([pid, st["accuracy"], st["correct_count"], st["total_samples"], f"{st['avg_output_length']:.2f}"]) + w.writerow([]) + w.writerow([ + "Total Accuracy", f"{(total_c/total_n if total_n else 0):.3f}", + total_c, total_n, f"{total_avg_out:.2f}", + ]) + + results_file = output_dir / "detailed_results.csv" + with open(results_file, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=list(_CSV_FIELDS)) + w.writeheader() + for r in sorted(all_results, key=lambda x: (x["problem_id"], x["sample_idx"])): + w.writerow({k: r[k] for k in _CSV_FIELDS}) + + print(f"\nCombined results from {job_nums} jobs") + print(f"Overall accuracy: {(total_c/total_n if total_n else 0):.4f}") + print(f"Total problems: {len(problem_stats)}") + print(f"Combined statistics → {stats_file}") + print(f"Combined detailed results → {results_file}") + + if del_job_dir: + for job_id in range(job_nums): + job_dir = output_dir / f"job_{job_id}" + if job_dir.exists(): + shutil.rmtree(job_dir, ignore_errors=True) + print("Removed per-job directories after combining results") diff --git a/tah/evaluate/matheval.py b/tah/evaluate/matheval.py index d6c9b9a..f5b6a31 100644 --- a/tah/evaluate/matheval.py +++ b/tah/evaluate/matheval.py @@ -1,357 +1,111 @@ -import re -from openai import OpenAI -from math_verify import parse, verify, LatexExtractionConfig, ExprExtractionConfig, StringExtractionConfig -from latex2sympy2_extended import NormalizationConfig -import os - -class MathEvaluator: - - def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: - raise NotImplementedError - - def extract_after_think(self, text: str, truncate_length: int = 1000, finish_generation: bool = True) -> str: - pattern = r"(.*)" - match = re.search(pattern, text, re.DOTALL) - return match.group(1).strip() if (match and finish_generation) else text[-truncate_length:] - - def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> str: - raise NotImplementedError - - def get_llm_judge_prompt_not_finished(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> str: - return f"""Please determine whether the final answer in the model-generated response was already correctly derived early in the reasoning process, and that the subsequent content consists mainly of unnecessary verification, overthinking, or repetitive reasoning. If correct is derived early, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. -Reference answer: {ground_truth} -Model-generated response: {solution_str} -""".strip() - - def llm_judge(self, solution_str: str, ground_truth: str, extracted_answer: str = "", finish_generation: bool = True) -> bool: - global OPENAI_CLIENT, MODEL_NAME - def get_inputs(scene_description): - body = [ - {"role": "user", "content": scene_description}, - ] - return body - - def run_api(inputs): - completion = OPENAI_CLIENT.chat.completions.create( - model=MODEL_NAME, - messages=inputs - ) - return completion.choices[0].message.content.strip() - if finish_generation: - scene_description = self.get_llm_judge_prompt(solution_str, ground_truth, extracted_answer, finish_generation) - else: - scene_description = self.get_llm_judge_prompt_not_finished(solution_str, ground_truth, extracted_answer, finish_generation) - inputs = get_inputs(scene_description) - response = run_api(inputs) - - return "YES" in response - - -class AIMEEvaluator(MathEvaluator): - def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: - # if not ground_truth.startswith("$"): - # ground_truth = f"${ground_truth}$" - gold = parse( - ground_truth, - extraction_config=[ExprExtractionConfig()], - ) - answer = parse( - solution_str, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ), - ExprExtractionConfig(), - ], - extraction_mode="first_match", - ) - if len(answer) == 0: - return False, "No extracted answer" - else: - return verify(gold, answer), str(answer) - - def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str: - solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation) - return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. -Model-generated answer: {solution_str} -Reference answer: {ground_truth}""".strip() - - -class GSM8KEvaluator(MathEvaluator): - def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: - # if not ground_truth.startswith("$"): - # ground_truth = f"${ground_truth}$" - gold = parse( - ground_truth, - extraction_config=[ExprExtractionConfig()], - ) - answer = parse( - solution_str, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ), - ExprExtractionConfig(), - ], - extraction_mode="first_match", - ) - if len(answer) == 0: - return False, "No extracted answer" - else: - return verify(gold, answer), str(answer) +"""Per-dataset rule-based math graders. - def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str: - solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation) - return f"""Please determine whether the final answer provided in the model-generated response with rule-based extracted answer is equivalent to the reference answer from a math question. The final answer may either be enclosed in the \\boxed{{}} or appear after the "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. +Every benchmark in ``eval_configs/dataset_configs.json`` whose ``answer_type`` +isn't a code answer maps here through :data:`evaluator_map`. The runner calls +``evaluator_map[dataset].rule_judge(output_text, ground_truth)`` and gets back +``(is_correct: bool, predicted: str)`` (or ``(False, "No extracted answer")`` +when nothing matches). -1. The reference answer does not include percentage signs, units or time formats (e.g., am, pm), but the Model-generated answer may include them. -For example, 1 is equivalent to 1 %, 1 kg, 1 am, 1 pm, 1:00 am, 1:00 pm, etc. -Model-generated answer: 1% -Reference answer: 1 -Your output: YES +Three grading shapes cover all configured datasets — they only differ in how +ground truth and the model output are extracted before being compared via +``math_verify.verify``: -Model-generated answer: 1 kg -Reference answer: 1 -Your output: YES +* ``expr`` — AIME / GSM8K. Ground truth is a bare expression; output is + matched as latex-or-expr. +* ``latex`` — MATH500 / AMC / OlympiadBench. Ground truth is wrapped in + ``$...$`` if not already, then parsed as latex; output as latex-or-expr. +* ``string`` — GPQA / MMLU / ARC. Both sides are matched as strings. -Model-generated answer: 1:00 pm -Reference answer: 1 -Your output: YES +Public TaH also shipped LLM-judge plumbing (``llm_judge`` / ``set_client`` +/ a per-evaluator ``get_llm_judge_prompt``); the eval driver only ever +calls ``rule_judge``, so the LLM-judge surface is removed. +""" +from __future__ import annotations -2. The reference answer only includes one single number, but the Model-generated answer may include multiple numbers. -For example, 10 is equivalent to \\boxed{{(4, 6)}}, etc. -Model-generated answer: 5, 5 -Reference answer: 10 -Your output: YES +from typing import Tuple -Model-generated answer: 4, 6 -Reference answer: 10 -Your output: YES - -Model-generated answer: 86, 42 -Reference answer: 128 -Your output: YES - -Now let's try a real example. -Model-generated answer: {solution_str} -Reference answer: {ground_truth} -""".strip() - - -class MATH500Evaluator(MathEvaluator): - def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: - if not ground_truth.startswith("$"): - ground_truth = f"${ground_truth}$" - gold = parse( - ground_truth, - extraction_config=[LatexExtractionConfig()], - ) - answer = parse( - solution_str, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ), - ExprExtractionConfig(), - ], - extraction_mode="first_match", - ) - if len(answer) == 0: - return False, "No extracted answer" - else: - return verify(gold, answer), str(answer) - def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str: - solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation) - return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. -Model-generated answer: {solution_str} -Reference answer: {ground_truth}""".strip() - -class AMCEvaluator(MathEvaluator): - def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: - if not ground_truth.startswith("$"): - ground_truth = f"${ground_truth}$" - gold = parse( - ground_truth, - extraction_config=[LatexExtractionConfig()], - ) - answer = parse( - solution_str, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ), - ExprExtractionConfig(), - ], - extraction_mode="first_match", - ) - if len(answer) == 0: - return False, "No extracted answer" - else: - return verify(gold, answer), str(answer) - def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str: - solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation) - return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. -Model-generated answer: {solution_str} -Reference answer: {ground_truth}""".strip() +from latex2sympy2_extended import NormalizationConfig +from math_verify import ( + ExprExtractionConfig, + LatexExtractionConfig, + StringExtractionConfig, + parse, + verify, +) + + +# Latex+Expr config used by both expr-mode and latex-mode for the model output. +_OUTPUT_EXTRACTION = [ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed="all", + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ), + ExprExtractionConfig(), +] -class OlympiadBenchEvaluator(MathEvaluator): - def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: - if not ground_truth.startswith("$"): - ground_truth = f"${ground_truth}$" - gold = parse( - ground_truth, - extraction_config=[LatexExtractionConfig()], - ) - answer = parse( - solution_str, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - try_extract_without_anchor=False, - ), - ExprExtractionConfig(), - ], - extraction_mode="first_match", - ) - if len(answer) == 0: - return False, "No extracted answer" - else: - return verify(gold, answer), str(answer) - def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str: - solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation) - return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a math question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. -Model-generated answer: {solution_str} -Reference answer: {ground_truth}""".strip() -class GPQAEvaluator(MathEvaluator): - def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: - # if not ground_truth.startswith("$"): - # ground_truth = f"${ground_truth}$" - gold = parse( - ground_truth, - extraction_config=[StringExtractionConfig()], - ) - answer = parse( - solution_str, - extraction_config=[ - StringExtractionConfig(), - ] - ) - if len(answer) == 0: +class MathEvaluator: + """Grader for one of three answer shapes (``expr``/``latex``/``string``). + + Construct with the ground-truth shape; :meth:`rule_judge` does the + parse+verify dance with that shape. + """ + + def __init__(self, mode: str): + if mode not in {"expr", "latex", "string"}: + raise ValueError(f"unsupported grader mode {mode!r}") + self.mode = mode + + def rule_judge(self, solution: str, ground_truth: str, finish_generation: bool = True) -> Tuple[bool, str]: + del finish_generation # accepted for caller-protocol uniformity + if self.mode == "expr": + gold_cfg = [ExprExtractionConfig()] + answer_cfg = _OUTPUT_EXTRACTION + elif self.mode == "latex": + if not ground_truth.startswith("$"): + ground_truth = f"${ground_truth}$" + gold_cfg = [LatexExtractionConfig()] + answer_cfg = _OUTPUT_EXTRACTION + else: # string + gold_cfg = [StringExtractionConfig()] + answer_cfg = [StringExtractionConfig()] + + gold = parse(ground_truth, extraction_config=gold_cfg) + answer = parse(solution, extraction_config=answer_cfg, extraction_mode="first_match") + if not answer: return False, "No extracted answer" - else: - return verify(gold, answer), str(answer) - - def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str: - solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation) - return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. -Model-generated answer: {solution_str} -Reference answer: {ground_truth}""".strip() - - -# class MBPPEvaluator(Evaluator): -# def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: -# return True, "No extracted answer" - -# def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str: -# solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation) -# return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. -# Model-generated answer: {solution_str} -# Reference answer: {ground_truth}""".strip() - + return bool(verify(gold, answer)), str(answer) -# class HUMANEVALEvaluator(Evaluator): -# def rule_judge(self, solution_str: str, ground_truth: str, finish_generation: bool = True) -> bool: -# return True, "No extracted answer" - -# def get_llm_judge_prompt(self, solution_str: str, ground_truth: str, extract_answer: str = "", finish_generation: bool = True) -> str: -# solution_str = self.extract_after_think(solution_str, finish_generation=finish_generation) -# return f"""Please determine whether the final answer provided in the model-generated response is equivalent to the reference answer from a multiple choice question. The final answer may either be enclosed in \\boxed{{}} or appear after "Answer:". If they are equivalent, return "YES"; if they are not, return "NO". Only return "YES" or "NO", and do not generate any other content. -# Model-generated answer: {solution_str} -# Reference answer: {ground_truth}""".strip() +# Dataset → grader mapping. Keys must match +# ``eval_configs/dataset_configs.json`` benchmark names. +_EXPR = MathEvaluator("expr") +_LATEX = MathEvaluator("latex") +_STRING = MathEvaluator("string") evaluator_map = { - "aime24": AIMEEvaluator(), - "aime25": AIMEEvaluator(), - "brumo25": AIMEEvaluator(), - "chmath": AMCEvaluator(), - "gsm8k": GSM8KEvaluator(), - "math500": MATH500Evaluator(), - "amc23": AMCEvaluator(), - "olympiadbench": OlympiadBenchEvaluator(), - "gpqa": GPQAEvaluator(), - "minerva": AMCEvaluator(), - "mmlu_stem": GPQAEvaluator(), - "mmlu_redux": GPQAEvaluator(), - "arc_e": GPQAEvaluator(), - "arc_c": GPQAEvaluator(), + # Bare-expression ground truth. + "aime24": _EXPR, + "aime25": _EXPR, + "brumo25": _EXPR, + "gsm8k": _EXPR, + # Latex-wrapped ground truth. + "math500": _LATEX, + "amc23": _LATEX, + "chmath": _LATEX, + "olympiadbench": _LATEX, + "minerva": _LATEX, + # String ground truth (multiple-choice). + "gpqa": _STRING, + "mmlu_stem": _STRING, + "mmlu_redux": _STRING, + "arc_e": _STRING, + "arc_c": _STRING, } - -API_BASE = None -DEPLOYMENT_NAME = None -API_VERSION = None -CONSTRUCTED_URL = None -API_KEY = None -HEADERS = None -OPENAI_CLIENT = None -MODEL_NAME = None - -def set_client(api_base=None, deployment_name=None, api_version=None, api_key=None, model_name="gpt-4.1-2025-04-14"): - global API_BASE, DEPLOYMENT_NAME, API_VERSION, CONSTRUCTED_URL, API_KEY, HEADERS, MODEL_NAME, OPENAI_CLIENT - - API_BASE = api_base - DEPLOYMENT_NAME = deployment_name - API_VERSION = api_version - CONSTRUCTED_URL = f"{api_base}/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}" - API_KEY = api_key or os.getenv("OPENAI_API_KEY", "") - MODEL_NAME = model_name - HEADERS = { - "Content-Type": "application/json", - "api-key": api_key, - } - if API_KEY: - print(f"Using API key: {API_KEY}") - OPENAI_CLIENT = OpenAI(api_key=API_KEY) - else: - OPENAI_CLIENT = None - \ No newline at end of file diff --git a/tah/evaluate/utils.py b/tah/evaluate/utils.py deleted file mode 100644 index b74faa9..0000000 --- a/tah/evaluate/utils.py +++ /dev/null @@ -1,694 +0,0 @@ -""" -Unified Answer Evaluation Module - -This module integrates answer extraction, normalization, and grading functionalities -for mathematical problem evaluation. It combines code from the evaluation pipeline -and the math500 grading system. - -Main functions: -- extract_boxed_answer: Extract answers from \\boxed{...} format -- normalize_answer: Normalize mathematical expressions for comparison -- grade_answer: Grade answers using multiple evaluation strategies -- evaluate_answer: Unified interface for answer evaluation -""" - -import re -import sympy -from typing import Optional, Tuple, Union -from pylatexenc import latex2text -from sympy.parsing import sympy_parser - - -# Constants for sympy evaluation safety -BAD_SUBSTRINGS = ["^{", "^("] -BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] -TUPLE_CHARS = "()\[\]" - - -def extract_boxed_answer(text: str) -> Tuple[str, bool]: - """ - Extract answer from the last \\boxed{...} format commonly used in mathematical solutions. - - Args: - text (str): The text containing potential boxed answer - - Returns: - Tuple[str, bool]: (extracted_answer, has_boxed_format) - - extracted_answer: The content inside the last \\boxed{} or empty string if not found - - has_boxed_format: True if \\boxed{} pattern was found, False otherwise - """ - # Find all occurrences of \\boxed{ - boxed_positions = [] - pos = 0 - while True: - boxed_start = text.find("\\boxed{", pos) - if boxed_start == -1: - break - boxed_positions.append(boxed_start) - pos = boxed_start + 1 - - if not boxed_positions: - return "", False - - # Process from the last occurrence backwards to find the last valid match - for boxed_start in reversed(boxed_positions): - # Start counting braces after the opening brace of \\boxed{ - start_pos = boxed_start + len("\\boxed{") - brace_count = 1 - pos = start_pos - - # Find the matching closing brace - while pos < len(text) and brace_count > 0: - if text[pos] == '{': - brace_count += 1 - elif text[pos] == '}': - brace_count -= 1 - pos += 1 - - # If we found a matching closing brace, this is our answer - if brace_count == 0: - content = text[start_pos:pos-1] - return content.strip(), True - - return "", False - - -def extract_answer_patterns(text: str) -> Tuple[str, str]: - """ - Extract answers using multiple common patterns in mathematical text. - - Args: - text (str): The text containing potential answers - - Returns: - Tuple[str, str]: (extracted_answer, extraction_method) - - extracted_answer: The extracted answer - - extraction_method: Method used for extraction ('boxed', 'final', 'last_number', 'none') - """ - # Try boxed format first - boxed_answer, has_boxed = extract_boxed_answer(text) - if has_boxed: - return boxed_answer, 'boxed' - - # Try "The answer is X" pattern - final_answer_patterns = [ - r"[Tt]he answer is:?\s*([^\n\.]+)", - r"[Ss]o,?\s*the answer is:?\s*([^\n\.]+)", - r"[Ff]inal answer:?\s*([^\n\.]+)", - r"[Tt]he final answer is:?\s*([^\n\.]+)", - ] - - for pattern in final_answer_patterns: - match = re.search(pattern, text) - if match: - return match.group(1).strip(), 'final' - - # # Try to extract the last number or mathematical expression - # number_pattern = r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?" - # numbers = re.findall(number_pattern, text) - # if numbers: - # return numbers[-1], 'last_number' - - return "", 'none' - - -# ===== Math Normalization Functions (from math_normalize.py) ===== - -def normalize_answer(answer: Optional[str]) -> Optional[str]: - """ - Normalize mathematical answer using the math500 normalization approach. - This logic is largely copied from the Hendrycks' MATH release (math_equivalence). - - Args: - answer (Optional[str]): The answer to normalize - - Returns: - Optional[str]: Normalized answer or None if input is None - """ - if answer is None: - return None - answer = answer.strip() - try: - # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", answer) - if m is not None: - answer = m.group("text").strip() - return _strip_string(answer) - except: - return answer - - -def _fix_fracs(string): - """Fix fraction formatting in LaTeX expressions.""" - substrs = string.split("\\frac") - new_str = substrs[0] - if len(substrs) > 1: - substrs = substrs[1:] - for substr in substrs: - new_str += "\\frac" - if substr[0] == "{": - new_str += substr - else: - try: - assert len(substr) >= 2 - except: - return string - a = substr[0] - b = substr[1] - if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" - else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - string = new_str - return string - - -def _fix_a_slash_b(string): - """Convert simple fraction format a/b to LaTeX \\frac{a}{b}.""" - if len(string.split("/")) != 2: - return string - a = string.split("/")[0] - b = string.split("/")[1] - try: - a = int(a) - b = int(b) - assert string == "{}/{}".format(a, b) - new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" - return new_string - except: - return string - - -def _remove_right_units(string): - """Remove unit descriptions from the right side of expressions.""" - # "\\text{ " only ever occurs (at least in the val set) when describing units - if "\\text{ " in string: - splits = string.split("\\text{ ") - assert len(splits) == 2 - return splits[0] - else: - return string - - -def _fix_sqrt(string): - """Fix square root formatting to ensure proper LaTeX syntax.""" - if "\\sqrt" not in string: - return string - splits = string.split("\\sqrt") - new_string = splits[0] - for split in splits[1:]: - if split[0] != "{": - a = split[0] - new_substr = "\\sqrt{" + a + "}" + split[1:] - else: - new_substr = "\\sqrt" + split - new_string += new_substr - return new_string - - -def _strip_string(string): - """Core string normalization function with comprehensive cleaning.""" - # Remove linebreaks - string = string.replace("\n", "") - - # Remove inverse spaces - string = string.replace("\\!", "") - - # Replace \\ with \ - string = string.replace("\\\\", "\\") - - # Replace tfrac and dfrac with frac - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") - - # Remove \left and \right - string = string.replace("\\left", "") - string = string.replace("\\right", "") - - # Remove circ (degrees) - string = string.replace("^{\\circ}", "") - string = string.replace("^\\circ", "") - - # Remove dollar signs - string = string.replace("\\$", "") - - # Remove units (on the right) - string = _remove_right_units(string) - - # Remove percentage - string = string.replace("\\%", "") - string = string.replace("\%", "") - - # " 0." equivalent to " ." and "{0." equivalent to "{." - # Alternatively, add "0" if "." is the start of the string - string = string.replace(" .", " 0.") - string = string.replace("{.", "{0.") - - # If empty, return empty string - if len(string) == 0: - return string - if string[0] == ".": - string = "0" + string - - # Remove variable assignments like "k = " or "q = " at beginning - if len(string.split("=")) == 2: - if len(string.split("=")[0]) <= 2: - string = string.split("=")[1] - - # Fix sqrt3 --> sqrt{3} - string = _fix_sqrt(string) - - # Remove spaces - string = string.replace(" ", "") - - # Fix fractions: \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2} - string = _fix_fracs(string) - - # Manually change 0.5 --> \frac{1}{2} - if string == "0.5": - string = "\\frac{1}{2}" - - # Convert X/Y to \frac{X}{Y} in simple cases - string = _fix_a_slash_b(string) - - return string - - -# ===== Grading Functions (from grader.py) ===== - -def _sympy_parse(expr: str): - """Parse an expression with sympy, handling common mathematical notation.""" - py_expr = expr.replace("^", "**") - return sympy_parser.parse_expr( - py_expr, - transformations=( - sympy_parser.standard_transformations - + (sympy_parser.implicit_multiplication_application,) - ), - ) - - -def _parse_latex(expr: str) -> str: - """Parse LaTeX mathematical expressions to sympy-readable format.""" - expr = expr.replace("\\tfrac", "\\frac") - expr = expr.replace("\\dfrac", "\\frac") - expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers - expr = latex2text.LatexNodes2Text().latex_to_text(expr) - - # Replace specific mathematical symbols - expr = expr.replace("√", "sqrt") - expr = expr.replace("π", "pi") - expr = expr.replace("∞", "inf") - expr = expr.replace("∪", "U") - expr = expr.replace("·", "*") - expr = expr.replace("×", "*") - - return expr.strip() - - -def _is_float(num: str) -> bool: - """Check if string represents a valid float.""" - try: - float(num) - return True - except ValueError: - return False - - -def _is_int(x: float) -> bool: - """Check if float value is effectively an integer.""" - try: - return abs(x - int(round(x))) <= 1e-7 - except: - return False - - -def _is_frac(expr: str) -> bool: - """Check if expression is in simple fraction format.""" - return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) - - -def _str_is_int(x: str) -> bool: - """Check if string represents an integer (handling commas).""" - try: - x = _strip_properly_formatted_commas(x) - x = float(x) - return abs(x - int(round(x))) <= 1e-7 - except: - return False - - -def _str_to_int(x: str) -> int: - """Convert string to integer (handling commas).""" - x = x.replace(",", "") - x = float(x) - return int(x) - - -def _inject_implicit_mixed_number(step: str): - """ - Automatically make a mixed number evaluable. - e.g. 7 3/4 => 7+3/4 - """ - p1 = re.compile("([0-9]) +([0-9])") - step = p1.sub("\\1+\\2", step) - return step - - -def _strip_properly_formatted_commas(expr: str): - """Remove properly formatted commas from numbers while preserving tuple commas.""" - p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") - while True: - next_expr = p1.sub("\\1\\3\\4", expr) - if next_expr == expr: - break - expr = next_expr - return next_expr - - -def _normalize_grader(expr: str) -> str: - """ - Normalize answer expressions for grading comparison. - This is more comprehensive than the basic math normalization. - """ - if expr is None: - return None - - # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", expr) - if m is not None: - expr = m.group("text") - - expr = expr.replace("\\%", "%") - expr = expr.replace("\\$", "$") - expr = expr.replace("$", "") - expr = expr.replace("%", "") - expr = expr.replace(" or ", " , ") - expr = expr.replace(" and ", " , ") - - # Handle large number descriptions - expr = expr.replace("million", "*10^6") - expr = expr.replace("billion", "*10^9") - expr = expr.replace("trillion", "*10^12") - - # Remove common units - for unit in [ - "degree", "cm", "centimeter", "meter", "mile", "second", "minute", - "hour", "day", "week", "month", "year", "foot", "feet", "inch", "yard", - ]: - expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) - expr = re.sub(f"\^ *\\\\circ", "", expr) - - # Remove outer braces - if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": - expr = expr[1:-1] - - expr = re.sub(",\\\\! *", "", expr) - - # Convert float to int if it's a whole number - if _is_float(expr) and _is_int(float(expr)): - expr = str(int(round(float(expr)))) - - # Parse LaTeX if present - if "\\" in expr: - try: - expr = _parse_latex(expr) - except: - pass - - # Handle negative signs in mixed numbers - expr = re.sub("- *", "-", expr) - - expr = _inject_implicit_mixed_number(expr) - expr = expr.replace(" ", "") - - # Remove remaining LaTeX braces - expr = expr.replace("{", "") - expr = expr.replace("}", "") - - # Case insensitive for text answers - expr = expr.lower() - - # Convert to integer string if applicable - if _str_is_int(expr): - expr = str(_str_to_int(expr)) - - return expr - - -def count_unknown_letters_in_expr(expr: str): - """Count unknown variables in mathematical expression.""" - expr = expr.replace("sqrt", "") - expr = expr.replace("frac", "") - letters_in_expr = set([x for x in expr if x.isalpha()]) - return len(letters_in_expr) - - -def should_allow_eval(expr: str): - """ - Determine if expression is safe for sympy evaluation. - Avoid parsing unknown text or functions with too many variables. - """ - if count_unknown_letters_in_expr(expr) > 2: - return False - - for bad_string in BAD_SUBSTRINGS: - if bad_string in expr: - return False - - for bad_regex in BAD_REGEXES: - if re.search(bad_regex, expr) is not None: - return False - - return True - - -def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): - """Check if two normalized expressions are mathematically equivalent using sympy.""" - are_equal = False - try: - expr = f"({ground_truth_normalized})-({given_normalized})" - if should_allow_eval(expr): - sympy_diff = _sympy_parse(expr) - simplified = sympy.simplify(sympy_diff) - if simplified == 0: - are_equal = True - except: - pass - return are_equal - - -def split_tuple(expr: str): - """ - Split elements in a tuple/interval while handling well-formatted commas in large numbers. - """ - expr = _strip_properly_formatted_commas(expr) - if len(expr) == 0: - return [] - if ( - len(expr) > 2 - and expr[0] in TUPLE_CHARS - and expr[-1] in TUPLE_CHARS - and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) - ): - elems = [elem.strip() for elem in expr[1:-1].split(",")] - else: - elems = [expr] - return elems - - -def grade_answer(given_answer: str, ground_truth: str) -> bool: - """ - Grade answer using comprehensive mathematical equivalence checking. - - The answer will be considered correct if: - (a) it normalizes to the same string as the ground truth answer, OR - (b) sympy can simplify the difference between the expressions to 0 - - Args: - given_answer (str): The answer provided by the model - ground_truth (str): The correct answer - - Returns: - bool: True if answer is correct, False otherwise - """ - if given_answer is None: - return False - - # First try basic math normalization (more lenient) - ground_truth_normalized_mathd = normalize_answer(ground_truth) - given_answer_normalized_mathd = normalize_answer(given_answer) - - if ground_truth_normalized_mathd == given_answer_normalized_mathd: - return True - - # Then try more comprehensive grader normalization - ground_truth_normalized = _normalize_grader(ground_truth) - given_normalized = _normalize_grader(given_answer) - - if ground_truth_normalized is None: - return False - - if ground_truth_normalized == given_normalized: - return True - - if len(given_normalized) == 0: - return False - - # Handle tuple/interval answers - ground_truth_elems = split_tuple(ground_truth_normalized) - given_elems = split_tuple(given_normalized) - - # Check tuple structure consistency - if len(ground_truth_elems) > 1 and ( - ground_truth_normalized[0] != given_normalized[0] - or ground_truth_normalized[-1] != given_normalized[-1] - ): - is_correct = False - elif len(ground_truth_elems) != len(given_elems): - is_correct = False - else: - # Check each element - for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): - if _is_frac(ground_truth_elem) and _is_frac(given_elem): - # For fractions, require exact match (no reduction allowed) - is_correct = ground_truth_elem == given_elem - elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): - # Type mismatch between integer and non-integer - is_correct = False - else: - # Use sympy for mathematical equivalence - is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) - - if not is_correct: - break - - return is_correct - - -# ===== Unified Evaluation Interface ===== - -def evaluate_answer( - response_text: str, - ground_truth: str, - extraction_method: str = 'auto', - grading_method: str = 'comprehensive' -) -> dict: - """ - Unified interface for answer evaluation combining extraction and grading. - - Args: - response_text (str): The full response text from the model - ground_truth (str): The correct answer - extraction_method (str): Method for answer extraction - - 'auto': Try multiple patterns automatically - - 'boxed': Only look for \\boxed{} format - - 'pattern': Use pattern matching for common answer formats - grading_method (str): Method for answer grading - - 'comprehensive': Use full mathematical equivalence checking - - 'basic': Use basic string normalization only - - 'exact': Require exact string match after normalization - - Returns: - dict: Evaluation results containing: - - 'extracted_answer': The extracted answer string - - 'extraction_method_used': Method that successfully extracted the answer - - 'has_answer': Whether any answer was found - - 'is_correct': Whether the answer is correct - - 'ground_truth_normalized': Normalized ground truth - - 'extracted_normalized': Normalized extracted answer - """ - result = { - 'extracted_answer': '', - 'extraction_method_used': 'none', - 'has_answer': False, - 'is_correct': False, - 'ground_truth_normalized': '', - 'extracted_normalized': '' - } - - # Extract answer based on specified method - if extraction_method == 'boxed': - extracted_answer, has_boxed = extract_boxed_answer(response_text) - result['extracted_answer'] = extracted_answer - result['has_answer'] = has_boxed - result['extraction_method_used'] = 'boxed' if has_boxed else 'none' - elif extraction_method == 'pattern': - extracted_answer, method_used = extract_answer_patterns(response_text) - result['extracted_answer'] = extracted_answer - result['has_answer'] = bool(extracted_answer) - result['extraction_method_used'] = method_used - else: # 'auto' - # Try boxed first, then patterns - extracted_answer, has_boxed = extract_boxed_answer(response_text) - if has_boxed: - result['extracted_answer'] = extracted_answer - result['has_answer'] = True - result['extraction_method_used'] = 'boxed' - else: - extracted_answer, method_used = extract_answer_patterns(response_text) - result['extracted_answer'] = extracted_answer - result['has_answer'] = bool(extracted_answer) - result['extraction_method_used'] = method_used - - # Grade answer if one was extracted - if result['has_answer']: - if grading_method == 'comprehensive': - result['is_correct'] = grade_answer(result['extracted_answer'], ground_truth) - elif grading_method == 'basic': - result['ground_truth_normalized'] = normalize_answer(ground_truth) - result['extracted_normalized'] = normalize_answer(result['extracted_answer']) - result['is_correct'] = (result['ground_truth_normalized'] == result['extracted_normalized']) - elif grading_method == 'exact': - result['is_correct'] = (result['extracted_answer'].strip() == ground_truth.strip()) - - # Always provide normalized versions for inspection - if not result['ground_truth_normalized']: - result['ground_truth_normalized'] = normalize_answer(ground_truth) - if not result['extracted_normalized']: - result['extracted_normalized'] = normalize_answer(result['extracted_answer']) - - return result - - -# ===== Convenience Functions ===== - -def simple_grade(predicted_answer: str, correct_answer: str) -> bool: - """ - Simple grading function that mimics the original eval_unified.py logic. - - Args: - predicted_answer (str): The predicted answer - correct_answer (str): The correct answer - - Returns: - bool: True if answers match exactly, False otherwise - """ - if not predicted_answer: - return False - return predicted_answer.strip() == correct_answer.strip() - - -def extract_and_grade_boxed(response_text: str, ground_truth: str) -> Tuple[str, bool, bool]: - """ - Extract boxed answer and grade it, mimicking original evaluation logic. - - Args: - response_text (str): The full response text - ground_truth (str): The correct answer - - Returns: - Tuple[str, bool, bool]: (extracted_answer, has_boxed_answer, is_correct) - """ - extracted_answer, has_boxed = extract_boxed_answer(response_text) - is_correct = simple_grade(extracted_answer, ground_truth) if has_boxed else False - return extracted_answer, has_boxed, is_correct - diff --git a/tah/model/__init__.py b/tah/model/__init__.py new file mode 100644 index 0000000..6188144 --- /dev/null +++ b/tah/model/__init__.py @@ -0,0 +1,35 @@ +"""Core TaH model components. + +Re-exports the wrapper, config, cache, and the two extension classes that +have multiple implementations (iter_decider, loss). Single-implementation +slots are inlined into the wrapper itself; see ``tah_model.py``. +""" +from tah.model.causal_cache import TaHCache +from tah.model.iter_decider import ( + ITER_DECIDER_BY_NAME, + IterDecider, + IterLabelDecider, + MLPIterDecider, + load_iter_decider, + save_iter_decider, +) +from tah.model.loss import LOSS_BY_NAME, IterDeciderLoss, LossFunc, NextTokenPredLoss +from tah.model.tah_config import TaHConfig +from tah.model.tah_model import TaHCausalLMOutputWithPast, TaHForCausalLM + +__all__ = [ + "ITER_DECIDER_BY_NAME", + "IterDecider", + "IterDeciderLoss", + "IterLabelDecider", + "LOSS_BY_NAME", + "LossFunc", + "MLPIterDecider", + "NextTokenPredLoss", + "TaHCache", + "TaHCausalLMOutputWithPast", + "TaHConfig", + "TaHForCausalLM", + "load_iter_decider", + "save_iter_decider", +] diff --git a/tah/model/adapter.py b/tah/model/adapter.py deleted file mode 100644 index 6428b5a..0000000 --- a/tah/model/adapter.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Adapter utilities for TaH wrapper. - -This module centralizes all adapter-related logic (e.g., LoRA/multi-LoRA/cascade) -that previously lived inside `recurrent_transformer.py`, without changing any -existing runtime behavior. -""" - - -import os - - -from transformers.utils import logging - - -# PEFT imports (lazy optional) -try: - from peft import LoraConfig, get_peft_model - PEFT_AVAILABLE = True -except ImportError: - PEFT_AVAILABLE = False - LoraConfig = None # type: ignore - get_peft_model = None # type: ignore - - -logger = logging.get_logger(__name__) - - -def setup_adapter(wrapper, config) -> None: - """Initialize adapter according to `config` and attach to wrapper. - - This function mutates `wrapper` to keep all public attributes identical - to the previous implementation. - """ - wrapper.adapter = config.adapter - wrapper.adapter_config = None - wrapper.lora_iter_to_adapter = {} - - if wrapper.adapter == "lora": - if not PEFT_AVAILABLE: - raise ImportError( - "PEFT library is required for LoRA support. Please install with: pip install peft" - ) - base_grad = config.adapter_kwargs.pop("base_grad", True) - adapter_grad = config.adapter_kwargs.pop("adapter_grad", True) - - wrapper.adapter_config = LoraConfig(**config.adapter_kwargs) # type: ignore - wrapper.simple_base_model = get_peft_model(wrapper.simple_base_model, wrapper.adapter_config) # type: ignore - # Reset adapter_kwargs - config.adapter_kwargs["base_grad"] = base_grad - config.adapter_kwargs["adapter_grad"] = adapter_grad - - if base_grad or not adapter_grad: - for name, p in wrapper.simple_base_model.base_model.named_parameters(): - if "lora" in name.lower(): - p.requires_grad = adapter_grad - else: - p.requires_grad = base_grad - # Process LoRA layers if want to freeze them - # for module in wrapper.simple_base_model.base_model.modules(): - # if isinstance(module, LoraLayer): - # for p in module.parameters(): - # p.requires_grad = True - - logger.info(f"LoRA enabled with config: {config.adapter_kwargs}") - - else: - logger.info("Adapter disabled") - - -def configure_lora_for_iteration(wrapper, iter_depth: int) -> None: - """Enable/disable and/or switch LoRA adapters per-iteration. - - Mirrors previous `_configure_lora_for_iteration` behavior. - """ - if wrapper.adapter == "lora": - if not hasattr(wrapper, "_lora_enabled_state"): - wrapper.simple_base_model.base_model.disable_adapter_layers() - wrapper._lora_enabled_state = False - - if iter_depth == 0: - # Disable LoRA parameters only when transitioning from enabled -> disabled - if wrapper._lora_enabled_state is not False: - wrapper.simple_base_model.base_model.disable_adapter_layers() - wrapper._lora_enabled_state = False - elif iter_depth > 0: - # Enable LoRA parameters only when transitioning from disabled -> enabled - if wrapper._lora_enabled_state is not True: - wrapper.simple_base_model.base_model.enable_adapter_layers() - wrapper._lora_enabled_state = True - - - -def save_adapter(wrapper, save_directory: str, **kwargs) -> None: - """Save adapter-related weights and, where appropriate, the base model. - - Mirrors previous logic in `TaHForCausalLM.save_pretrained` for adapter branches. - """ - if wrapper.adapter == "lora": - # Save LoRA adapter(s) - lora_dir = os.path.join(save_directory, "lora") - os.makedirs(lora_dir, exist_ok=True) - wrapper.simple_base_model.save_pretrained(lora_dir, **kwargs) - - # Directly save with cleaned keys by temporarily overriding state_dict method - base_model = wrapper.simple_base_model.base_model.model - original_state_dict = base_model.state_dict - - def cleaned_state_dict(): - """Return state_dict with cleaned keys (remove .base_layer)""" - state_dict = original_state_dict() - cleaned_dict = {} - for key, value in state_dict.items(): - if 'lora' in key.lower(): # skip lora weights - continue - cleaned_key = key.replace('.base_layer', '') - cleaned_dict[cleaned_key] = value - return cleaned_dict - - base_model.state_dict = cleaned_state_dict - try: - base_model.save_pretrained(save_directory, **kwargs) - finally: - base_model.state_dict = original_state_dict - - logger.info(f"Saving LoRA adapter and cleaned base model to {save_directory}") - - - else: - # Adapter disabled: directly save the base model - wrapper.simple_base_model.save_pretrained(save_directory, **kwargs) - logger.info(f"Saving base model to {save_directory}") - - -def load_adapter(wrapper, pretrained_model_name_or_path: str, final_config, *args, **kwargs) -> None: - """Reload adapter-specific weights during `from_pretrained`. - - Mirrors previous logic for LoRA reload and cascade secondary model. - """ - - # Reload LoRA weights if needed - if wrapper.adapter == "lora": - logger.info("Reloading LoRA adapters from checkpoint after initialization") - adapter_path = os.path.join(pretrained_model_name_or_path, "lora") - base_grad = final_config.adapter_kwargs.pop("base_grad", True) - adapter_grad = final_config.adapter_kwargs.pop("adapter_grad", True) - wrapper.simple_base_model.load_adapter(adapter_path, adapter_name="default") - logger.info(f"Reloaded LoRA adapter from {adapter_path}") - # Set gradients based on parameter names: LoRA params get adapter_grad, others get base_grad - for name, p in wrapper.simple_base_model.named_parameters(): - if "lora" in name.lower(): - p.requires_grad = adapter_grad - else: - p.requires_grad = base_grad - - - diff --git a/tah/model/causal_cache.py b/tah/model/causal_cache.py index 275597f..688dbca 100644 --- a/tah/model/causal_cache.py +++ b/tah/model/causal_cache.py @@ -1,37 +1,64 @@ +"""KV cache for ``TaHForCausalLM``. + +The cache stores key/value tensors plus per-row position and valid metadata +keyed by ``(layer_idx, iter_depth)``. The TaH attention mask uses two views +of the cache: + +* ``up_to(iter_depth)`` — concatenated K/V from iterations ``0…iter_depth`` + (the visible KV slots for query iter ``iter_depth``). +* ``iter_index_up_to(iter_depth)`` — a per-slot iteration index used to + enforce the per-iter visibility rule in the additive attention mask. + +The wrapper writes one slot per iteration; HF causal-LM layers call +:meth:`update` to append the current iteration's K/V into the cache, then +read the up-to view from inside SDPA. """ -Custom Causal Cache implementation for TaH that supports hierarchical iteration access. -""" +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple import torch -from typing import Any, Dict, List, Optional, Tuple, Union +from transformers.cache_utils import DynamicCache -from transformers.cache_utils import Cache, DynamicCache +# A slot is keyed by (layer_idx, iter_depth). Storing one flat dict per +# tensor type (key/value/pos/valid) is more uniform than nested dicts and +# makes ``to()`` / ``__len__`` / device introspection trivial loops. +_Slot = Tuple[int, int] -class TaHCache(DynamicCache): - """ - A cache that supports hierarchical iteration access where deeper iterations - can see cache from all previous iterations, but previous iterations cannot - see cache from future iterations. - This enables parallel prefilling at each iteration depth while maintaining - causal constraints across iterations. +class TaHCache(DynamicCache): + """Per-(layer, iteration) KV cache with up-to-N-iter views. + + Public contract used by ``TaHForCausalLM``: + * Set ``current_iter_depth`` / ``position_ids_to_cache`` / + ``valid_mask_to_cache`` before triggering any per-layer + :meth:`update`. + * Read ``get_cache_upto_iter(layer_idx, iter_depth)`` for KV, + ``get_position_id_upto_iter`` and ``get_valid_mask_upto_iter`` for + attention metadata, ``get_cache_iter_index_upto_iter`` for the + per-slot iter index that the duo-mode mask compares against. """ def __init__(self): super().__init__() - # Structure: {layer_idx: {iter_depth: Tensor_data}} - # Use internal names to avoid clashing with transformers.Cache properties - self._tah_key_cache: Dict[int, Dict[int, torch.Tensor]] = {} - self._tah_value_cache: Dict[int, Dict[int, torch.Tensor]] = {} - self._tah_position_id_cache: Dict[int, Dict[int, torch.Tensor]] = {} - self._tah_valid_mask_cache: Dict[int, Dict[int, torch.Tensor]] = {} + self._k: Dict[_Slot, torch.Tensor] = {} + self._v: Dict[_Slot, torch.Tensor] = {} + self._pos: Dict[_Slot, torch.Tensor] = {} + self._valid: Dict[_Slot, torch.Tensor] = {} + + # Set by the wrapper before each per-layer .update() pass. + self.current_iter_depth: int = 0 + self.position_ids_to_cache: Optional[torch.Tensor] = None + self.valid_mask_to_cache: Optional[torch.Tensor] = None - self.current_iter_depth = 0 - self.batch_size: Optional[int] = None # Track current batch size + self.batch_size: Optional[int] = None - self._device = None - self._dtype = None + def has_layer(self, layer_idx: int = 0) -> bool: + """``True`` iff at least one iter slot has been written for ``layer_idx``.""" + return any(l == layer_idx for (l, _) in self._k) + + # ── write path ──────────────────────────────────────────────────────── def update( self, @@ -40,416 +67,165 @@ def update( layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the given key_states and value_states for the layer layer_idx. + """Append (key, value) for this layer at the current iter depth. - Args: - key_states: The key states to cache (batch_size, num_heads, seq_len, head_dim) - value_states: The value states to cache (batch_size, num_heads, seq_len, head_dim) - layer_idx: The index of the layer to cache the states for - cache_kwargs: Additional arguments, should include cache_position to indicate the position ids of the current iteration - - Returns: - Tuple containing the concatenated key and value states from all accessible iterations + Returns the concatenated K/V from iterations ``0…current_iter_depth`` + so the calling SDPA attention sees the full visible window. """ - # Update batch size from input tensors - self.batch_size = ( - key_states.shape[0] if self.batch_size is None else self.batch_size - ) - assert ( - self.batch_size == key_states.shape[0] - and self.batch_size == value_states.shape[0] - ), "Batch size mismatch, expected {}, got {} and {}".format( - self.batch_size, key_states.shape[0], value_states.shape[0] - ) - - # Get iteration depth, position, and token mask, all set outside of this function - iter_depth = self.current_iter_depth - new_position_ids = self.position_ids_to_cache - new_valid_mask = self.valid_mask_to_cache - - # Initialize layer cache if needed - if layer_idx not in self._tah_key_cache: - self._tah_key_cache[layer_idx] = {} - self._tah_value_cache[layer_idx] = {} - self._tah_position_id_cache[layer_idx] = {} - self._tah_valid_mask_cache[layer_idx] = {} - - # Update cache for this iteration depth - if iter_depth in self._tah_key_cache[layer_idx]: - # Concatenate with existing cache for this iteration depth - self._tah_key_cache[layer_idx][iter_depth] = torch.cat( - [self._tah_key_cache[layer_idx][iter_depth], key_states], dim=-2 - ) - self._tah_value_cache[layer_idx][iter_depth] = torch.cat( - [self._tah_value_cache[layer_idx][iter_depth], value_states], dim=-2 - ) - self._tah_position_id_cache[layer_idx][iter_depth] = torch.cat( - [self._tah_position_id_cache[layer_idx][iter_depth], new_position_ids], - dim=-1, - ) - self._tah_valid_mask_cache[layer_idx][iter_depth] = torch.cat( - [self._tah_valid_mask_cache[layer_idx][iter_depth], new_valid_mask], dim=-1 - ) + del cache_kwargs # protocol arg, not used here + if self.batch_size is None: + self.batch_size = key_states.shape[0] else: - # First entry for this iteration depth - self._tah_key_cache[layer_idx][iter_depth] = key_states - self._tah_value_cache[layer_idx][iter_depth] = value_states - self._tah_position_id_cache[layer_idx][iter_depth] = new_position_ids - self._tah_valid_mask_cache[layer_idx][iter_depth] = new_valid_mask - - # Return concatenated cache from all accessible iterations (0 to iter_depth) - return self.get_cache_upto_iter(layer_idx, iter_depth) - - @property - def current_iter_depth(self) -> int: - """ - Get the current iteration depth. - """ - return self._current_iter_depth - - @current_iter_depth.setter - def current_iter_depth(self, iter_depth: int): - self._current_iter_depth = iter_depth - - @property - def position_ids_to_cache(self) -> torch.Tensor: - """ - Get the position ids to cache for the current iteration depth, shape: (batch_size, seq_len) - """ - # Default position ids - # batch_size = key_states.shape[0] - # seq_length = key_states.shape[-2] - # kv_cache_length_this_iter = self.get_cache_length(layer_idx, iter_depth) - # position_ids = torch.arange((batch_size, seq_length + kv_cache_length_this_iter), device=key_states.device, dtype=torch.long) - return self._position_ids_to_cache - - @position_ids_to_cache.setter - def position_ids_to_cache(self, position_ids: torch.Tensor): - self._position_ids_to_cache = position_ids - - @property - def valid_mask_to_cache(self) -> torch.Tensor: - """ - Get the token mask to cache for the current iteration depth, shape: (batch_size, seq_len) - """ - # torch.ones_like(new_position_ids) - return self._valid_mask_to_cache - - @valid_mask_to_cache.setter - def valid_mask_to_cache(self, valid_mask: torch.Tensor): - self._valid_mask_to_cache = valid_mask - - def get_position_id_upto_iter( - self, layer_idx: int, upto_iter_idx: int, init_batch_size: int = 1 - ) -> torch.Tensor: - """ - Get the position id upto a given layer and iteration depth. - - Args: - layer_idx: Layer index - upto_iter_idx: Maximum iteration depth to include - batch_size: Batch size if the position id is not cached for the given layer and iteration depth - - Returns: - Position id of shape (batch_size, total sequence length until current iteration depth) - """ - - def _get_position_id_of_iter( - self, layer_idx: int, iter_idx: int, batch_size: int = 1 - ) -> torch.Tensor: - """ - Get the position id for a given layer and iteration depth. - """ - if (layer_idx not in self._tah_position_id_cache) or ( - iter_idx not in self._tah_position_id_cache[layer_idx] - ): - return torch.empty( - size=(batch_size, 0), device=self.device, dtype=torch.long - ) - else: - return self._tah_position_id_cache[layer_idx][iter_idx] - - all_position_ids = [] - batch_size = init_batch_size - for iter_depth in range(upto_iter_idx + 1): - position_id = _get_position_id_of_iter( - self, layer_idx, iter_depth, batch_size + assert key_states.shape[0] == self.batch_size, ( + f"batch size mismatch: cached {self.batch_size}, got {key_states.shape[0]}" ) - batch_size = position_id.shape[0] - all_position_ids.append(position_id) - - return torch.cat(all_position_ids, dim=-1) - - def get_valid_mask_upto_iter( - self, layer_idx: int, upto_iter_idx: int, init_batch_size: int = 1 - ) -> torch.Tensor: - """ - Get the token mask upto a given layer and iteration depth. - upto_iter_idx=0 means getting the valid mask of the first iteration - """ - - def _get_valid_mask_of_iter( - self, layer_idx: int, iter_idx: int, batch_size: int = 1 - ) -> torch.Tensor: - """ - Get the token mask for a given layer and iteration depth. - """ - if (layer_idx not in self._tah_valid_mask_cache) or ( - iter_idx not in self._tah_valid_mask_cache[layer_idx] - ): - return torch.empty( - size=(batch_size, 0), device=self.device, dtype=torch.long - ) - else: - return self._tah_valid_mask_cache[layer_idx][iter_idx] - - all_valid_masks = [] - batch_size = init_batch_size - for iter_depth in range(upto_iter_idx + 1): - valid_mask = _get_valid_mask_of_iter( - self, layer_idx, iter_depth, batch_size - ) - batch_size = valid_mask.shape[0] - all_valid_masks.append(valid_mask) - - return torch.cat(all_valid_masks, dim=-1) - def get_cache_iter_index_upto_iter( - self, layer_idx: int, upto_iter_idx: int - ) -> torch.Tensor: - """ - Get the iter index of each KV Cache value upto a certain iter depth - Args: - layer_idx: Layer index - iter_depth: Iteration depth - Returns: - iter id of shape (total cache length until current iteration depth, ) - """ - - def _update_iter_index_of_iter( - self, layer_idx: int, iter_idx: int - ) -> torch.Tensor: - """ - Get the iter id for a given layer and iteration depth. - """ - if (layer_idx not in self._tah_position_id_cache) or ( - iter_idx not in self._tah_position_id_cache[layer_idx] - ): - return torch.empty(size=(0,), device=self.device, dtype=torch.long) - else: - return iter_idx + slot: _Slot = (layer_idx, self.current_iter_depth) + if slot in self._k: + # Same (layer, iter) called more than once (e.g. autoregressive + # decode): concatenate along the sequence dimension. + self._k[slot] = torch.cat([self._k[slot], key_states], dim=-2) + self._v[slot] = torch.cat([self._v[slot], value_states], dim=-2) + self._pos[slot] = torch.cat([self._pos[slot], self.position_ids_to_cache], dim=-1) + self._valid[slot] = torch.cat([self._valid[slot], self.valid_mask_to_cache], dim=-1) + else: + self._k[slot] = key_states + self._v[slot] = value_states + self._pos[slot] = self.position_ids_to_cache + self._valid[slot] = self.valid_mask_to_cache - cache_length_upto_iter = self.get_cache_length_upto_iter(layer_idx, upto_iter_idx) + return self.get_cache_upto_iter(layer_idx, self.current_iter_depth) - if cache_length_upto_iter == 0: # the first position id does not exist - return torch.empty(size=(0,), device=self.device, dtype=torch.long) - else: - iter_id_tensor = torch.zeros( - size=(cache_length_upto_iter,), device=self.device, dtype=torch.long - ) - cache_length_upto_current_iter = 0 - for iter_idx in range(upto_iter_idx): - cache_length_upto_current_iter += self.get_cache_length( - layer_idx, iter_idx - ) - iter_id_tensor[cache_length_upto_current_iter:] += 1 - return iter_id_tensor + # ── read views ──────────────────────────────────────────────────────── def get_cache_upto_iter( self, layer_idx: int, upto_iter_idx: int - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Get concatenated cache from iterations 0 to upto_iter_idx (inclusive). - - Args: - layer_idx: Layer index - upto_iter_idx: Maximum iteration depth to include - - Returns: - Concatenated key and value states with consistent batch dimensions - """ - if layer_idx not in self._tah_key_cache: + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Concatenated K/V from iter 0..upto_iter_idx (inclusive). ``(None, None)`` if empty.""" + keys = [self._k[(layer_idx, i)] for i in range(upto_iter_idx + 1) if (layer_idx, i) in self._k] + if not keys: return None, None + vals = [self._v[(layer_idx, i)] for i in range(upto_iter_idx + 1) if (layer_idx, i) in self._v] + return torch.cat(keys, dim=-2), torch.cat(vals, dim=-2) - all_keys = [] - all_values = [] - - # Collect cache from all iterations up to current depth - for iter_depth in range(upto_iter_idx + 1): - if iter_depth in self._tah_key_cache[layer_idx]: - all_keys.append(self._tah_key_cache[layer_idx][iter_depth]) - all_values.append(self._tah_value_cache[layer_idx][iter_depth]) - - if not all_keys: - return None, None + def _meta_upto_iter( + self, + store: Dict[_Slot, torch.Tensor], + layer_idx: int, + upto_iter_idx: int, + init_batch_size: int, + empty_dim: int, + ) -> torch.Tensor: + """Concatenate per-iter metadata tensors (positions or valid masks). + + ``empty_dim`` is the singleton dim for an absent iter (1 for B-prefixed + tensors of shape (B, T)) so the cat dimension stays well-defined. + """ + chunks = [] + batch = init_batch_size + for i in range(upto_iter_idx + 1): + t = store.get((layer_idx, i)) + if t is None: + shape = (batch, 0) if empty_dim == 2 else (0,) + chunks.append(torch.empty(shape, device=self.device, dtype=torch.long)) + else: + batch = t.shape[0] if empty_dim == 2 else batch + chunks.append(t) + return torch.cat(chunks, dim=-1) + + def get_position_id_upto_iter(self, layer_idx: int, upto_iter_idx: int, init_batch_size: int = 1) -> torch.Tensor: + return self._meta_upto_iter(self._pos, layer_idx, upto_iter_idx, init_batch_size, empty_dim=2) + + def get_valid_mask_upto_iter(self, layer_idx: int, upto_iter_idx: int, init_batch_size: int = 1) -> torch.Tensor: + return self._meta_upto_iter(self._valid, layer_idx, upto_iter_idx, init_batch_size, empty_dim=2) + + def get_cache_iter_index_upto_iter(self, layer_idx: int, upto_iter_idx: int) -> torch.Tensor: + """Per-slot iteration index for KV slots in iter 0..upto_iter_idx. + + Shape ``(total_kv_len,)``; element ``j`` is the iter the slot belongs + to. Used by the duo-mode mask: a query at iter ``i`` can attend to a + KV slot iff its iter index is ``<= i``. + """ + per_iter_lens = [ + self._k[(layer_idx, i)].shape[-2] if (layer_idx, i) in self._k else 0 + for i in range(upto_iter_idx + 1) + ] + if sum(per_iter_lens) == 0: + return torch.empty((0,), device=self.device, dtype=torch.long) + return torch.cat([ + torch.full((n,), i, device=self.device, dtype=torch.long) + for i, n in enumerate(per_iter_lens) if n > 0 + ]) + + # ── lengths & misc DynamicCache contract ───────────────────────────── + + def get_cache_length(self, layer_idx: Optional[int] = 0, iter_idx: Optional[int] = None) -> int: + """Sequence length of stored K/V; sum across iters when ``iter_idx is None``.""" + if iter_idx is not None: + t = self._k.get((layer_idx, iter_idx)) + return t.shape[-2] if t is not None else 0 + return sum(t.shape[-2] for (l, _), t in self._k.items() if l == layer_idx) + + def get_cache_length_upto_iter(self, layer_idx: Optional[int] = 0, iter_depth: int = 0) -> int: + return sum( + self._k[(layer_idx, i)].shape[-2] + for i in range(iter_depth + 1) if (layer_idx, i) in self._k + ) - concatenated_keys = torch.cat(all_keys, dim=-2) - concatenated_values = torch.cat(all_values, dim=-2) + def get_seq_length(self, layer_idx: Optional[int] = 0, iter_idx: Optional[int] = 0) -> int: + """HF protocol: max position + 1 for the layer's iter-0 slot.""" + t = self._k.get((layer_idx, iter_idx)) + return t.shape[-2] if t is not None else 0 - return concatenated_keys, concatenated_values + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + del new_seq_length + return self.get_cache_length(layer_idx) - def get_cache_length( - self, layer_idx: Optional[int] = 0, iter_idx: Optional[int] = None - ) -> int: - """ - Get the cache length for a given layer and iteration depth. If iter_idx is not provided, return the total cache length across all iterations. - Args: - layer_idx: Layer index - iter_idx: Iteration depth - Returns: - Cache length - """ - if layer_idx not in self._tah_key_cache or not self._tah_key_cache[layer_idx]: - return 0 - - total_length = 0 - if iter_idx is None: - # Return total cache length across ALL stored iterations - for iter_depth in self._tah_key_cache[layer_idx]: - key_states = self._tah_key_cache[layer_idx][iter_depth] - total_length += key_states.shape[-2] - else: - # Return cache length for a given iteration - key_states = self._tah_key_cache[layer_idx][iter_idx] - total_length = key_states.shape[-2] - - return total_length - - def get_cache_length_upto_iter( - self, layer_idx: Optional[int] = 0, iter_depth: int = 0 - ) -> int: - """Returns cache length from iterations 0 to before_iter_depth-1.""" - if layer_idx not in self._tah_key_cache or not self._tah_key_cache[layer_idx]: - return 0 - - total_length = 0 - for iter_depth in range(iter_depth + 1): - if iter_depth in self._tah_key_cache[layer_idx]: - key_states = self._tah_key_cache[layer_idx][iter_depth] - total_length += key_states.shape[-2] - - return total_length - - def get_seq_length( - self, layer_idx: Optional[int] = 0, iter_idx: Optional[int] = 0 - ) -> int: - """Returns the current sequence length (max position + 1) for a given layer.""" - if layer_idx not in self._tah_position_id_cache: - return 0 - - # Find maximum position across all iterations (sequence grows during generation) - # max_position = torch.max(self.position_id_cache[layer_idx][iter_idx]).item() # make more sense, but not exactly the same as huggingface - max_position = self._tah_key_cache[layer_idx][iter_idx].shape[-2] - return max_position - - # Not used def get_max_length(self) -> Optional[int]: - """Returns the maximum cache length if it exists.""" - return None # Dynamic cache has no maximum length + return None # dynamic — no cap def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length (i.e. max capacity) of the cache object.""" - return None # Dynamic cache has no maximum length + return None - def get_usable_length( - self, new_seq_length: int, layer_idx: Optional[int] = 0 - ) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - return self.get_cache_length(layer_idx) + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> Tuple[int, int]: + return cache_position.shape[0] + self.get_cache_length(layer_idx), 0 - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorder the cache according to beam_idx for beam search.""" - for layer_idx in self._tah_key_cache: - for iter_depth in self._tah_key_cache[layer_idx]: - device = self._tah_key_cache[layer_idx][iter_depth].device - # Reorder key cache - self._tah_key_cache[layer_idx][iter_depth] = self._tah_key_cache[layer_idx][ - iter_depth - ].index_select(0, beam_idx.to(device)) - # Reorder value cache - self._tah_value_cache[layer_idx][iter_depth] = self._tah_value_cache[layer_idx][ - iter_depth - ].index_select(0, beam_idx.to(device)) - - def get_mask_sizes( - self, cache_position: torch.Tensor, layer_idx: int - ) -> Tuple[int, int]: - """ - Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for - the given layer at `layer_idx`. - """ - query_length = cache_position.shape[0] - past_seen_tokens = self.get_cache_length(layer_idx) - kv_length = query_length + past_seen_tokens - return kv_length, 0 + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorder the batch axis of all stored tensors for beam search.""" + for store in (self._k, self._v): + for slot, t in store.items(): + store[slot] = t.index_select(0, beam_idx.to(t.device)) + + # ── device / dtype ──────────────────────────────────────────────────── @property def device(self) -> torch.device: - """Returns the device of the cached tensors.""" - # use the device of the first cached tensor - key_device = None - for layer_idx in self._tah_key_cache: - for iter_depth in self._tah_key_cache[layer_idx]: - if self._tah_key_cache[layer_idx][iter_depth] is not None: - return self._tah_key_cache[layer_idx][iter_depth].device - - # Else, use cpu as default + for t in self._k.values(): + return t.device return torch.device("cpu") @property def dtype(self) -> torch.dtype: - """Returns the dtype of the cached tensors.""" - # use the dtype of the first cached tensor - for layer_idx in self._tah_key_cache: - for iter_depth in self._tah_key_cache[layer_idx]: - if self._tah_key_cache[layer_idx][iter_depth] is not None: - return self._tah_key_cache[layer_idx][iter_depth].dtype - - # Else, use bfloat16 as default + for t in self._k.values(): + return t.dtype return torch.bfloat16 def to(self, *args, **kwargs) -> "TaHCache": - """ - Move all cached tensors to the specified device and/or convert to specified dtype. - - Supports the same interface as PyTorch tensor.to(): - - to(device) - - to(dtype) - - to(device, dtype) - - to(device=..., dtype=...) - - Returns: - Self for method chaining - """ - # Parse arguments similar to PyTorch tensor.to() - device = kwargs.get('device', None) - dtype = kwargs.get('dtype', None) - - # Handle positional arguments + """Move all stored tensors. Mirrors ``torch.Tensor.to`` calling conventions.""" + device = kwargs.get("device") + dtype = kwargs.get("dtype") for arg in args: if isinstance(arg, (torch.device, str)): device = arg elif isinstance(arg, torch.dtype): dtype = arg - # Convert key and value caches (both device and dtype) - kv_cache_dicts = [self._tah_key_cache, self._tah_value_cache] - for cache_dict in kv_cache_dicts: - for layer_idx in cache_dict: - for iter_depth in cache_dict[layer_idx]: - if cache_dict[layer_idx][iter_depth] is not None: - cache_dict[layer_idx][iter_depth] = cache_dict[layer_idx][iter_depth].to( - *([device] if device is not None else []), - *([dtype] if dtype is not None else []) - ) - - # Convert position_id and valid_mask caches (device only, preserve dtype) - metadata_cache_dicts = [self._tah_position_id_cache, self._tah_valid_mask_cache] - for cache_dict in metadata_cache_dicts: - for layer_idx in cache_dict: - for iter_depth in cache_dict[layer_idx]: - if cache_dict[layer_idx][iter_depth] is not None and device is not None: - cache_dict[layer_idx][iter_depth] = cache_dict[layer_idx][iter_depth].to(device) - + # K/V get both device and dtype; position/valid stay long but follow device. + for store in (self._k, self._v): + for slot, t in store.items(): + store[slot] = t.to(*(([device] if device is not None else []) + ([dtype] if dtype is not None else []))) + if device is not None: + for store in (self._pos, self._valid): + for slot, t in store.items(): + store[slot] = t.to(device) return self diff --git a/tah/model/input_updater.py b/tah/model/input_updater.py deleted file mode 100755 index cba863b..0000000 --- a/tah/model/input_updater.py +++ /dev/null @@ -1,86 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, TYPE_CHECKING -import torch -import torch.nn as nn - -if TYPE_CHECKING: - from tah.model.recurrent_transformer import TaHForCausalLM - -from tah.model.registry import register_input_updater, get_input_updater_class, capture_init_args - - -class InputUpdater(nn.Module, ABC): - """ - Base class for updating input embeddings between iterations. - - This class is designed to efficiently handle tensors of arbitrary shape (..., x), - where the leading dimensions can be any combination of batch, sequence, or other - dimensions. All operations preserve the leading dimensions and only operate on - the last dimension for vocabulary/embedding operations. - """ - - @abstractmethod - def forward( - self, - logits: torch.Tensor, - prev_inputs: torch.Tensor, - embedding_weight: torch.Tensor, - hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Return updated inputs for the next iteration. - - This method efficiently handles tensors of arbitrary shape, preserving all - leading dimensions while operating only on the embedding dimension. - - Args: - logits: The logits from the token, shape (..., vocab_size) - prev_inputs: The previous inputs, shape (..., embed_dim) - embedding_weight: The embedding weight tensor, shape (vocab_size, embed_dim) - hidden_states: The hidden states, shape (..., hidden_dim) - - Returns: - The updated inputs, shape (..., embed_dim) - - Note: - All leading dimensions (...) are preserved exactly. The implementation - is optimized for efficient processing regardless of the number or size - of leading dimensions (e.g., batch size, sequence length, etc.). - """ - - -@register_input_updater -@capture_init_args -class TrivialUpdater(InputUpdater): - """ - Trivial update that directly returns logits-weighted embeddings. - - Efficiently handles tensors of arbitrary shape (..., vocab_size), preserving - all leading dimensions while computing weighted embeddings. - """ - - def __init__(self, use_hidden_states: bool = False, topk: Optional[int] = None): - super().__init__() - self.use_hidden_states = use_hidden_states - self.topk = topk - - def forward( - self, - logits: torch.Tensor, - prev_inputs: torch.Tensor, - embedding_weight: torch.Tensor, - hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Direct matrix multiplication preserves all leading dimensions: (..., vocab_size) @ (vocab_size, embed_dim) -> (..., embed_dim) - if self.use_hidden_states: - return hidden_states[...,-1,:] # shape: seq_len, num_layer, embed_dim - else: - if self.topk is not None: - topk_values, topk_indices = torch.topk(logits, k=min(self.topk, logits.size(-1)), dim=-1) - topk_probs = torch.softmax(topk_values, dim=-1) - topk_embeddings = embedding_weight[topk_indices] - return torch.sum(topk_probs.unsqueeze(-1) * topk_embeddings, dim=-2) - else: - return torch.softmax(logits, dim=-1) @ embedding_weight - - diff --git a/tah/model/iter_decider.py b/tah/model/iter_decider.py old mode 100755 new mode 100644 index dd00894..c7da5a8 --- a/tah/model/iter_decider.py +++ b/tah/model/iter_decider.py @@ -1,76 +1,104 @@ +"""Iter deciders: per-token "should we iterate again?" classifiers. + +Two implementations are kept — both used in the canonical training/eval +recipes. Public TaH also shipped ``TrivialIterDecider``, +``AlwaysWrapperIterDecider``, and ``OracleDynamicIterDecider``; none of those +are used by the released checkpoints and have been removed. + +* :class:`IterLabelDecider` — step-1 SFT. Continues iff the dataset's oracle + ``iter_count_labels`` say so, ignoring model logits entirely. Used to teach + the LoRA adapter on tokens that the labeller marked "hard". +* :class:`MLPIterDecider` — step-2 SFT and eval/serving. Small MLP over a + selected subset of per-layer hidden states plus the projected top-k logits; + outputs a continue-probability per token. + +Each decider returns ``(decision_bool, decision_logits)``: + +* ``decision_bool`` — ``(N,)`` BoolTensor, True ⇒ continue iterating. +* ``decision_logits`` — ``(N,)`` float tensor of pre-sigmoid logits, used by + ``IterDeciderLoss`` for BCE supervision. May be a constant (e.g. + ``IterLabelDecider`` produces NEUTRAL_LOGITS) when the decider doesn't + itself produce a learnable score. + +Persistence: :func:`save_iter_decider` pickles ``(class_name, init_args, +state_dict)`` to ``iter_decider.bin``. :func:`load_iter_decider` restores via +``ITER_DECIDER_BY_NAME``. +""" +from __future__ import annotations + +import inspect import os +from typing import Optional, Tuple + import torch import torch.nn as nn -from typing import Callable, Dict, Optional, Any, Type, Union, Tuple, List - -from tah.model.registry import ( - register_iter_decider, - get_iter_decider_class, - capture_init_args, - mark_wrapper_iter_decider, -) - -POSITIVE_INFINITY_LOGITS = 10.0 -MINUS_INFINITY_LOGITS = -10.0 +# Per-class threshold semantics for IterLabelDecider's auxiliary logits. NEUTRAL_LOGITS = 0.0 +MINUS_INFINITY_LOGITS = -10.0 -class IterDecider(nn.Module): - """Base class for deciding whether to continue iterating a token. - All IterDecider implementations must efficiently handle inputs of arbitrary shape (..., vocab_size) - where (...) can be any number of leading dimensions (batch, sequence, etc.). - """ +class IterDecider(nn.Module): + """Base class. Subclasses must override :meth:`forward`.""" def __init__(self, threshold: float = 0.5, max_iter: int = 3): super().__init__() - # store as buffer to allow assignment on subclasses without property conflicts + # Buffer (not parameter) so subclasses can also override as a learnable + # nn.Parameter without registry conflicts. self.register_buffer("threshold", torch.tensor(float(threshold), dtype=torch.float32)) self.max_iter = max_iter + self._init_args: dict = {} - def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor: - """ - Decide whether to continue iterating a token. - - Args: - logits: The logits of the token, shape (..., vocab_size) where (...) - represents arbitrary leading dimensions - iter_depth: The iteration depth of the token that has been processed. - Optional kwargs: - - hidden_states: The hidden states of the token, shape (..., hidden_size) where (...) - - Returns: - A float tensor of shape (...) with values between 0 and 1, - indicating the probability of continuing iteration. - The output preserves all leading dimensions from the input. - """ + def forward( + self, + logits: torch.Tensor, + iter_depth: int, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError -@register_iter_decider -@capture_init_args -class TrivialIterDecider(IterDecider): - """Trivial iteration decider that always ends. + def _stop_decision(self, logits: torch.Tensor, fill: float) -> Tuple[torch.Tensor, torch.Tensor]: + """Return ``(all-False, full-of-`fill`)`` shaped to match logits' leading dims. + + Used by subclasses when ``iter_depth >= max_iter`` (cap exhausted) or + when there's no useful input to decide on. + """ + shape = logits.shape[:-1] + decision = torch.zeros(shape, dtype=torch.bool, device=logits.device) + return decision, torch.full(shape, fill, dtype=logits.dtype, device=logits.device) + - Efficiently handles arbitrary input shapes (..., vocab_size) by returning - a boolean tensor of shape (...,) filled with False values. +def _capture_init_args(cls): + """Decorator: store positional + keyword __init__ args on ``self._init_args``. + + Used by :func:`save_iter_decider` so that loaders can re-instantiate the + class with the same constructor arguments before applying the state dict. """ + original_init = cls.__init__ + sig = inspect.signature(original_init) + param_names = [p for p in sig.parameters.keys() if p != "self"] - def __init__(self, max_iter: int = 1): - super().__init__(max_iter=max_iter) + def new_init(self, *args, **kwargs): + captured: dict = {} + for i, arg in enumerate(args): + if i < len(param_names): + captured[param_names[i]] = arg + captured.update(kwargs) + original_init(self, *args, **kwargs) + self._init_args = captured - def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor: - decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device) - logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device) - return decision, logits_out + cls.__init__ = new_init + return cls -@register_iter_decider -@capture_init_args -class IterLabelDecider(IterDecider): - """Iteration decider that strictly follows provided iter_count_labels. - Decision rule: continue if and only if (iter_count_labels > iter_depth) for valid tokens. - padding/ignored tokens (-100) will always stop. - """ +# ──────────────────────────────────────────────────────────────────────────── +# IterLabelDecider — oracle-supervised, used in step-1 SFT. +# ──────────────────────────────────────────────────────────────────────────── + + +@_capture_init_args +class IterLabelDecider(IterDecider): + """Continues iff ``iter_count_labels > iter_depth`` (and not ignored).""" def __init__(self, max_iter: int = 3): super().__init__(max_iter=max_iter) @@ -81,134 +109,104 @@ def forward( iter_depth: int, iter_count_labels: Optional[torch.Tensor] = None, **kwargs, - ) -> torch.Tensor: - if (iter_depth >= self.max_iter): - decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device) - logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device) - return decision, logits_out - if (iter_count_labels is None): - decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device) - logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device) - return decision, logits_out - - valid_mask = (iter_count_labels != -100) - decision_bool = (iter_count_labels > iter_depth) & valid_mask - decision = decision_bool - logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device) - return decision, logits_out - -class ClassifierBlock(nn.Module): - """ - A single transformer-style block for the classifier backbone. - Implements layer normalization, MLP with expansion, and residual connections. - """ + ) -> Tuple[torch.Tensor, torch.Tensor]: + if iter_depth >= self.max_iter: + return self._stop_decision(logits, fill=MINUS_INFINITY_LOGITS) + if iter_count_labels is None: + return self._stop_decision(logits, fill=NEUTRAL_LOGITS) + decision = (iter_count_labels > iter_depth) & (iter_count_labels != -100) + _, neutral = self._stop_decision(logits, fill=NEUTRAL_LOGITS) + return decision, neutral + + +# ──────────────────────────────────────────────────────────────────────────── +# MLPIterDecider — learned classifier, used in step-2 SFT and eval/serving. +# ──────────────────────────────────────────────────────────────────────────── + + +class _ClassifierBlock(nn.Module): + """LayerNorm + 2-layer MLP with residual; supports dim change.""" + def __init__( self, - input_dim, - output_dim, - expansion_factor=4, - dropout_rate=0.3 + input_dim: int, + output_dim: int, + expansion_factor: int, + dropout_rate: float, + dtype: Optional[torch.dtype] = None, ): super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - - self.layer_norm = nn.LayerNorm(input_dim) - + factory = {"dtype": dtype} if dtype is not None else {} + self.layer_norm = nn.LayerNorm(input_dim, **factory) self.mlp = nn.Sequential( - nn.Linear(input_dim, input_dim * expansion_factor), + nn.Linear(input_dim, input_dim * expansion_factor, **factory), nn.GELU(), nn.Dropout(dropout_rate), - nn.Linear(input_dim * expansion_factor, output_dim), + nn.Linear(input_dim * expansion_factor, output_dim, **factory), nn.Dropout(dropout_rate), ) - self.dim_change = ( - nn.Linear(input_dim, output_dim) - if input_dim != output_dim - else nn.Identity() + nn.Linear(input_dim, output_dim, **factory) if input_dim != output_dim else nn.Identity() ) - def forward(self, x): - normalized = self.layer_norm(x) - residual = self.dim_change(x) - return residual + self.mlp(normalized) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.dim_change(x) + self.mlp(self.layer_norm(x)) -class ClassifierBackbone(nn.Module): - """ - Backbone architecture for all classifiers. - Implements transformer-style MLP blocks with residual connections. - Position embeddings are disabled in this setup. - """ +class _ClassifierBackbone(nn.Module): + """Stacked _ClassifierBlocks with kaiming-init linears and a final norm + projection.""" def __init__( self, - input_dim, - output_dim=1, - hidden_dims=[256, 512, 256], - expansion_factor=4, - dropout_rate=0.3, - use_position_embedding=False, - max_position_embeddings=1024, + input_dim: int, + hidden_dims, + expansion_factor: int, + dropout_rate: float, + output_dim: int = 1, + dtype: Optional[torch.dtype] = None, ): super().__init__() - self.use_position_embedding = use_position_embedding - - self.blocks = nn.ModuleList() - - self.input_projection = nn.Linear(input_dim, hidden_dims[0]) - - block_dims = hidden_dims + [hidden_dims[-1]] - - for i in range(len(block_dims) - 1): - block_input_dim = block_dims[i] - block_output_dim = block_dims[i + 1] - self.blocks.append( - ClassifierBlock( - input_dim=block_input_dim, - output_dim=block_output_dim, - expansion_factor=expansion_factor, - dropout_rate=dropout_rate, - ) - ) - + factory = {"dtype": dtype} if dtype is not None else {} + self.input_projection = nn.Linear(input_dim, hidden_dims[0], **factory) + block_dims = list(hidden_dims) + [hidden_dims[-1]] + self.blocks = nn.ModuleList( + _ClassifierBlock(block_dims[i], block_dims[i + 1], expansion_factor, dropout_rate, **factory) + for i in range(len(block_dims) - 1) + ) self.output_layer = nn.Sequential( - nn.LayerNorm(hidden_dims[-1]), - nn.Linear(hidden_dims[-1], output_dim), + nn.LayerNorm(hidden_dims[-1], **factory), + nn.Linear(hidden_dims[-1], output_dim, **factory), ) + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) - self._init_weights() - - def _init_weights(self): - for module in self.modules(): - if isinstance(module, nn.Linear): - nn.init.kaiming_normal_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - def apply_position_embedding(self, x, position_ids=None): - # No position embeddings used in this setup - return x - - def forward(self, x, position_ids=None): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.input_projection(x) - x = self.apply_position_embedding(x, position_ids) for block in self.blocks: x = block(x) return self.output_layer(x) -@register_iter_decider -@capture_init_args + +@_capture_init_args class MLPIterDecider(IterDecider): - """Classifier-based iteration decider using hidden states and top-k logits.""" + """Decides by combining a slice of per-layer hidden states with top-k logits. + + Hidden features: pick layer indices ``hidden_states_layer_nums`` from the + base model's stacked hidden states (shape ``(..., L, H)``), concat into a + single ``(..., len(layer_nums) * H)`` vector. Logit features: top-k of the + per-token logits, projected up to the hidden size. Concat both, project to + backbone-input width, run the MLP backbone, sigmoid → continue probability. + """ def __init__( self, topk: int = 100, hidden_states_size: int = 1024, - hidden_states_layer_nums: list = [16,20,24,28], # explicit layer indices to use from all_hidden_states - hidden_dims: list = [256, 512, 256], + hidden_states_layer_nums = (16, 20, 24, 28), + hidden_dims = (256, 512, 256), expansion_factor: int = 4, dropout_rate: float = 0.3, normalize_input: bool = False, @@ -220,578 +218,137 @@ def __init__( self.topk = topk self.hidden_states_size = hidden_states_size self.hidden_states_layer_nums = list(hidden_states_layer_nums) - if hasattr(self.__class__, 'threshold'): - delattr(self, 'threshold') - self.threshold = nn.Parameter(torch.tensor(threshold, dtype=dtype, requires_grad=True)) + self.normalize_input = normalize_input self.max_iter = max_iter - self.normalize_input = normalize_input - if self.normalize_input: - num_selected = max(1, len(self.hidden_states_layer_nums)) - self.layer_norm_hidden_states = nn.LayerNorm(hidden_states_size * num_selected) + # threshold: replace the buffer with a learnable parameter. + if hasattr(self.__class__, "threshold"): + try: + delattr(self, "threshold") + except AttributeError: + pass + self.threshold = nn.Parameter(torch.tensor(threshold, dtype=dtype, requires_grad=True)) - # Project top-k logits to hidden state size - self.logits_projection = nn.Linear(self.topk, hidden_states_size, dtype=dtype) + n_layers = max(1, len(self.hidden_states_layer_nums)) + if normalize_input: + self.layer_norm_hidden_states = nn.LayerNorm(hidden_states_size * n_layers) - # Combine hidden states and projected logits - num_selected = max(1, len(self.hidden_states_layer_nums)) - combined_size = hidden_states_size * num_selected + hidden_states_size - self.combined_projection = nn.Linear(combined_size, hidden_dims[0], dtype=dtype) - - # Backbone MLP stack - self.backbone = ClassifierBackbone( - input_dim=hidden_dims[0], - output_dim=1, - hidden_dims=hidden_dims, - expansion_factor=expansion_factor, - dropout_rate=dropout_rate, + self.logits_projection = nn.Linear(self.topk, hidden_states_size, dtype=dtype) + self.combined_projection = nn.Linear(hidden_states_size * n_layers + hidden_states_size, hidden_dims[0], dtype=dtype) + self.backbone = _ClassifierBackbone( + hidden_dims[0], hidden_dims, expansion_factor, dropout_rate, dtype=dtype, ) - self.sigmoid = nn.Sigmoid() - def forward(self, logits: torch.Tensor, iter_depth: int, all_hidden_states: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - if iter_depth >= self.max_iter: - decision = torch.zeros( - logits.shape[:-1], dtype=torch.bool, device=logits.device - ) - logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device) - return decision, logits_out - - original_shape = logits.shape[:-1] + def _select_hidden_features( + self, + all_hidden_states: Optional[torch.Tensor], + decision_shape: Tuple[int, ...], + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Pick + flatten the configured hidden-state layers (or zero-fill if missing). - # Build hidden features from requested layers without padding - num_selected = max(1, len(self.hidden_states_layer_nums)) + ``all_hidden_states`` may be ``(..., L, H)`` or ``(..., H)`` (single + layer); we treat the latter as ``L=1``. Returned shape is + ``(*decision_shape, n_layers * H)``. + """ + n_layers = max(1, len(self.hidden_states_layer_nums)) + flat_dim = self.hidden_states_size * n_layers if all_hidden_states is None: - # Fallback to zeros if hidden states are unavailable - hidden_concat = torch.zeros(*original_shape, self.hidden_states_size * num_selected, device=logits.device, dtype=logits.dtype) - else: - hs = all_hidden_states - # Expect shape (..., L, H); if (..., H) provided, treat as single-layer - if hs.dim() == logits.dim(): - hs = hs.unsqueeze(-2) - total_layers = hs.size(-2) - if num_selected == 1 and len(self.hidden_states_layer_nums) == 0: - indices = [total_layers - 1] - else: - indices = self.hidden_states_layer_nums - index_tensor = torch.as_tensor(indices, device=hs.device, dtype=torch.long) - if index_tensor.numel() == 0: - raise ValueError("hidden_states_layer_nums must not be empty") - if torch.min(index_tensor).item() < 0 or torch.max(index_tensor).item() >= total_layers: - raise ValueError(f"hidden_states_layer_nums out of range: {indices}, total_layers={total_layers}") - selected = torch.index_select(hs, dim=-2, index=index_tensor) # (..., num_selected, H) - hidden_concat = selected.reshape(*original_shape, selected.size(-2) * self.hidden_states_size) - - # Mirror PluginNeuralIterDecider behavior: apply top-k on logits - k = min(self.topk, logits.size(-1)) - topk_values, _ = torch.topk(logits, k=k, dim=-1) - - # Optional normalization + return torch.zeros(*decision_shape, flat_dim, device=device, dtype=dtype) + + hs = all_hidden_states + if hs.dim() == len(decision_shape) + 1: # (..., H) — promote to (..., 1, H) + hs = hs.unsqueeze(-2) + total_layers = hs.size(-2) + indices = self.hidden_states_layer_nums or [total_layers - 1] + idx = torch.as_tensor(indices, device=hs.device, dtype=torch.long) + if idx.numel() == 0 or int(idx.min()) < 0 or int(idx.max()) >= total_layers: + raise ValueError(f"hidden_states_layer_nums {indices} out of range for {total_layers} layers") + return torch.index_select(hs, dim=-2, index=idx).reshape(*decision_shape, flat_dim) + + def forward( + self, + logits: torch.Tensor, + iter_depth: int, + all_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if iter_depth >= self.max_iter: + return self._stop_decision(logits, fill=MINUS_INFINITY_LOGITS) + + decision_shape = logits.shape[:-1] + hidden_concat = self._select_hidden_features( + all_hidden_states, decision_shape, logits.device, logits.dtype, + ) + topk_values, _ = torch.topk(logits, k=min(self.topk, logits.size(-1)), dim=-1) if self.normalize_input: hidden_concat = self.layer_norm_hidden_states(hidden_concat) topk_values = torch.softmax(topk_values, dim=-1) - # Project logits and combine - logits_features = self.logits_projection(topk_values) - combined_features = torch.cat([hidden_concat, logits_features], dim=-1) - x = self.combined_projection(combined_features) - + x = self.combined_projection( + torch.cat([hidden_concat, self.logits_projection(topk_values)], dim=-1), + ) decision_logits = self.backbone(x) if decision_logits.dim() == logits.dim(): decision_logits = decision_logits.squeeze(-1) - decision_scores = self.sigmoid(decision_logits) - thr = self.threshold - if isinstance(thr, torch.Tensor): - thr = float(thr.detach().item()) - decision_mask = (decision_scores > thr) + threshold = self.threshold + if isinstance(threshold, torch.Tensor): + threshold = float(threshold.detach().item()) + decision_mask = self.sigmoid(decision_logits) > float(threshold) return decision_mask, decision_logits -@register_iter_decider -@capture_init_args -class AlwaysWrapperIterDecider(IterDecider): - """Wrapper that enforces a simple control-flow policy around a base iter decider. +# ──────────────────────────────────────────────────────────────────────────── +# Persistence + name-based dispatch. +# ──────────────────────────────────────────────────────────────────────────── - Modes: - - "continue": force continuing until the final allowed iteration (previous behavior) - - "stop": stop after the first iteration - Finishing rule via threshold (used by TaH's finished_mask = (prob <= threshold)): - - continue: threshold = -1.0 until last iteration, then 1.0 to finish all - - stop: threshold = 1.0 at the first iteration so all tokens finish immediately - """ +ITER_DECIDER_BY_NAME = { + "IterLabelDecider": IterLabelDecider, + "MLPIterDecider": MLPIterDecider, +} - def __init__( - self, - max_iter: int = 3, - base_iter_decider_cls: str = "MLPIterDecider", - base_iter_decider_kwargs: Optional[dict] = None, - mode: str = "continue", - ): - super().__init__(max_iter=max_iter) - if not isinstance(base_iter_decider_cls, str): - raise ValueError("AlwaysWrapperIterDecider expects base_iter_decider_cls as a string class name") - mode = str(mode).lower().strip() - if mode not in ("continue", "stop"): - raise ValueError("AlwaysWrapperIterDecider mode must be either 'continue' or 'stop'") - self.mode = mode - - base_cls = get_iter_decider_class(base_iter_decider_cls) - base_iter_decider_kwargs = dict(base_iter_decider_kwargs or {}) - base_iter_decider_kwargs.setdefault("max_iter", max_iter) - self.base_iter_decider = base_cls(**base_iter_decider_kwargs) - self._last_forward_iter_depth: Optional[int] = None - - def update_training_state(self, current_step: int, current_epoch: int): - if getattr(self, 'base_iter_decider', None) is not None and hasattr(self.base_iter_decider, 'update_training_state') and callable(self.base_iter_decider.update_training_state): - try: - self.base_iter_decider.update_training_state(current_step=current_step, current_epoch=current_epoch) - except Exception: - pass - def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor: - # Respect cap for shape consistency - if iter_depth >= self.max_iter: - decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device) - logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device) - return decision, logits_out - - self._last_forward_iter_depth = int(iter_depth) - - # Delegate to base decider to obtain its logits if any - base_output = self.base_iter_decider(logits, iter_depth, **kwargs) - if (not isinstance(base_output, tuple)) or (len(base_output) < 2): - raise TypeError("Base iter decider must return a (decision_bool, logits) tuple") - base_decision, base_logits = base_output[0], base_output[1] - - if self.mode == "continue": - decision = torch.ones_like(base_decision, dtype=torch.bool, device=base_decision.device) - else: # stop after first iteration - decision = torch.zeros_like(base_decision, dtype=torch.bool, device=base_decision.device) - return decision, (base_logits.to(dtype=logits.dtype) if base_logits is not None else torch.full_like(decision, NEUTRAL_LOGITS, dtype=logits.dtype)) - - @property - def threshold(self) -> float: - # If last forward depth is unknown, fall back to base threshold or 0.5 - if self._last_forward_iter_depth is None: - thr = getattr(self.base_iter_decider, 'threshold', None) - try: - if isinstance(thr, torch.Tensor): - return float(thr.detach().item()) - except Exception: - pass - return float(thr) if thr is not None else 0.5 - - if self.mode == "continue": - # Continue until final iteration - if self._last_forward_iter_depth < self.max_iter: - return -1.0 - else: - return 1.0 - else: - # Stop mode: finish immediately at the first iteration - return 1.0 - - -@register_iter_decider -@capture_init_args -class OracleDynamicIterDecider(IterDecider): - """ - Use LLM forward in the first iter to get the oracle token. - In subsequent iters, compare the predicted token with the oracle token. - """ +def get_iter_decider_class(name: str): + """Lookup helper kept for backwards-compatible callers.""" + if name not in ITER_DECIDER_BY_NAME: + raise ValueError(f"Unknown iter_decider {name!r}; have {sorted(ITER_DECIDER_BY_NAME)}") + return ITER_DECIDER_BY_NAME[name] - def __init__( - self, - max_iter: int = 3, - ref_model_path: Optional[str] = None, - dtype: Union[torch.dtype, str] = "auto", - device: Optional[str] = None, - use_kv_cache: bool = True, - backend: str = 'sglang', - false_positive_rate: float = 0, - false_negative_rate: float = 0 - ): - ''' - Note: the ref_model is default to hf model - Args: - - max_iter (int): The maximum number of iterations to perform. - - ref_model_path (str, optional): The path to the reference model. - - dtype (torch.dtype or str, optional): The data type to use for the model. - - device (str, optional): The device to run the model on. - - use_kv_cache (bool, optional): Whether to use key-value caching. - - backend (str, optional): The backend to use for the model. options: ['sglang', 'hf'] - ''' - super().__init__() - self.max_iter = max_iter - self.use_kv_cache = use_kv_cache - self._dtype = dtype - self._device = device - self.backend = backend - - self.ref_model = None - self._ref_model_path = ref_model_path - print('oracle iterdecider ref_model path:', self._ref_model_path) - self._ref_past = None - # Cache of oracle tokens from reference model prefill - # For HF backend prefill at iter_depth==1, we cache per-position greedy tokens - # shape: (batch, seq_len) - self._cached_tokens_full = None - self._last_step_depth = None - - self.false_positive_rate = false_positive_rate - self.false_negative_rate = false_negative_rate - - - @torch.no_grad() - def _ensure_ref(self, ref_model): - if ref_model is not None: - return ref_model - if self.ref_model is not None: - return self.ref_model - if self._ref_model_path is None: - raise RuntimeError( - "OracleDynamicIterDecider needs a reference model: please provide ref_model_path in init " - "or pass ref_model in forward()." - ) - if self.backend == 'hf': - torch_dtype = None - if isinstance(self._dtype, str) and self._dtype != "auto": - torch_dtype = getattr(torch, self._dtype) - elif isinstance(self._dtype, torch.dtype): - torch_dtype = self._dtype - from transformers import AutoModelForCausalLM - self.ref_model = AutoModelForCausalLM.from_pretrained( - self._ref_model_path, - torch_dtype=torch_dtype, - device_map="auto" if self._device is None else None, - attn_implementation='sdpa', - trust_remote_code=True, - low_cpu_mem_usage=True, - ) - if self._device is not None: - self.ref_model.to(self._device) - self.ref_model.eval() - return self.ref_model - elif self.backend == 'sglang': - from transformers import AutoTokenizer, AutoConfig - import socket, os - import multiprocessing as mp - - class _SGLGreedyOracle: - """ - - """ - def __init__(self, model_path: str, dtype="auto"): - print('Initializing sgl engine') - - self.tok = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True, use_fast=True - ) - self.pad_id = self.tok.pad_token_id - self.vocab_size = self.tok.vocab_size - - def _to_sgl_dtype(dt): - if dt in (None, "auto"): return "auto" - if isinstance(dt, torch.dtype): - m = {torch.float16: "float16", torch.bfloat16: "bfloat16", torch.float32: "float32"} - return m.get(dt, "auto") - if isinstance(dt, str): return dt - return "auto" - - def _get_idx(): - try: - name = mp.current_process().name - return int(name.split("-")[-1]) - except Exception: - return 0 - - def _next_free_port(start): - p = start - while True: - with socket.socket() as s: - try: - s.bind(("", p)) - return p - except OSError: - p += 1 - - idx = _get_idx() - - BASE = 32000 + idx * 20 - http_port = _next_free_port(BASE + 0) - rdzv_port = _next_free_port(BASE + 1) - nccl_port = _next_free_port(BASE + 2) - - print(f'[Job {idx}] SGLang Engine ports: HTTP {http_port}, RDZV {rdzv_port}, NCCL {nccl_port}') - - os.environ["SGLANG_PORT"] = str(http_port) - os.environ["NCCL_IB_DISABLE"] = "1" - - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(rdzv_port) - os.environ["SGLANG_PORT"] = str(http_port) - os.environ["NCCL_IB_DISABLE"] = "1" - - import sglang as sgl - self.eng = sgl.Engine( - model_path=model_path, - tokenizer_path=model_path, - trust_remote_code=True, - dtype=_to_sgl_dtype(dtype), - tp_size=2, - dp_size=1, - enable_dp_attention=False, - host='127.0.0.1', - port=http_port, # HTTP - dist_init_addr=f"127.0.0.1:{rdzv_port}", # **For TCPStore port** - ) - self._running_ids = None # list[list[int]] - - - def _trim_left_pad(self, ids: list[int]) -> list[int]: - if self.pad_id is None or not ids: return ids - start = 0 - while start < len(ids) and ids[start] == self.pad_id: - start += 1 - return ids[start:] - - def _tokenize_one(self, text: str) -> List[int]: - if not text: - return [] - return self.tok(text, add_special_tokens=False, return_attention_mask=False).input_ids - - def greedy_next_token_ids(self, input_ids: torch.Tensor, fresh: bool) -> torch.Tensor: - assert isinstance(input_ids, torch.Tensor) and input_ids.dim() == 2, \ - "input_ids must be a 2D LongTensor (B, T)" - device = input_ids.device - ids_list = input_ids.tolist() - B = len(ids_list) - - if fresh or self._running_ids is None: - # new sequence - ids_list = [self._trim_left_pad(x) for x in ids_list] - self._running_ids = [list(x) for x in ids_list] - else: - # continue - if len(self._running_ids) != B: - raise RuntimeError( - "fresh=False but batch size changed, set fresh=True when starting a new batch." - ) - for i in range(B): - self._running_ids[i].extend(ids_list[i]) - - # allow - outs = self.eng.generate( - input_ids=self._running_ids, - sampling_params={ - "max_new_tokens": 1, - "temperature": 0.0, - "top_p": 1.0, - "ignore_eos": False, # allow - }, - return_logprob=False, - ) - - next_ids: List[int] = [] - for i, o in enumerate(outs): - nid: Optional[int] = None - - # for different version - for k in ("token_ids", "output_ids"): - val = o.get(k, None) - if isinstance(val, list) and len(val) >= 1: - nid = int(val[0]) - break - - if nid is None: - cont = o.get("text", "") - new_ids = self._tokenize_one(cont) - if new_ids: - nid = int(new_ids[0]) - else: - # if model did not return new_id, return eos - nid = int(self.tok.eos_token_id) if self.tok.eos_token_id is not None else 0 - - next_ids.append(nid) - # self._running_ids[i].append(nid) - result = torch.tensor(next_ids, dtype=torch.long, device=device).detach() - del outs, ids_list, next_ids - return result - - config = AutoConfig.from_pretrained(self._ref_model_path, trust_remote_code=True) - self.ref_model = _SGLGreedyOracle(self._ref_model_path, dtype=self._dtype) - self.ref_model.config = config - return self.ref_model - else: - raise RuntimeError("Unsupported backend: {}".format(self.backend)) - - @torch.no_grad() - def forward( - self, - logits: torch.Tensor, - iter_depth: int, - active_valid_mask: torch.Tensor, - prediction_logits: Optional[torch.Tensor] = None, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ref_model: Optional[nn.Module] = None, - fresh: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - ''' - Args: - - logits (torch.Tensor): The logits tensor from the model. - - iter_depth (int): The current iteration depth. - - active_valid_mask (torch.Tensor): shape (batch, max_active_len) - - prediction_logits (torch.Tensor): next prediction logits from base model (batch, vocab) - - input_ids (torch.Tensor, optional): (batch, seq) - - ref_model (nn.Module, optional): The reference model to use. - - fresh (torch.Tensor, optional): If True, reset the reference model state. - - **kwargs: Additional keyword arguments. - ''' - if iter_depth >= self.max_iter: - return torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device), None - - # iter depth examination - if self._last_step_depth is None and iter_depth > 1: - raise RuntimeError("Cannot start OracleDynamicIterDecider with iter_step > 1") - if self._last_step_depth is not None and iter_depth <= self._last_step_depth and iter_depth > 1: - print("[OracleDynamicIterDecider] Warning: iter_depth not increasing, make sure it restarts from 1 per sequence.") - self._last_step_depth = iter_depth - - # Note: For per-position decisions we prefer using the provided per-token logits (logits arg) - # which corresponds to active positions flattened by active_valid_mask. We'll reconstruct - # a (B, T) view later. prediction_logits (B, V) may reflect only the last position and is - # insufficient for full prefill decisions. - - if iter_depth == 1: - ref = self._ensure_ref(ref_model) - if input_ids is None: - raise RuntimeError("input_ids is needed when iter_depth == 1") - - # assert prediction_logits.size(-1) == ref.config.vocab_size, \ - # f"Base model vocab != ref model vocab(base {prediction_logits.size(-1)}, ref {ref.config.vocab_size}). Use same tokenizer/vocab for oracle comparison." - - batch_size, query_len = input_ids.shape - # check if new sequence - if fresh: - self._ref_past = None - self._cached_tokens_full = None - - if self.backend == 'hf': - # Prepare attention mask aligned to current query length - attention_mask = attention_mask[:, -query_len:] if attention_mask is not None else None - outputs = ref( - input_ids=input_ids, - attention_mask=attention_mask, - use_cache=True, - past_key_values=self._ref_past, - ) - - if self.use_kv_cache: - self._ref_past = outputs.past_key_values - - # Cache per-position oracle tokens for the whole sequence (B, T) - # We only keep argmax tokens to avoid storing full logits - self._cached_tokens_full = outputs.logits.argmax(dim=-1) - # Free large logits tensor ASAP - del outputs - elif self.backend == 'sglang': - # For sglang backend we only support next-token oracle for decode use-case. - # To maintain compatibility, fill the last position using sglang's next token - next_ids = ref.greedy_next_token_ids(input_ids, fresh=fresh) - # Build a per-position cache with only the last position populated - self._cached_tokens_full = torch.zeros_like(input_ids) - self._cached_tokens_full[:, -1] = next_ids - - # Ensure oracle tokens are available - if self._cached_tokens_full is None: - raise RuntimeError("OracleDynamicIterDecider must have per-position oracle tokens cached at iter 1.") - - # Compute base predictions for all active positions from current iteration logits - # logits shape: (num_valid_positions, vocab) - base_pred_flat = logits.argmax(dim=-1) - - # Reconstruct per-position matrix (B, T) for base predictions - # Fill only where active_valid_mask == 1 - base_pred_full = torch.zeros_like(input_ids) - base_pred_full[active_valid_mask == 1] = base_pred_flat - - # Compute per-position continue decisions: continue if base iter1 != oracle - continue_full = (base_pred_full != self._cached_tokens_full).to(torch.bool) - # Only consider valid positions - continue_full = continue_full & active_valid_mask.bool() - - # Apply noise per-position - if self.false_negative_rate > 0 or self.false_positive_rate > 0: - rand_fn = torch.rand_like(continue_full.float()) - rand_fp = torch.rand_like(continue_full.float()) - continue_full = (continue_full | (rand_fn < self.false_negative_rate)) & (rand_fp >= self.false_positive_rate) - - # Flatten back to match expected return shape (num_valid_positions,) - continue_mask = continue_full[active_valid_mask == 1].bool().to(logits.device) - - return continue_mask, None - - -def save_iter_decider(iter_decider: IterDecider, save_directory: str): - """Save iter_decider state dict and configuration.""" - # Use captured initialization arguments from the decorator - init_args = getattr(iter_decider, "_init_args", {}) - - # Use natural state_dict - no overrides needed - state_dict = iter_decider.state_dict() - state_dict = {k: v.cpu() for k, v in state_dict.items()} - data = { - "class": iter_decider.__class__.__name__, - "state_dict": state_dict, - "init_args": init_args, - } - save_path = os.path.join(save_directory, "iter_decider.bin") - print(f"Saving iter_decider with {len(state_dict)} parameters to {save_path}") - torch.save(data, save_path) +def save_iter_decider(decider: IterDecider, save_directory: str) -> None: + """Pickle ``(class_name, init_args, state_dict)`` to ``iter_decider.bin``.""" + state = {k: v.detach().cpu() for k, v in decider.state_dict().items()} + payload = { + "class": decider.__class__.__name__, + "state_dict": state, + "init_args": getattr(decider, "_init_args", {}), + } + torch.save(payload, os.path.join(save_directory, "iter_decider.bin")) -def load_iter_decider(load_directory: str, class_name: Optional[str] = None, init_args: Optional[dict] = None) -> IterDecider: - """Load iter_decider from directory.""" +def load_iter_decider( + load_directory: str, + class_name: Optional[str] = None, + init_args: Optional[dict] = None, +) -> IterDecider: path = os.path.join(load_directory, "iter_decider.bin") - if not os.path.isfile(path): - raise FileNotFoundError(f"No iter_decider found at {path}") - + raise FileNotFoundError(f"no iter_decider.bin at {path}") data = torch.load(path, map_location="cpu", weights_only=False) - if class_name is None: - class_name = data.get("class") - - if not class_name: - raise ValueError("No iter_decider class specified in saved data") - - # Get constructor arguments if available - if init_args is None: - init_args = data.get("init_args", {}) - - # Create iter_decider instance using registry with proper arguments - iter_decider_class = get_iter_decider_class(class_name) - iter_decider = iter_decider_class(**init_args) - - # Load state dict if available - natural loading - state_dict = data.get("state_dict", {}) - if state_dict: - # Filter out state_dict keys that conflict with init_args - filtered_state_dict = {} - for key, value in state_dict.items(): - if key not in init_args: - filtered_state_dict[key] = value - else: - print(f"Skipping state_dict key '{key}' as it conflicts with init_args") - print(f"Loading iter_decider state dict with {len(filtered_state_dict)} parameters (filtered from {len(state_dict)})") - if filtered_state_dict: - # print(filtered_state_dict.values()) - iter_decider.load_state_dict(filtered_state_dict, strict=False) - - return iter_decider - - + cls_name = class_name or data.get("class") + if not cls_name: + raise ValueError("iter_decider.bin lacks a class name") + args = init_args if init_args is not None else data.get("init_args", {}) + decider = get_iter_decider_class(cls_name)(**args) + sd = data.get("state_dict", {}) + if sd: + # Drop any state-dict keys that collide with init args (e.g. learnable + # threshold whose value is also passed in init_args). + sd = {k: v for k, v in sd.items() if k not in args} + if sd: + decider.load_state_dict(sd, strict=False) + return decider diff --git a/tah/model/iter_label.py b/tah/model/iter_label.py deleted file mode 100644 index 789c2ea..0000000 --- a/tah/model/iter_label.py +++ /dev/null @@ -1,218 +0,0 @@ -import torch -import torch.nn as nn -from typing import Optional - -from tah.model.registry import ( - register_iter_label_generator, - get_iter_label_generator_class, -) - - -class IterLabelGenerator(nn.Module): - """Base class for generating per-token iter-count labels. - - Contract: - - prepare(batch_size, seq_len, device, dtype): allocate internal buffers - - intra_iter_labels(...): return labels for current active tokens, and update internal full labels - - finalize(): return full (B, S) labels accumulated across iterations - """ - - def __init__(self, **kwargs): - super().__init__() - self.config = kwargs - self.full_labels = None - - def prepare(self, batch_size: int, seq_len: int, device: torch.device, dtype: torch.dtype): - self.full_labels = torch.full( - (batch_size, seq_len), fill_value=0, device=device, dtype=torch.long - ) - - @staticmethod - def _assign_active(current_iter_mask: torch.BoolTensor, src: torch.Tensor, dest: torch.Tensor) -> torch.Tensor: - """Scatter active `src` back to dense `dest` (no padding handling beyond mask).""" - B, S = current_iter_mask.shape - active_counts = current_iter_mask.sum(1) - for b in range(B): - n = int(active_counts[b].item()) - if n: - dest[b, current_iter_mask[b]] = src[b, :n] - return dest - - def intra_iter_labels( - self, - active_logits: torch.Tensor, - active_labels_shifted: Optional[torch.Tensor], - iter_depth: int, - current_iter_mask: torch.BoolTensor, - active_valid_mask: torch.LongTensor, - prompt_mask: Optional[torch.Tensor] = None, - ignore_index: int = -100, - **kwargs, - ) -> Optional[torch.LongTensor]: - raise NotImplementedError - - def finalize(self) -> Optional[torch.LongTensor]: - return self.full_labels - - -@register_iter_label_generator -class FixedIterLabelGenerator(IterLabelGenerator): - """Pass-through labels coming from dataset via `iter_count_labels`. - - The model will supply the active slice, this generator just maps prompt to ignore. - """ - - def __init__(self, ignore_index: int = -100, **kwargs): - super().__init__(**kwargs) - self.ignore_index = ignore_index - - def intra_iter_labels( - self, - active_iter_count_labels: torch.LongTensor, - current_iter_mask: torch.BoolTensor, - **kwargs, - ) -> Optional[torch.LongTensor]: - # Ensure long dtype for labels - active_iter_count_labels = active_iter_count_labels.to(dtype=torch.long) - - # Update full labels with the latest observed labels for active positions - if self.full_labels is not None and (self.full_labels.shape == (current_iter_mask.shape[0], current_iter_mask.shape[1])): - proposal = torch.zeros_like(active_iter_count_labels) - valid = (active_iter_count_labels != self.ignore_index) - proposal[valid] = active_iter_count_labels[valid] - current = self.full_labels.clone() - tmp = torch.zeros_like(self.full_labels) - tmp = self._assign_active(current_iter_mask, proposal, tmp) - self.full_labels = torch.maximum(current, tmp) - - return active_iter_count_labels - - -@register_iter_label_generator -class DynamicMismatchIterLabelGenerator(IterLabelGenerator): - """Generate per-iteration pseudo count labels based on mismatch. - - Rule at depth d (1-indexed): - - If mismatch → label = d + 1 - - Else → label = d - So that (label > d) is the desired continue target. - """ - - def __init__(self, max_iter: int = 3, **kwargs): - super().__init__(**kwargs) - self.max_iter = max_iter - - @staticmethod - def _compute_mismatch_continue(logits: torch.Tensor, labels_shifted: torch.Tensor, ignore_index: int) -> torch.BoolTensor: - # Handle causal LM shift: logits[i] predicts labels[i+1] - # We need to compare logits[:-1] with labels[1:] for proper alignment - if logits.dim() >= 2 and logits.shape[-2] > 1 and labels_shifted.shape[-1] > 1: - shifted_logits = logits[..., :-1, :] - shifted_labels = labels_shifted[..., :-1] - predicted = torch.argmax(shifted_logits, dim=-1) - mismatch = (predicted != shifted_labels) - cont = torch.cat([mismatch, torch.zeros_like(mismatch[..., :1])], dim=-1) - else: - predicted = torch.argmax(logits, dim=-1) - mismatch = (predicted != labels_shifted) - cont = mismatch - # Exclude ignore positions - valid = (labels_shifted != ignore_index) - return (cont & valid) - - def intra_iter_labels( - self, - active_logits: torch.Tensor, - active_labels_shifted: Optional[torch.Tensor], - iter_depth: int, - current_iter_mask: torch.BoolTensor, - active_valid_mask: torch.LongTensor, - prompt_mask: Optional[torch.Tensor] = None, - ignore_index: int = -100, - **kwargs, - ) -> Optional[torch.LongTensor]: - if active_labels_shifted is None or active_logits is None: - return None - - # Compute mismatch-based continue mask on the active slice - continue_mask = self._compute_mismatch_continue(active_logits, active_labels_shifted, ignore_index) - - # Only supervise valid active tokens - valid_active = (active_valid_mask == 1) - - # Build count labels for active slice - depth_tensor = torch.full_like(active_logits[..., 0], fill_value=iter_depth, dtype=torch.long) - labels_active = torch.where(continue_mask, depth_tensor + 1, depth_tensor) - # Ensure labels do not exceed max_iter - labels_active = torch.clamp(labels_active, max=self.max_iter) - - # Also ignore positions that are not valid active tokens - labels_active = labels_active.masked_fill(~valid_active.bool(), ignore_index) - labels_active = labels_active.to(dtype=torch.long) - - # Accumulate full labels: take max across depths to ensure monotonicity - if self.full_labels is not None and (self.full_labels.shape == (current_iter_mask.shape[0], current_iter_mask.shape[1])): - proposal = labels_active.clone() - proposal = proposal.masked_fill(proposal == ignore_index, 0) - current = self.full_labels.clone() - tmp = torch.zeros_like(self.full_labels) - tmp = self._assign_active(current_iter_mask, proposal, tmp) - self.full_labels = torch.maximum(current, tmp) - - return labels_active - - - - -@register_iter_label_generator -class MaxIterLabelGenerator(IterLabelGenerator): - """Always assign `max_iter` as the label for active tokens. - - Invalid or non-active positions are masked with `ignore_index`. - """ - - def __init__(self, max_iter: int = 3, **kwargs): - super().__init__(**kwargs) - self.max_iter = max_iter - - def intra_iter_labels( - self, - active_logits: Optional[torch.Tensor], - active_labels_shifted: Optional[torch.Tensor], - iter_depth: int, - current_iter_mask: torch.BoolTensor, - active_valid_mask: torch.LongTensor, - prompt_mask: Optional[torch.Tensor] = None, - ignore_index: int = -100, - **kwargs, - ) -> Optional[torch.LongTensor]: - # Determine the active slice shape and device - base_tensor: Optional[torch.Tensor] = None - if active_logits is not None: - base_tensor = active_logits[..., 0] - elif active_labels_shifted is not None: - base_tensor = active_labels_shifted - else: - base_tensor = active_valid_mask - - labels_active = torch.full( - base_tensor.shape, - fill_value=self.max_iter, - device=base_tensor.device, - dtype=torch.long, - ) - - # Only supervise valid active tokens - valid_active = (active_valid_mask == 1) - labels_active = labels_active.masked_fill(~valid_active.bool(), ignore_index) - - # Accumulate full labels for the dense (B, S) view - if self.full_labels is not None and (self.full_labels.shape == (current_iter_mask.shape[0], current_iter_mask.shape[1])): - proposal = labels_active.clone() - proposal = proposal.masked_fill(proposal == ignore_index, 0) - current = self.full_labels.clone() - tmp = torch.zeros_like(self.full_labels) - tmp = self._assign_active(current_iter_mask, proposal, tmp) - self.full_labels = torch.maximum(current, tmp) - - return labels_active diff --git a/tah/model/loss.py b/tah/model/loss.py index 050016f..8187303 100644 --- a/tah/model/loss.py +++ b/tah/model/loss.py @@ -1,11 +1,35 @@ -import torch -import torch.nn.functional as F -from typing import Dict, Any, Optional +"""Loss functions used by ``TaHForCausalLM``. + +Two implementations are kept — both used in the canonical recipes. Public TaH +also shipped ``ConsistencyLoss``, which is unused and removed. + +* :class:`NextTokenPredLoss` — standard causal-LM cross-entropy. Always + applied at the end of the iteration loop on the accumulated logits. Step-1 + SFT and eval both use this. + +* :class:`IterDeciderLoss` — per-iteration BCE supervised by ``iter_count_labels``. + Step-2 SFT uses this on top of the iter decider's continue logits to teach + it which tokens deserve another pass. + +Both expose the same three-method protocol expected by ``TaHForCausalLM``: + +* ``prepare_loss(B, T, device, dtype)`` — allocate per-token accumulators. +* ``intra_iter_loss_func(...)`` — only on intra-iter losses; called once per iteration. +* ``final_loss_func(...)`` — called once at the end of the forward. -from tah.model.registry import register_loss_func, capture_init_args, get_loss_func_class -from tah.train import weighted_cross_entropy, fixed_cross_entropy +The class-level ``_is_intra_iter_loss: bool`` lets the wrapper decide whether +to invoke ``intra_iter_loss_func``. Lookup by name is via ``LOSS_BY_NAME``. +""" +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn from transformers.utils import logging +from tah.train import fixed_cross_entropy + logger = logging.get_logger(__name__) @@ -15,38 +39,28 @@ class LossFunc: def __init__(self, **kwargs): self.config = kwargs - def prepare_loss(self, batch_size, query_len, device, dtype, **kwargs): + def prepare_loss(self, batch_size: int, query_len: int, device, dtype, **kwargs) -> None: pass def intra_iter_loss_func(self, *args, **kwargs): - raise NotImplementedError( - "This loss function does not support intra-iteration loss calculation." - ) + raise NotImplementedError def final_loss_func(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError -@register_loss_func -@capture_init_args +# ──────────────────────────────────────────────────────────────────────────── +# NextTokenPredLoss — used by step-1 SFT and eval everywhere. +# ──────────────────────────────────────────────────────────────────────────── + + class NextTokenPredLoss(LossFunc): + """Standard causal-LM cross-entropy on the final accumulated logits.""" + _is_intra_iter_loss: bool = False - def __init__(self, hard_token_relative_weight: float = 1.0, weight_hard: float = None, weight_easy: float = None, **kwargs): + def __init__(self, **kwargs): super().__init__() - self.hard_token_relative_weight = hard_token_relative_weight - self.weight_hard = weight_hard - self.weight_easy = weight_easy - - def prepare_loss(self, batch_size, query_len, device, dtype, **kwargs): - # Update weights if provided in kwargs (allows dynamic weight setting) - super().prepare_loss(batch_size, query_len, device, dtype, **kwargs) - if 'weight_hard' in kwargs and kwargs['weight_hard'] is not None: - self.weight_hard = kwargs['weight_hard'] - if 'weight_easy' in kwargs and kwargs['weight_easy'] is not None: - self.weight_easy = kwargs['weight_easy'] - if 'hard_token_relative_weight' in kwargs and kwargs['hard_token_relative_weight'] is not None: - self.hard_token_relative_weight = kwargs['hard_token_relative_weight'] def final_loss_func( self, @@ -56,310 +70,138 @@ def final_loss_func( training: bool, **kwargs, ) -> torch.Tensor: + del iter_count, training # accepted for protocol uniformity num_items_in_batch = kwargs.get("num_items_in_batch", None) - - vocab_size = logits.shape[-1] - - logits = logits.float() # upcast to float to avoid precision issue, following transformers official implementation - shift_iter = iter_count.contiguous() if not (iter_count == -1).all() else None - - shift_logits = logits.view(-1, vocab_size).float() - shift_labels = labels_shifted.view(-1) - shift_iter = shift_iter.view(-1) if shift_iter is not None else None - - shift_labels = shift_labels.to(shift_logits.device) - ignore_index = -100 - has_custom_weights = ( - self.weight_hard is not None and self.weight_easy is not None + # upcast for numerical parity with HF's cross-entropy path + flat_logits = logits.view(-1, logits.shape[-1]).float() + flat_labels = labels_shifted.view(-1).to(flat_logits.device) + return fixed_cross_entropy( + flat_logits, flat_labels, + num_items_in_batch=num_items_in_batch, ignore_index=-100, ) - if self.hard_token_relative_weight == 1.0 or not training: - return fixed_cross_entropy( - shift_logits, - shift_labels, - num_items_in_batch=num_items_in_batch, - ignore_index=ignore_index, - ) - else: - weight_hard = ( - self.weight_hard - if has_custom_weights - else self.hard_token_relative_weight - ) - weight_easy = self.weight_easy if has_custom_weights else 1.0 - - token_weights = torch.full_like( - shift_labels, weight_easy, dtype=shift_logits.dtype - ) - if shift_iter is not None: - token_weights[shift_iter > 1] = weight_hard - - return weighted_cross_entropy( - shift_logits, - shift_labels, - token_weights, - num_items_in_batch=num_items_in_batch, - ignore_index=ignore_index, - ) - - -@register_loss_func -@capture_init_args -class ConsistencyLoss(LossFunc): - _is_intra_iter_loss: bool = True - - def __init__(self, **kwargs): - from tah.model.recurrent_transformer import TaHForCausalLM # import like this to avoid circular import - self.assign_active = TaHForCausalLM.assign_active - super().__init__(**kwargs) - - def prepare_loss(self, batch_size, query_len, device, dtype, **kwargs): - self.consistency_loss_per_token = torch.zeros( - batch_size, query_len, device=device, dtype=torch.float32 - ) # noqa: always use float32 for loss - - def intra_iter_loss_func( - self, - active_logits: torch.Tensor, - current_iter_mask: torch.BoolTensor, - active_labels_shifted: torch.Tensor, - **kwargs, - ): - if self.consistency_loss_per_token is None: - raise RuntimeError( - "Consistency loss tensor not initialized. Call `init_consistency_loss` first." - ) - - batch_size, query_len = current_iter_mask.shape - device = active_logits.device - active_logits = active_logits.float() # upcast to float to avoid precision issue, following transformers official implementation - - token_losses = torch.zeros( - batch_size, query_len, device=device, dtype=active_logits.dtype - ) - - if not current_iter_mask.any() or active_labels_shifted is None: - return torch.tensor(0.0, device=active_logits.device, dtype=active_logits.dtype) - - flat_active_logits = active_logits.view(-1, active_logits.size(-1)) - flat_active_labels = active_labels_shifted.view(-1) - - flat_losses = F.cross_entropy( - flat_active_logits, - flat_active_labels, - reduction="none", - ignore_index=-100, - ) - - active_losses_reshaped = flat_losses.view(batch_size, -1) - self.assign_active(current_iter_mask, active_losses_reshaped, token_losses) - self._update_consistency_loss(token_losses) - - return token_losses - - def _update_consistency_loss(self, token_losses): - self.consistency_loss_per_token = token_losses + self.consistency_loss_per_token - - def final_loss_func( - self, - labels_shifted: torch.Tensor, - iter_count: torch.Tensor, - training: bool, - **kwargs, - ) -> torch.Tensor: - if self.consistency_loss_per_token is None: - raise RuntimeError( - "Consistency loss tensor not initialized or already consumed." - ) - - num_items_in_batch = kwargs.get("num_items_in_batch", None) - - valid_mask = (labels_shifted != -100) & (iter_count > 0) - - consistency_loss = self.consistency_loss_per_token - self.consistency_loss_per_token = None # Consume the loss - - if not valid_mask.any(): - return torch.tensor( - 0.0, device=labels_shifted.device, dtype=consistency_loss.dtype - ) - avg_losses = torch.zeros_like(consistency_loss) - avg_losses[valid_mask] = ( - consistency_loss[valid_mask] / iter_count[valid_mask].float() - ).to(dtype=consistency_loss.dtype) +# ──────────────────────────────────────────────────────────────────────────── +# IterDeciderLoss — per-iteration BCE on continue logits, used by step-2 SFT. +# ──────────────────────────────────────────────────────────────────────────── - if num_items_in_batch is not None: - return avg_losses[valid_mask].sum() / num_items_in_batch - else: - return avg_losses[valid_mask].mean() -@register_loss_func -@capture_init_args class IterDeciderLoss(LossFunc): + """BCE on the iter decider's continue logits, supervised by iter_count_labels. + + At iteration depth ``d``, the per-token target is ``(iter_count_labels > d)``. + ``skip_last_iter`` skips the loss at ``d == max_iter`` (always-stop step). + Optional ``pos_weight`` rescales the positive class to handle imbalance. """ - Loss function for iter decider that predicts whether each token should continue iterating. - Uses BCE loss similar to the router training implementation. - Calculates loss at each iteration depth. - """ + _is_intra_iter_loss: bool = True def __init__(self, pos_weight: Optional[float] = None, skip_last_iter: bool = True, max_iter: Optional[int] = None, **kwargs): - """ - Initialize IterDeciderLoss. - - Args: - pos_weight: Positive class weight for BCE loss to handle class imbalance - skip_last_iter: If True, skip loss at the max iteration because it's always stop - """ - from tah.model.recurrent_transformer import TaHForCausalLM # import like this to avoid circular import - self.assign_active = TaHForCausalLM.assign_active - super().__init__(**kwargs) + super().__init__() self.pos_weight = pos_weight self.skip_last_iter = bool(skip_last_iter) - # Optional explicit max_iter (preferred over reading from model at call time) - self.max_iter: Optional[int] = int(max_iter) if max_iter is not None else None - + self.max_iter = int(max_iter) if max_iter is not None else None if self.skip_last_iter and self.max_iter is None: - raise ValueError("max_iter must be provided if skip_last_iter is True") - - # Create BCE loss criterion - if pos_weight is not None: - self.criterion = torch.nn.BCEWithLogitsLoss( - pos_weight=torch.tensor([pos_weight]) - ) - else: - self.criterion = torch.nn.BCEWithLogitsLoss() - + raise ValueError("max_iter must be set when skip_last_iter is True") + self.criterion = nn.BCEWithLogitsLoss( + pos_weight=torch.tensor([pos_weight]) if pos_weight is not None else None + ) + # State allocated by prepare_loss. + self.iter_decider_loss_per_token: Optional[torch.Tensor] = None + self._metric_correct_count: Optional[torch.Tensor] = None + self._metric_total_count: Optional[torch.Tensor] = None + def prepare_loss(self, batch_size, query_len, device, dtype, **kwargs): - self.iter_decider_loss_per_token = torch.zeros( - batch_size, query_len, device=device, dtype=torch.float32 - ) # always use float32 for loss - # Metric accumulators (float32 scalars on device) + # Always float32 for losses, regardless of model dtype. + self.iter_decider_loss_per_token = torch.zeros(batch_size, query_len, device=device, dtype=torch.float32) self._metric_correct_count = torch.zeros(1, device=device, dtype=torch.float32) self._metric_total_count = torch.zeros(1, device=device, dtype=torch.float32) def intra_iter_loss_func( self, - active_logits: torch.Tensor, + active_logits, current_iter_mask: torch.BoolTensor, - active_labels_shifted: torch.Tensor, + active_labels_shifted, active_valid_continue_logits: Optional[torch.Tensor], active_valid_mask: torch.LongTensor, iter_depth: int, active_iter_count_labels: Optional[torch.LongTensor] = None, **kwargs, ): - """ - Calculate iter decider loss at each iteration depth. - - Args: - active_logits: Model logits (not used) - current_iter_mask: Mask for current iteration - active_labels_shifted: Shifted labels (not used) - active_valid_continue_prob: Continue probabilities from iter_decider - active_valid_continue_logits: Continue logits from iter_decider - active_valid_mask: Valid mask for active tokens - iter_depth: Current iteration depth - active_iter_count_labels: Target iteration counts - **kwargs: Additional arguments - """ - if active_iter_count_labels is None or active_valid_continue_logits is None: - return torch.tensor(0.0, dtype=torch.float32) - - if not current_iter_mask.any() or active_valid_mask.sum() == 0: + del active_logits, active_labels_shifted # not used by this loss + if ( + active_iter_count_labels is None + or active_valid_continue_logits is None + or not current_iter_mask.any() + or active_valid_mask.sum() == 0 + or (self.skip_last_iter and int(iter_depth) >= int(self.max_iter)) + ): return torch.tensor(0.0, dtype=torch.float32) - # Optionally skip loss for the final iteration depth (always-stop step) - if self.skip_last_iter: - if int(iter_depth) >= int(self.max_iter): - return torch.tensor(0.0, dtype=torch.float32) - - # Update metrics using probabilities derived from logits and threshold - if active_iter_count_labels is not None and active_valid_continue_logits is not None: - valid_active_mask = (active_valid_mask == 1) - valid_iter_count_labels = active_iter_count_labels[valid_active_mask] - non_padding_mask = (valid_iter_count_labels != -100) - if non_padding_mask.any(): - final_continue_targets = (valid_iter_count_labels[non_padding_mask] > iter_depth).to(torch.float32) - final_continue_probs = torch.sigmoid(active_valid_continue_logits[non_padding_mask]).to(torch.float32) - # Resolve threshold value - iter_decider_threshold = kwargs.get('iter_decider_threshold', 0.5) - if isinstance(iter_decider_threshold, torch.Tensor): - threshold_value = float(iter_decider_threshold.detach().item()) - else: - threshold_value = float(iter_decider_threshold) - - with torch.no_grad(): - pred_positive = (final_continue_probs > threshold_value).to(torch.float32) - target_positive = final_continue_targets - correct = (pred_positive == target_positive).to(torch.float32).sum() - total = torch.tensor(float(pred_positive.numel()), device=final_continue_probs.device, dtype=torch.float32) - - if hasattr(self, '_metric_correct_count') and self._metric_correct_count is not None: - self._metric_correct_count += correct - self._metric_total_count += total - - active_valid_continue_logits = active_valid_continue_logits.float() - device = active_valid_continue_logits.device - dtype = active_valid_continue_logits.dtype - batch_size, query_len = current_iter_mask.shape - - if self.iter_decider_loss_per_token is None: - raise RuntimeError( - "Iter decider loss tensor not initialized. Call `prepare_loss` first." - ) - # Initialize token losses for this iteration - token_losses = torch.zeros( - batch_size, query_len, device=device, dtype=dtype - ) + # Restrict to valid active tokens that have a non-padding iter label. + valid_active = active_valid_mask == 1 + labels_at_active = active_iter_count_labels[valid_active] + non_padding = labels_at_active != -100 + if not non_padding.any(): + return torch.tensor(0.0, device=device, dtype=torch.float32) + targets = (labels_at_active[non_padding] > iter_depth).float() + used_logits = active_valid_continue_logits[non_padding].float() - # Calculate target labels: should continue if iter_count_labels > iter_depth - # Only consider valid active tokens - valid_active_mask = (active_valid_mask == 1) - # For valid active tokens, calculate binary targets - valid_iter_count_labels = active_iter_count_labels[valid_active_mask] - valid_continue_targets = (valid_iter_count_labels > iter_depth).float() - - # Exclude padding tokens (-100) - non_padding_mask = (valid_iter_count_labels != -100) - if not non_padding_mask.any(): - return torch.tensor(0.0, device=device, dtype=dtype) - - final_continue_targets = valid_continue_targets[non_padding_mask] - final_continue_logits = active_valid_continue_logits[non_padding_mask] - - # Move pos_weight to correct device if needed - if hasattr(self.criterion, 'pos_weight') and self.criterion.pos_weight is not None: - self.criterion.pos_weight = self.criterion.pos_weight.to(device=device) + self._record_threshold_accuracy(used_logits, targets, kwargs.get("iter_decider_threshold", 0.5)) - # Calculate BCE loss - loss = self.criterion(final_continue_logits.unsqueeze(-1), final_continue_targets.unsqueeze(-1)) - - # Assign loss back to full tensor structure - # This is simplified - we assign the same loss to all valid active tokens - if valid_active_mask.any() and non_padding_mask.any(): - # Create a tensor to hold loss for active tokens - active_token_losses = torch.zeros(batch_size, active_valid_mask.shape[1], device=device, dtype=loss.dtype) - # We'll assign the average loss to all contributing tokens - num_contributing_tokens = non_padding_mask.sum() - if num_contributing_tokens > 0: - per_token_loss = loss / num_contributing_tokens - # Create a full-size tensor for valid active positions - valid_positions = torch.zeros_like(active_token_losses, dtype=torch.bool) - valid_positions[valid_active_mask] = non_padding_mask - active_token_losses[valid_positions] = per_token_loss - self.assign_active(current_iter_mask, active_token_losses, token_losses) - - # Update cumulative loss - self._update_iter_decider_loss(token_losses) - + if self.criterion.pos_weight is not None: + self.criterion.pos_weight = self.criterion.pos_weight.to(device=device) + loss = self.criterion(used_logits.unsqueeze(-1), targets.unsqueeze(-1)) + + # Spread the per-iter loss evenly across contributing positions so that + # final_loss_func's sum-then-divide gives the correct per-token average, + # and scatter back into the dense (B, T) accumulator. + per_token = loss / float(used_logits.numel()) + token_losses = self._scatter_per_token( + per_token, valid_active, non_padding, current_iter_mask, active_valid_mask, device, + ) + self.iter_decider_loss_per_token = self.iter_decider_loss_per_token + token_losses return token_losses - def _update_iter_decider_loss(self, token_losses): - self.iter_decider_loss_per_token = token_losses + self.iter_decider_loss_per_token + def _record_threshold_accuracy( + self, logits: torch.Tensor, targets: torch.Tensor, threshold, + ) -> None: + """Update running counts of correct iter-decider predictions.""" + if isinstance(threshold, torch.Tensor): + threshold = float(threshold.detach().item()) + with torch.no_grad(): + preds = (torch.sigmoid(logits) > float(threshold)).float() + self._metric_correct_count += (preds == targets).float().sum() + self._metric_total_count += float(logits.numel()) + + @staticmethod + def _scatter_per_token( + per_token: torch.Tensor, + valid_active: torch.Tensor, + non_padding: torch.Tensor, + current_iter_mask: torch.BoolTensor, + active_valid_mask: torch.LongTensor, + device: torch.device, + ) -> torch.Tensor: + """Place ``per_token`` (a scalar) at every (valid + non_padding) active + position in a fresh dense ``(B, T)`` tensor.""" + from tah.model.tah_model import scatter_back # local: avoid import cycle + + active_losses = torch.zeros(active_valid_mask.shape, device=device, dtype=torch.float32) + # valid_positions[b, k] = True iff the k-th active token in row b is + # both valid and non-padding (the same subset that contributed to the + # BCE numerator above). + valid_positions = torch.zeros_like(active_losses, dtype=torch.bool) + valid_positions[valid_active] = non_padding + active_losses[valid_positions] = per_token + + token_losses = torch.zeros( + current_iter_mask.shape, device=device, dtype=torch.float32, + ) + scatter_back(current_iter_mask, src=active_losses, dest=token_losses, in_place=True) + return token_losses def final_loss_func( self, @@ -370,69 +212,52 @@ def final_loss_func( training: bool = True, **kwargs, ) -> torch.Tensor: - """ - Calculate final iter decider loss from accumulated losses. - - Args: - logits: Model logits (not used) - labels_shifted: Shifted labels (not used) - iter_count: Actual iteration counts from model - iter_count_labels: Target iteration count labels (optional) - training: Whether in training mode - **kwargs: Additional arguments - - Returns: - Accumulated iter decider loss - """ + del labels_shifted, training if self.iter_decider_loss_per_token is None: - raise RuntimeError( - "Iter decider loss tensor not initialized or already consumed." - ) + raise RuntimeError("prepare_loss has not been called") - num_items_in_batch = kwargs.get("num_items_in_batch", None) - - # Use iter_count_labels if available, otherwise fall back to simple validation - if iter_count_labels is not None: - valid_mask = (iter_count_labels != -100) & (iter_count > 0) - else: - valid_mask = (iter_count > 0) - - iter_decider_loss = self.iter_decider_loss_per_token - self.iter_decider_loss_per_token = None # Consume the loss - - if not valid_mask.any(): - return torch.tensor( - 0.0, device=logits.device, dtype=iter_decider_loss.dtype - ) - - # Compute and log metrics if requested - logger_callback = kwargs.get('logger_callback', None) + accumulated = self.iter_decider_loss_per_token + self.iter_decider_loss_per_token = None # consume + + valid = (iter_count_labels != -100) & (iter_count > 0) if iter_count_labels is not None else (iter_count > 0) + if not valid.any(): + return torch.tensor(0.0, device=logits.device, dtype=accumulated.dtype) + + # Optional accuracy logging via callback (kept for parity with the trainer). + callback = kwargs.get("logger_callback") with torch.no_grad(): - if hasattr(self, '_metric_total_count') and self._metric_total_count is not None and ( - (self._metric_total_count.item() > 0) or (kwargs.get('num_items_in_batch', None) is not None) - ): - # Accuracy logging: follow avg_iter_count pattern → correct_count / num_items_in_batch - if logger_callback is not None: - if not hasattr(logger_callback, 'iter_decider_accuracy'): - logger_callback.iter_decider_accuracy = 0.0 - - num_items_in_batch = kwargs.get('num_items_in_batch', None) - if num_items_in_batch is not None and num_items_in_batch > 0: - acc_step = (self._metric_correct_count / num_items_in_batch) - else: - # Fallback to total-based accuracy if num_items_in_batch is absent - total_safe = torch.clamp(self._metric_total_count, min=1.0) - acc_step = (self._metric_correct_count / total_safe) - logger_callback.iter_decider_accuracy += float(acc_step) - - # Reset metric accumulators after consumption + if callback is not None and self._metric_total_count is not None: + if not hasattr(callback, "iter_decider_accuracy"): + callback.iter_decider_accuracy = 0.0 + num_items = kwargs.get("num_items_in_batch") + if num_items: + callback.iter_decider_accuracy += float(self._metric_correct_count / num_items) + else: + callback.iter_decider_accuracy += float( + self._metric_correct_count / torch.clamp(self._metric_total_count, min=1.0) + ) self._metric_correct_count = None self._metric_total_count = None - # Calculate average loss over valid tokens - if num_items_in_batch is not None: - return iter_decider_loss[valid_mask].sum() / num_items_in_batch - else: - return iter_decider_loss[valid_mask].mean() - + num_items = kwargs.get("num_items_in_batch") + if num_items is not None: + return accumulated[valid].sum() / num_items + return accumulated[valid].mean() + + +# ──────────────────────────────────────────────────────────────────────────── +# Name-based dispatch (replaces the old registry system). +# ──────────────────────────────────────────────────────────────────────────── + + +LOSS_BY_NAME = { + "NextTokenPredLoss": NextTokenPredLoss, + "IterDeciderLoss": IterDeciderLoss, +} + +def get_loss_func_class(name: str): + """Lookup helper for backwards-compatible callers.""" + if name not in LOSS_BY_NAME: + raise ValueError(f"Unknown loss {name!r}; have {sorted(LOSS_BY_NAME)}") + return LOSS_BY_NAME[name] diff --git a/tah/model/output_updater.py b/tah/model/output_updater.py deleted file mode 100644 index 35ec8fd..0000000 --- a/tah/model/output_updater.py +++ /dev/null @@ -1,114 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, TYPE_CHECKING - -import torch -import torch.nn as nn - -if TYPE_CHECKING: - from tah.model.recurrent_transformer import TaHForCausalLM - -from tah.model.registry import register_output_updater, get_output_updater_class, capture_init_args - - -class OutputUpdater(nn.Module, ABC): - """ - Base class for updating output logits between iterations. - - This class is designed to efficiently handle tensors of arbitrary shape (..., vocab_size), - where the leading dimensions can be any combination of batch, sequence, or other - dimensions. All operations preserve the leading dimensions and only operate on - the last dimension for vocabulary operations. - """ - - @abstractmethod - def forward( - self, - logits: torch.Tensor, - prev_logits: Optional[torch.Tensor] = None, - iter_depth: int = 0, - **kwargs - ) -> torch.Tensor: - """ - Return updated logits for accumulation. - - This method efficiently handles tensors of arbitrary shape, preserving all - leading dimensions while operating only on the vocabulary dimension. - - Args: - logits: The current iteration logits, shape (..., vocab_size) - prev_logits: The previous accumulated logits, shape (..., vocab_size) or None for first iteration - iter_depth: Current iteration depth (0-indexed) - **kwargs: Additional arguments - - Returns: - The updated accumulated logits, shape (..., vocab_size) - - Note: - All leading dimensions (...) are preserved exactly. The implementation - is optimized for efficient processing regardless of the number or size - of leading dimensions (e.g., batch size, sequence length, etc.). - """ - - -@register_output_updater -@capture_init_args -class NoneUpdater(OutputUpdater): - """ - No-op output updater that returns current logits without accumulation. - This is the default behavior to maintain backward compatibility. - """ - - def __init__(self): - super().__init__() - - def forward( - self, - logits: torch.Tensor, - prev_logits: Optional[torch.Tensor] = None, - iter_depth: int = 0, - **kwargs - ) -> torch.Tensor: - """Simply return current logits without any accumulation.""" - return logits - - -@register_output_updater -@capture_init_args -class AdditiveLogitsUpdater(OutputUpdater): - """ - Additive output updater that accumulates logits across iterations. - - On the first iteration (prev_logits is None), returns current logits. - On subsequent iterations, returns prev_logits + current logits. - This allows the model to learn residual corrections to the output. - """ - - def __init__(self): - super().__init__() - - def forward( - self, - logits: torch.Tensor, - prev_logits: Optional[torch.Tensor] = None, - iter_depth: int = 0, - **kwargs - ) -> torch.Tensor: - """ - Accumulate logits additively. - - Args: - logits: Current iteration logits (..., vocab_size) - prev_logits: Previous accumulated logits (..., vocab_size) or None - iter_depth: Current iteration depth (0-indexed) - - Returns: - Accumulated logits (..., vocab_size) - """ - if prev_logits is None: - # First iteration: return current logits as-is - return logits - else: - # Subsequent iterations: add to accumulated logits - return prev_logits + logits - - diff --git a/tah/model/recurrent_transformer.py b/tah/model/recurrent_transformer.py deleted file mode 100755 index a3e87ab..0000000 --- a/tah/model/recurrent_transformer.py +++ /dev/null @@ -1,1323 +0,0 @@ -""" -TaH (Hierarchical Recurrent Reasoning) Transformer Wrapper. - -This module wraps around standard transformer PreTrainedModel (e.g., Qwen3ForCausalLM) -to enable hierarchical recurrent processing with iteration-aware caching. -""" - -import torch -import torch.nn.functional as F -import json -import os -from dataclasses import asdict, dataclass, fields -from typing import Optional, Union, Tuple, Dict, Any, List, Union -from transformers import PreTrainedModel, AutoModelForCausalLM -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.utils import logging -from accelerate import dispatch_model - -from tah.model.causal_cache import TaHCache -from tah.model.utils import ( - get_attr_recursive, - dict_string_to_type, - type_to_dict_string, - get_device_map, -) -from tah.model.iter_decider import ( - save_iter_decider, - load_iter_decider, - get_iter_decider_class, -) -from tah.model.input_updater import get_input_updater_class -from tah.model.output_updater import get_output_updater_class -from tah.model.loss import get_loss_func_class -from tah.model.iter_label import get_iter_label_generator_class -from tah.model.tah_config import TaHConfig -from tah.model.adapter import ( - setup_adapter, - save_adapter, - load_adapter, - configure_lora_for_iteration -) - -logger = logging.get_logger(__name__) -# # Ensure INFO level logging is enabled for this module -# logging.set_verbosity_info() - - -@dataclass -class TaHCausalLMOutputWithPast(CausalLMOutputWithPast): - """ - Inherit from CausalLMOutputWithPast, add iter_count. - - Args: - iter_count (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Number of iterations performed for each token in the sequence. - """ - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - iter_count: Optional[torch.LongTensor] = None - iter_count_labels: Optional[torch.LongTensor] = None - - -class TaHForCausalLM(PreTrainedModel): - """ - TaH wrapper for Causal Language Models that enables hierarchical recurrent processing. - - This wrapper takes a standard transformer model (e.g., Qwen3ForCausalLM) and adds - support for iterative processing where: - 1. Each token can be processed multiple times (based on iter_count) - 2. The output of iteration i becomes the input of iteration i+1 - 3. Deeper iterations can see cache from all previous iterations - 4. Previous iterations cannot see cache from future iterations - """ - - def __init__( - self, base_model: PreTrainedModel, config: Optional[TaHConfig] = None, **kwargs - ): - """ - Initialize TaH wrapper. - - Args: - base_model: The base transformer model to wrap (e.g., Qwen3ForCausalLM) - max_iter: Maximum number of iterations in automatic mode - iter_decider: Plug-in object that decides whether a token continues - input_updater: Module that updates inputs between iterations - """ - self.base_model._supports_sdpa = True - super().__init__(base_model.config) - self.config = base_model.config - self.supports_gradient_checkpointing = True - - if config is None: - config = TaHConfig() - - self.tah_config = config - - # Check the embedding key - try: - get_attr_recursive(base_model, config.embedding_key) - except AttributeError: - raise ValueError( - f"Embedding_key {config.embedding_key} not found in base model" - ) - self.embedding_key = config.embedding_key - - self.max_iter = config.max_iter - - # Build iter_decider from config - # Create decider from class and kwargs - decider_cls = get_iter_decider_class(config.iter_decider) - config.iter_decider_kwargs["max_iter"] = self.max_iter - self.iter_decider = decider_cls(**config.iter_decider_kwargs) - - # Optional: build a separate iter_decider for evaluation/inference - # Falls back to training decider when not provided - eval_iter_decider = getattr(config, "eval_iter_decider", None) - if eval_iter_decider is not None: - resolved = None - if isinstance(eval_iter_decider, str): - # Support hierarchical path referencing the built training iter_decider - # Example: "iter_decider.primary_iter_decider.final_iter_decider" - if eval_iter_decider.startswith("iter_decider"): - path = eval_iter_decider.split(".") - obj = self - for seg in path: - if not seg: - continue - if seg == "self": - obj = self - else: - obj = getattr(obj, seg) - resolved = obj - # Class-name path - else: - eval_decider_cls = get_iter_decider_class(eval_iter_decider) - resolved = eval_decider_cls(**getattr(config, 'eval_iter_decider_kwargs', {})) - - self.eval_iter_decider = resolved if resolved is not None else self.iter_decider - else: - self.eval_iter_decider = self.iter_decider - - # Build input_updater from config - # Create updater from class and kwargs - updater_cls = get_input_updater_class(config.input_updater) - self.input_updater = updater_cls(**config.input_updater_kwargs) - - # Build output_updater from config - # Create output updater from class and kwargs, - output_updater_cls = get_output_updater_class(config.output_updater or 'NoneUpdater') - self.output_updater = output_updater_cls(**config.output_updater_kwargs) - - # Build loss func from config - # Robustly pass max_iter when supported by the loss (wrapper-safe): try with max_iter, fallback without - train_loss_func_cls = get_loss_func_class(config.train_loss) - _train_kwargs = dict(getattr(config, 'train_loss_kwargs', {}) or {}) - try: - self.train_loss = train_loss_func_cls(**{**_train_kwargs, "max_iter": self.max_iter}) - except TypeError as e: - if "max_iter" in str(e): - self.train_loss = train_loss_func_cls(**_train_kwargs) - else: - raise - - # Build eval loss func from config - if config.eval_loss: - eval_loss_func_cls = get_loss_func_class(config.eval_loss) - _eval_kwargs = dict(getattr(config, 'eval_loss_kwargs', {}) or {}) - try: - self.eval_loss = eval_loss_func_cls(**{**_eval_kwargs, "max_iter": self.max_iter}) - except TypeError as e: - if "max_iter" in str(e): - self.eval_loss = eval_loss_func_cls(**_eval_kwargs) - else: - raise - else: - self.eval_loss = self.train_loss - - # Build iter label generator from config (constructed here, prepared per forward) - iter_label_generator_name = getattr(config, "iter_label_generator", None) or "FixedIterLabelGenerator" - iter_label_generator_kwargs = getattr(config, "iter_label_generator_kwargs", None) or {} - IterLabelGenCls = get_iter_label_generator_class(iter_label_generator_name) - self.iter_label_generator = IterLabelGenCls(**iter_label_generator_kwargs) - - # Init base model - self.simple_base_model = base_model - - # iter attention mode - self.iter_attention_mode = config.iter_attention_mode - - - # Setup adapter if enabled - self._setup_adapter(config) - - # Tokens that require multiple iterations (iter_count > 1) are considered "hard". - # Their loss will be multiplied by this factor during training. 1.0 means no reweighting. - self.hard_token_relative_weight = 1.0 - self.avg_hard_token_ratio = None - self.weight_hard = None - self.weight_easy = None - - # TODO: Ensure input_updater is on the same device and dtype as the base model - device_map = kwargs.pop("device_map", None) - if device_map is not None: - device_map = get_device_map(self, device_map, self.dtype) - dispatch_model_kwargs = { - "device_map": device_map, - "offload_dir": None, - "offload_index": None, - "offload_buffers": False, - # "skip_keys": ["past_key_values"], - "skip_keys": self.simple_base_model._skip_keys_device_placement - } - self = dispatch_model(self, **dispatch_model_kwargs) - - def _setup_adapter(self, config: TaHConfig): - """Setup adapter for the model (delegated).""" - setup_adapter(self, config) - - def _configure_lora_for_iteration(self, iter_depth: int): - """Configure LoRA adapters for the current iteration (delegated).""" - configure_lora_for_iteration(self, iter_depth) - - - @property - def device(self) -> torch.device: - """ - `torch.device`: The device on which the module is (assuming that all the module parameters are on the same - device). - """ - # Since TaH is a wrapper without its own parameters, delegate to base model - return self.simple_base_model.device - - def to(self, *args, **kwargs): - """ - Move the model to the specified device/dtype. Delegates to the base model. - """ - self.simple_base_model = self.simple_base_model.to(*args, **kwargs) - if hasattr(self, "input_updater") and self.input_updater is not None: - self.input_updater = self.input_updater.to(*args, **kwargs) - if hasattr(self, "output_updater") and self.output_updater is not None: - self.output_updater = self.output_updater.to(*args, **kwargs) - if hasattr(self, "iter_decider") and self.iter_decider is not None: - self.iter_decider = self.iter_decider.to(*args, **kwargs) - if hasattr(self, "eval_iter_decider") and self.eval_iter_decider is not None: - self.eval_iter_decider = self.eval_iter_decider.to(*args, **kwargs) - if hasattr(self, "iter_label_generator") and self.iter_label_generator is not None: - self.iter_label_generator = self.iter_label_generator.to(*args, **kwargs) # type: ignore[attr-defined] - return self - - def cuda(self, device=None): - """ - Move the model to CUDA. Delegates to the base model. - """ - self.simple_base_model = self.simple_base_model.cuda(device) - if hasattr(self, "input_updater") and self.input_updater is not None: - self.input_updater = self.input_updater.cuda(device) - if hasattr(self, "output_updater") and self.output_updater is not None: - self.output_updater = self.output_updater.cuda(device) - if hasattr(self, "iter_decider") and self.iter_decider is not None: - self.iter_decider = self.iter_decider.cuda(device) - if hasattr(self, "eval_iter_decider") and self.eval_iter_decider is not None: - self.eval_iter_decider = self.eval_iter_decider.cuda(device) - if hasattr(self, "iter_label_generator") and self.iter_label_generator is not None: - self.iter_label_generator = self.iter_label_generator.cuda(device) - return self - - def cpu(self): - """ - Move the model to CPU. Delegates to the base model. - """ - self.simple_base_model = self.simple_base_model.cpu() - if hasattr(self, "input_updater") and self.input_updater is not None: - self.input_updater = self.input_updater.cpu() - if hasattr(self, "output_updater") and self.output_updater is not None: - self.output_updater = self.output_updater.cpu() - if hasattr(self, "iter_decider") and self.iter_decider is not None: - self.iter_decider = self.iter_decider.cpu() - if hasattr(self, "eval_iter_decider") and self.eval_iter_decider is not None: - self.eval_iter_decider = self.eval_iter_decider.cpu() - if hasattr(self, "iter_label_generator") and self.iter_label_generator is not None: - self.iter_label_generator = self.iter_label_generator.cpu() - return self - - @property - def embed_tokens(self): - """Return the embedding layer from the base model.""" - if "lora" in self.adapter: - return get_attr_recursive(self.simple_base_model.base_model.model, self.embedding_key) - return get_attr_recursive(self.simple_base_model, self.embedding_key) - - def forward( - self, - input_ids: torch.LongTensor, - iter_count: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[TaHCache] = None, - # input_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - iter_count_labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = False, # noqa - new_sequence: Optional[bool] = False, # used by oracle iter decider - # cache_position: Optional[torch.LongTensor] = None, - # logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs, - ) -> CausalLMOutputWithPast: - """ - Forward pass with hierarchical recurrent processing. - - Warning: iter_count will be deprecated in future versions. Iteration will be fully controlled by the iter decider. - - Args: - input_ids: Token IDs of shape (batch_size, seq_len) - iter_count: Number of iterations for each token of shape (batch_size, query_seq_len) - attention_mask: Optional attention mask, with shape (batch_size, total_seq_len) - position_ids: Optional position ids, with shape (batch_size, query_seq_len) - labels: Optional labels for loss computation - use_cache: Whether to use cache - output_attentions: Whether to output attention weights - output_hidden_states: Whether to output hidden states - **kwargs: Additional arguments - - Returns: - CausalLMOutputWithPast with results from final iteration - """ - - """ - Initializations - """ - # TODO: support other functions of Transformers - assert (output_attentions is None) or (output_attentions is False), "TaH does not support output_attentions" - assert (output_hidden_states is None) or (output_hidden_states is False), "TaH does not support output_hidden_states" - - # shift labels - if labels is not None: - labels_shifted = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() - labels_all_shifted = F.pad(input_ids.clone(), (0, 1), value=-100)[..., 1:].contiguous() # includes query tokens uniformly - else: - labels_shifted = None - labels_all_shifted = None - - max_iterations = self.max_iter - - # Initialize scalars and tensors - batch_size, query_len = input_ids.shape - vocab_size = self.config.vocab_size - hidden_size = self.config.hidden_size - use_cache = use_cache if use_cache is not None else self.config.use_cache - - input_embeds = self.embed_tokens( - input_ids - ) # (batch_size, query_len, hidden_size) - dtype = input_embeds.dtype - device = input_embeds.device - final_output_logits = torch.zeros( - batch_size, query_len, vocab_size, device=device, dtype=dtype - ) # (batch_size, query_len, vocab_size) - cumulative_logits = torch.zeros( - batch_size, query_len, vocab_size, device=device, dtype=dtype - ) # (batch_size, query_len, vocab_size) - for output updater accumulation - actual_iter_counts = torch.zeros_like( - input_ids, dtype=torch.long - ) # (batch_size, query_len) - - # Initialize cache - if past_key_values is not None: - cache = past_key_values - else: - cache = TaHCache().to(device=device, dtype=dtype) # noqa - - # Initialize token mask - if attention_mask is not None: - valid_mask = attention_mask.clone()[:, -query_len:].to(dtype=torch.long) - assert valid_mask.shape == ( - batch_size, - query_len, - ), f"attention_mask shape must be (batch_size, seq_len), but got {attention_mask.shape}" - else: - valid_mask = torch.ones_like(input_ids, dtype=torch.long) - - # Initialize position_ids - if position_ids is None: - # position_offset = cache.get_seq_length() # Layer 0 Iter 0 cache length - # TODO: design more efficient ways to get seq length of each batch - position_ids = torch.clamp( - torch.cumsum( - torch.cat( - ( - cache.get_valid_mask_upto_iter( - layer_idx=0, upto_iter_idx=0, init_batch_size=batch_size - ).to(device), - valid_mask, - ), - dim=-1, - ), - dim=-1, - )[:, -query_len:] - - 1, - min=0, - ) - else: - position_ids = position_ids.clone() - - # Initialize loss func for the forward pass - loss_func = self.train_loss if self.training else self.eval_loss - loss_func.prepare_loss( - batch_size, query_len, device, dtype, - weight_hard=self.weight_hard, - weight_easy=self.weight_easy, - hard_token_relative_weight=self.hard_token_relative_weight - ) - - # Decide whether to use iter label generator this forward - use_iter_labeling = (self.iter_label_generator is not None) and (labels_shifted is not None) - # Prepare iter label generator buffers (train or eval) only when labeling is used - if use_iter_labeling: - self.iter_label_generator.prepare(batch_size, query_len, device, dtype) - - - """ - Iterative processing - """ - current_iter_mask = torch.ones_like( - input_ids, dtype=torch.bool - ) # (batch_size, query_len), 1 if the element is selected for current iter - finished_mask = torch.zeros_like( - current_iter_mask, dtype=torch.bool - ) # default to unfinished - iter_depth = 0 - - while iter_depth < max_iterations and current_iter_mask.any(): - - # Configure LoRA for current iteration if not already done - self._configure_lora_for_iteration(iter_depth) - - # Extract sparse inputs for active tokens - # TODO: modify the extraction logic to handle the token mask - active_input_embeds, active_cumulative_logits, active_position_ids, active_valid_mask, active_iter_count, active_labels_shifted, active_iter_count_labels, active_labels_all_shifted = ( - self.to_active( - current_iter_mask, input_embeds, cumulative_logits, position_ids, valid_mask, iter_count, labels_shifted, iter_count_labels, labels_all_shifted - ) - ) - - - # Break if no active tokens - if active_valid_mask.shape[1] == 0: - break - - # Create SDPA attention mask - sdpa_attention_mask = self.create_TaH_sdpa_attention_mask( - active_position_ids, active_valid_mask, cache, iter_depth, dtype=dtype - ) - - active_outputs = self._process_sparse_iteration( - sparse_input=active_input_embeds, - position_ids=active_position_ids, - valid_mask=active_valid_mask, - cache_position=None, # cache position not used for now - attention_mask=sdpa_attention_mask, - iter_depth=iter_depth, - past_key_values=cache, - use_cache=True if iter_depth < max_iterations - 1 else use_cache, - output_attentions=output_attentions, - output_hidden_states=True, # noqa: output_hidden_states must be True to get last hidden for iter decider - model=self.simple_base_model, # Pass the selected model to _process_sparse_iteration - **kwargs, - ) - - # Update iter depth - iter_depth += 1 - - # noqa: Update output device, when device map is auto, output device may be different from the input one. Move back to input device - active_outputs.logits = active_outputs.logits.to(device=device) - - # Update cumulative logits using output updater - # For first iteration, active_prev_logits will be zeros, which output_updater handles correctly - active_updated_cumulative_logits = self.output_updater( - logits=active_outputs.logits, - prev_logits=active_cumulative_logits, - iter_depth=iter_depth - 1, # iter_depth was incremented above, so use iter_depth - 1 for 0-indexed - ) - - # Write updated logits back to cumulative_logits - # Use assign_active_no_inplace to avoid in-place modification that breaks autograd - cumulative_logits = self.assign_active_no_inplace( - current_iter_mask, - src=active_updated_cumulative_logits, - dest=cumulative_logits - ) - - all_hidden = None - if hasattr(active_outputs, "hidden_states") and active_outputs.hidden_states is not None: - hidden_states = active_outputs.hidden_states - # hidden_states can be a tuple(list) of length = num_layers, each with shape (B, T, H) - # Convert to a tensor shaped (B, T, L, H) so that boolean mask of shape (B, T) applies cleanly - if isinstance(hidden_states, (tuple, list)): - layer_hidden_list = [h.to(device=device) for h in hidden_states] - if len(layer_hidden_list) > 0: - all_hidden = torch.stack(layer_hidden_list, dim=0).permute(1, 2, 0, 3) - elif torch.is_tensor(hidden_states): - # If already a tensor (B, T, H), add layer dim to unify to (B, T, L=1, H) - all_hidden = hidden_states.to(device=device) - if all_hidden.dim() == 3: - all_hidden = all_hidden.unsqueeze(-2) - - # TODO: improve efficiency - # We always call iter_decider, even if the count is given, to get the continue prob for all tokens - # Choose decider based on train/eval mode - cur_iter_decider = self.iter_decider if self.training else (self.eval_iter_decider or self.iter_decider) - - # Optionally compute per-iteration labels from generator and unify naming - if use_iter_labeling: - active_iter_count_labels = self.iter_label_generator.intra_iter_labels( - active_iter_count_labels=active_iter_count_labels, - active_logits=active_updated_cumulative_logits, - active_labels_shifted=active_labels_all_shifted, - iter_depth=iter_depth, - current_iter_mask=current_iter_mask, - active_valid_mask=active_valid_mask, - ) - - active_valid_continue_decision, active_valid_continue_logits = cur_iter_decider( - logits=active_updated_cumulative_logits[active_valid_mask == 1], - iter_depth=iter_depth, - all_hidden_states=all_hidden[active_valid_mask == 1] if all_hidden is not None else None, - labels_shifted=active_labels_all_shifted[active_valid_mask == 1] if active_labels_all_shifted is not None else None, # used by mismatch iter decider - iter_count_labels=(active_iter_count_labels[active_valid_mask == 1] if active_iter_count_labels is not None else None), - ) - - - # Ensure at least one labeled position continues when all labeled decisions are False - if ( - (active_labels_shifted is not None) - and (active_valid_continue_decision is not None) - and (active_valid_continue_decision.numel() > 0) - and iter_depth < self.max_iter - ): - label_mask_flat = (active_labels_shifted != -100)[active_valid_mask == 1] - if label_mask_flat.any() and (not active_valid_continue_decision[label_mask_flat].any()): - candidate_indices = torch.nonzero(label_mask_flat, as_tuple=False).flatten() - chosen_idx = candidate_indices[torch.randint(low=0, high=candidate_indices.numel(), size=(1,), device=device)] - active_valid_continue_decision[chosen_idx] = True - - # Move tensors to correct device - if active_valid_continue_logits is not None: - active_valid_continue_logits = active_valid_continue_logits.to(device=device) - - # decide whether to finish current iteration - active_finished_mask = torch.ones_like(active_valid_mask, dtype=torch.bool) - # When explicit boolean decision is provided: finish where decision == False - active_finished_mask[active_valid_mask == 1] = (~active_valid_continue_decision) - self.assign_active( - current_iter_mask, src=active_finished_mask, dest=finished_mask - ) - - actual_iter_counts[current_iter_mask] += 1 - - # Calculate loss for all active tokens in current iteration - if labels_shifted is not None and loss_func._is_intra_iter_loss: - # Prepare kwargs for intra_iter_loss_func, including iter_depth and active_iter_count_labels - intra_loss_kwargs = kwargs.copy() - intra_loss_kwargs['iter_depth'] = iter_depth - # pass iter_decider threshold for metric computation - intra_loss_kwargs['iter_decider_threshold'] = cur_iter_decider.threshold - # provide model handle for potential freeze control in loss - intra_loss_kwargs['model'] = self - # forward global_step if provided by caller - if 'global_step' in kwargs: - intra_loss_kwargs['global_step'] = kwargs['global_step'] - # Use unified active_iter_count_labels for BCE targets if present - if active_iter_count_labels is not None: - intra_loss_kwargs['active_iter_count_labels'] = active_iter_count_labels - if all_hidden is not None: - intra_loss_kwargs['all_hidden_states'] = all_hidden - - # Compute loss for all currently active tokens - loss_func.intra_iter_loss_func( - active_logits=active_updated_cumulative_logits, - current_iter_mask=current_iter_mask, - active_labels_shifted=active_labels_shifted, - active_valid_continue_logits=active_valid_continue_logits, - active_valid_mask=active_valid_mask, - **intra_loss_kwargs - ) - - if active_finished_mask.any(): - # Update actual iteration counts for tokens that finish - # Copy accumulated logits from active_updated_cumulative_logits to final_output_logits for finished tokens - self.assign_active_with_mask( - current_iter_mask, - assignment_mask=finished_mask, - src=active_updated_cumulative_logits, - dest=final_output_logits, - ) - - next_iter_mask = ( - (~finished_mask) & current_iter_mask & (valid_mask == 1) - ) # iter mask can accept invalid inputs; currently, it filter to accept only valid for efficiency concideration - if next_iter_mask.any(): - active_next_iter_mask = (~active_finished_mask) & (active_valid_mask == 1) - # Always pass all_hidden_states to input_updater; selection handled inside updater - active_input_embeds[active_next_iter_mask] = self.input_updater( - logits = active_updated_cumulative_logits[active_next_iter_mask], - prev_inputs = active_input_embeds[active_next_iter_mask], - embedding_weight = self.embed_tokens.weight, # should ignored by AlignDeviceHook and avoid weight moving - hidden_states = all_hidden[active_next_iter_mask] if all_hidden is not None else None, - ).to(device=device) - - # Clone to prevent in-place modification on a tensor that autograd still needs for previous iterations - input_embeds = torch.zeros_like(input_embeds) - self.assign_active_with_mask( - current_iter_mask, - assignment_mask=next_iter_mask, - src=active_input_embeds, - dest=input_embeds, - ) - - # Create mask for tokens that need processing at this iteration - current_iter_mask = next_iter_mask - - if not current_iter_mask.any(): - break - - # Compute loss if labels are provided - loss = None - if labels_shifted is not None: - # Prepare kwargs for loss function, including iter_count_labels if available - loss_kwargs = kwargs.copy() - # Finalize generator-produced full labels for logging/loss if requested - finalized_iter_labels = None - if use_iter_labeling: - finalized_iter_labels = self.iter_label_generator.finalize() - # expose for analysis/logging - # TODO: log final iter labels - # if hasattr(self, 'logger_callback') and finalized_iter_labels is not None: - # self.logger_callback.last_iter_count_labels = finalized_iter_labels.detach().to('cpu') - if finalized_iter_labels is not None: - loss_kwargs['iter_count_labels'] = finalized_iter_labels - elif iter_count_labels is not None: - loss_kwargs['iter_count_labels'] = iter_count_labels - # pass logger callback for metric logging if available - if hasattr(self, 'logger_callback'): - loss_kwargs['logger_callback'] = self.logger_callback - # provide model handle for potential freeze control in loss - loss_kwargs['model'] = self - # forward global_step if provided by caller - if 'global_step' in kwargs: - loss_kwargs['global_step'] = kwargs['global_step'] - - loss = loss_func.final_loss_func( - logits=final_output_logits, - labels_shifted=labels_shifted, - iter_count=actual_iter_counts, - training=self.training, - **loss_kwargs - ) - - if hasattr(self, "logger_callback"): - num_items_in_batch = kwargs.get("num_items_in_batch", None) - if num_items_in_batch is not None: - valid_iter_mask = (labels_shifted.detach() != -100) - valid_iter_counts = actual_iter_counts.detach()[valid_iter_mask] - avg_valid_iter_count = torch.sum(valid_iter_counts).float() - self.logger_callback.avg_iter_count += float((avg_valid_iter_count / num_items_in_batch).item()) - else: - self.logger_callback.avg_iter_count = float((torch.mean(actual_iter_counts.detach().float())).item()) - - - # Create custom output that includes actual iteration counts - output = TaHCausalLMOutputWithPast( - loss=loss, - logits=final_output_logits, - past_key_values=cache if use_cache else None, - hidden_states=None, - attentions=None, - iter_count=actual_iter_counts, - iter_count_labels=finalized_iter_labels if 'finalized_iter_labels' in locals() else None, - ) - - return output - - @staticmethod - def to_active( - current_iter_mask: torch.BoolTensor, - input_embeds: torch.Tensor, - cumulative_logits: torch.Tensor, - position_ids: torch.LongTensor, - valid_mask: torch.LongTensor, - iter_count: Optional[torch.LongTensor], - labels_shifted: Optional[torch.LongTensor] = None, - iter_count_labels: Optional[torch.LongTensor] = None, - labels_all_shifted: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, torch.LongTensor, torch.LongTensor, Union[torch.LongTensor, None], Union[torch.LongTensor, None], torch.BoolTensor, Union[torch.LongTensor, None]]: - """ - Return the active tokens (padded to the batch-wise max length). - - * active_input_embeds : (B, max_active_len, H) - * active_cumulative_logits : (B, max_active_len, V) - * active_position_ids : (B, max_active_len) - * active_valid_mask : (B, max_active_len) – propagates `valid_mask` - * active_iter_count : (B, max_active_len) – propagates `iter_count` - * active_labels_shifted : (B, max_active_len) – propagates `labels`, None if labels is None - * active_iter_count_labels : (B, max_active_len) – propagates `iter_count_labels`, None if iter_count_labels is None - """ - B, S, H = input_embeds.shape - _, _, V = cumulative_logits.shape - device = input_embeds.device - - active_per_seq = current_iter_mask.sum(1) # (B,) - max_len = int(active_per_seq.max()) # scalar - if max_len == 0: # nothing active - empty_e = input_embeds.new_empty(B, 0, H) - empty_i = position_ids.new_empty(B, 0) - empty_mask = torch.empty(B, 0, dtype=torch.bool, device=device) - return empty_e, empty_i, empty_i, None, None, None, None, None - - # ------------------------------------------------------------------ - # 1. Build gather_idx_clamped and pad_mask - # ------------------------------------------------------------------ - SENTINEL = S # out-of-range value - base_idx = torch.arange(S, device=device).expand(B, S) - base_idx = base_idx.masked_fill(~current_iter_mask, SENTINEL) # (B, S) - - # Stable sort → [active … | SENTINEL …] - sorted_idx, _ = torch.sort(base_idx, dim=1, stable=True) # (B, S) - gather_idx = sorted_idx[:, :max_len] # (B, max_len) - pad_mask = gather_idx.eq(SENTINEL) # True → padded - - # Same index, but clamped so `gather` is always in-range - gather_idx_clamped = gather_idx.clamp(max=S - 1) - - # ------------------------------------------------------------------ - # 2. Vectorised gather and zero-out - # ------------------------------------------------------------------ - active_input_embeds = torch.gather( - input_embeds, 1, gather_idx_clamped.unsqueeze(-1).expand(-1, -1, H) - ) # (B, max_len, H) - # Avoid in-place modification on tensors that autograd needs for the backward - active_input_embeds = active_input_embeds.masked_fill( - pad_mask.unsqueeze(-1), 0 - ) - - active_cumulative_logits = torch.gather( - cumulative_logits, 1, gather_idx_clamped.unsqueeze(-1).expand(-1, -1, V) - ) # (B, max_len, V) - active_cumulative_logits = active_cumulative_logits.masked_fill(pad_mask.unsqueeze(-1), 0) - - active_position_ids = torch.gather(position_ids, 1, gather_idx_clamped) - active_position_ids = active_position_ids.masked_fill(pad_mask, 0) - - active_valid_mask = torch.gather(valid_mask, 1, gather_idx_clamped) - active_valid_mask = active_valid_mask.masked_fill(pad_mask, 0) - - if iter_count is not None: - active_iter_count = torch.gather(iter_count, 1, gather_idx_clamped) - active_iter_count = active_iter_count.masked_fill(pad_mask, 0) - else: - active_iter_count = None - - if iter_count_labels is not None: - active_iter_count_labels = torch.gather(iter_count_labels, 1, gather_idx_clamped) - active_iter_count_labels = active_iter_count_labels.masked_fill(pad_mask, 0) - else: - active_iter_count_labels = None - - if labels_shifted is not None: - active_labels_shifted = torch.gather(labels_shifted, 1, gather_idx_clamped) - active_labels_shifted = active_labels_shifted.masked_fill(pad_mask, -100) - else: - active_labels_shifted = None - - if labels_all_shifted is not None: - active_labels_all_shifted = torch.gather(labels_all_shifted, 1, gather_idx_clamped) - active_labels_all_shifted = active_labels_all_shifted.masked_fill(pad_mask, -100) - else: - active_labels_all_shifted = None - - return active_input_embeds, active_cumulative_logits, active_position_ids, active_valid_mask, active_iter_count, active_labels_shifted, active_iter_count_labels, active_labels_all_shifted - - @staticmethod - def assign_active( - current_iter_mask: torch.BoolTensor, - src: torch.Tensor, - dest: torch.Tensor, - pad_value: float | int = 0, - ) -> torch.Tensor: - """ - Scatter `src` (the output of `extract_active`) back into a dense tensor. - - Args: - current_iter_mask : BoolTensor (B, S) - True where a position should be filled from `src`. - src : Tensor (B, max_active, ...) - Active tokens, padded on the right inside the second dimension. - dest : Tensor (B, S, ...) - Tensor to be updated **in-place**. - pad_value : scalar - Value written to inactive (False) positions. - - Returns: - dest : Tensor (B, S, ...) — same object that was passed in - """ - B, S = current_iter_mask.shape - max_active = src.shape[1] - - active_counts = current_iter_mask.sum(1) # (B,) - - for b in range(B): - n = active_counts[b].item() - if n: # only copy when there is something to copy - dest[b, current_iter_mask[b]] = src[b, :n] - - return dest - - @staticmethod - def assign_active_no_inplace( - current_iter_mask: torch.BoolTensor, - src: torch.Tensor, - dest: torch.Tensor, - pad_value: float | int = 0, - ) -> torch.Tensor: - """ - Scatter `src` (the output of `extract_active`) back into a dense tensor without in-place modification. - - Args: - current_iter_mask : BoolTensor (B, S) - True where a position should be filled from `src`. - src : Tensor (B, max_active, ...) - Active tokens, padded on the right inside the second dimension. - dest : Tensor (B, S, ...) - Tensor to be updated (will be cloned, not modified in-place). - pad_value : scalar - Value written to inactive (False) positions. - - Returns: - new_dest : Tensor (B, S, ...) — new tensor with updates applied - """ - B, S = current_iter_mask.shape - max_active = src.shape[1] - - # Clone dest to avoid in-place modification - new_dest = dest.clone() - active_counts = current_iter_mask.sum(1) # (B,) - - for b in range(B): - n = active_counts[b].item() - if n: # only copy when there is something to copy - new_dest[b, current_iter_mask[b]] = src[b, :n] - - return new_dest - - @staticmethod - def assign_active_with_mask( - current_iter_mask: torch.BoolTensor, - assignment_mask: torch.BoolTensor, - src: torch.Tensor, - dest: torch.Tensor, - pad_value: float | int = 0, - ) -> torch.Tensor: - """ - Scatter the masked `src` (the output of `extract_active`) back into a dense tensor. - - Args: - current_iter_mask : BoolTensor (B, S) - True where a position should be filled from `src`. - assignment_mask : BoolTensor (B, S) - True where a position should be filled from `src` to `dest`. - src : Tensor (B, max_active, ...) - Active tokens, padded on the right inside the second dimension. - dest : Tensor (B, S, ...) - Tensor to be updated **in-place**. - pad_value : scalar - Value written to inactive (False) positions. - - Returns: - dest : Tensor (B, S, ...) — same object that was passed in - """ - B, S = current_iter_mask.shape - - # Only assign where both masks are True - final_mask = current_iter_mask & assignment_mask - active_counts = current_iter_mask.sum(1) # (B,) - - for b in range(B): - n_active = active_counts[b].item() - if n_active == 0: - continue - - # Get active positions and assignment positions for this batch - active_pos = current_iter_mask[b].nonzero(as_tuple=False).flatten() - assign_pos = final_mask[b].nonzero(as_tuple=False).flatten() - - # Find which src indices correspond to assignment positions - src_indices = torch.searchsorted(active_pos, assign_pos) - valid_mask = src_indices < n_active - - if valid_mask.any(): - dest[b, assign_pos[valid_mask]] = src[b, src_indices[valid_mask]] - - return dest - - def _compute_positions_for_iteration( - self, active_position_ids: torch.Tensor, seq_length: int, cache_length: int - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Compute position_ids and cache_position for active tokens. - - Args: - active_position_ids: Original positions of active tokens (batch_size, num_active) - seq_length: Current sequence length (for new token position computation) - cache_length: Current total length of KV cache - - Returns: - position_ids: Adjusted positions for positional encoding (batch_size, num_active) - cache_position: Sequential positions in growing cache (num_active,) - """ - _, num_active = active_position_ids.shape - - # position_ids are the sequence positions (already correct from _extract_active_inputs) - position_ids = active_position_ids - - # cache_position: sequential positions starting from cache_length - cache_position = torch.arange( - cache_length, - cache_length + num_active, - device=active_position_ids.device, - dtype=torch.long, - ) - - return position_ids, cache_position - - def create_TaH_sdpa_attention_mask( - self, - active_position_ids: torch.Tensor, - active_valid_mask: torch.LongTensor, - cache: Optional[TaHCache], - iter_depth: int, - dtype: torch.dtype = torch.bfloat16, - ) -> Optional[torch.Tensor]: - """ - Create SDPA attention mask where query at position p, iteration i - The query can attend to cached KVs with position <= p AND iteration <= i. - Note that the mask should have the same shape as the updated cache, which only contains the KVs with iteration <= i. - The mask is added to the attention score, with min_dtype = torch.finfo(dtype).min being the masked part. - - Args: - active_position_ids: Original positions of active tokens (batch_size, query_length) - active_valid_mask: Mask indicating valid active tokens (batch_size, query_length) - cache: Current cache object - iter_depth: Current iteration depth - dtype: Data type for the attention mask - - Returns: - Attention mask of shape (batch_size, 1, query_length, filtered_cache_length + query_length) or None - """ - batch_size, query_length = active_position_ids.shape - device = active_position_ids.device - - # Get filtered cache positions (only iterations <= iter_depth) - if (cache is not None) and (0 in cache._tah_position_id_cache): - iter_index_cache = cache.get_cache_iter_index_upto_iter( - layer_idx=0, upto_iter_idx=iter_depth - ) - position_ids_cache_upto_iter = cache.get_position_id_upto_iter( - layer_idx=0, upto_iter_idx=iter_depth, init_batch_size=batch_size - ) - valid_mask_cache_upto_iter = cache.get_valid_mask_upto_iter( - layer_idx=0, upto_iter_idx=iter_depth, init_batch_size=batch_size - ) # TODO: implement - kv_cache_length_upto_iter = iter_index_cache.shape[-1] - else: - iter_index_cache = torch.empty(size=(0,), device=device, dtype=torch.long) - position_ids_cache_upto_iter = torch.empty( - size=(batch_size, 0), device=device, dtype=torch.long - ) - valid_mask_cache_upto_iter = torch.empty( - size=(batch_size, 0), device=device, dtype=torch.long - ) - kv_cache_length_upto_iter = 0 - - # KV length for attention computation equals to the KV length from cache plus the current key/value length (=query length) - kv_length_this_iter = kv_cache_length_upto_iter + query_length - - if kv_length_this_iter == 0: - return None - - min_dtype = torch.finfo(dtype).min - - # Build complete KV position list: filtered cache + new positions being added - kv_position_ids_upto_iter = torch.cat( - (position_ids_cache_upto_iter, active_position_ids), dim=-1 - ) # shape: (batch_size, query_length + kv_cache_length_upto_iter) - # Extract only valid positions based on active_valid_mask - kv_valid_mask_upto_iter = torch.cat( - (valid_mask_cache_upto_iter, active_valid_mask), dim=-1 - ) # shape: (batch_size, query_length + kv_cache_length_upto_iter) - kv_iter_index = torch.cat( - ( - iter_index_cache, - torch.full( - (query_length,), iter_depth, dtype=torch.long, device=device - ), - ), - dim=-1, - )[ - None, : - ] # shape: (1-batch-size, query_length + kv_cache_length_upto_iter) - - # Expand query positions and iterations for broadcasting - query_positions = active_position_ids[ - :, :, None - ] # (batch_size, query_length) -> (batch_size, query_length, 1) - kv_position_ids_upto_iter = kv_position_ids_upto_iter[ - :, None, : - ] # (batch_size, total_kv_length) -> (batch_size, 1, total_kv_length) - kv_valid_mask_upto_iter = kv_valid_mask_upto_iter[ - :, None, : - ] # (batch_size, total_kv_length) -> (batch_size, 1, total_kv_length) - - - if self.iter_attention_mode == "duo": - query_iter_index = torch.full_like( - query_positions, iter_depth - ) # (batch_size, query_length, 1) - elif self.iter_attention_mode == "root": - query_iter_index = torch.full_like( - query_positions, 0 - ) # (batch_size, query_length, 1) - elif self.iter_attention_mode == "same_iter": - query_iter_index = torch.full_like( - query_positions, iter_depth - ) # (batch_size, query_length, 1) - else: - raise ValueError(f"Invalid iter attention mode: {self.iter_attention_mode}") - - kv_iter_index = kv_iter_index[:, None, :] # (1, 1, total_kv_length) - - # Vectorized rule: query at (position=p, iter=i) can see cache entry at (position=cp, iter=ci) - # if and only if cp <= p AND ci <= i - position_mask = ( - kv_position_ids_upto_iter <= query_positions - ) # (batch_size, query_length, total_kv_length) - if self.iter_attention_mode == "same_iter": - iteration_mask = ( - kv_iter_index == query_iter_index - ) # (batch_size, query_length, total_kv_length) - else: - iteration_mask = ( - kv_iter_index <= query_iter_index - ) # (batch_size, query_length, total_kv_length) - valid_mask = (kv_valid_mask_upto_iter == 1) # (batch_size, 1, total_kv_length) - - # Combine both conditions - bool_attention_mask = ( - position_mask & iteration_mask & valid_mask - ) # (batch_size, query_length, total_kv_length) - - # Create attention mask - start with all masked (min_dtype), then unmask where can_attend is True - attention_mask = torch.full( - (batch_size, query_length, kv_length_this_iter), - min_dtype, - device=device, - dtype=dtype, - ) - attention_mask[bool_attention_mask] = 0.0 # unmasked - - return attention_mask[:, None, :, :] - - def _process_sparse_iteration( - self, - sparse_input: torch.Tensor, - position_ids: torch.Tensor, - valid_mask: torch.LongTensor, - cache_position: torch.Tensor, - attention_mask: torch.Tensor, - iter_depth: int, - past_key_values: Optional[TaHCache], - use_cache: bool, - output_attentions: bool, - output_hidden_states: bool, - model: Optional[PreTrainedModel] = None, - **kwargs, - ) -> CausalLMOutputWithPast: - """ - Process a single iteration through the base model with active/sparse inputs. - - Args: - sparse_input: Active input embeddings (batch_size, num_active, hidden_size) - position_ids: Active position ids (batch_size, num_active) - valid_mask: Long tensor mask indicating the padding scenario of the original input tokens. 0 means masked. - cache_position: Sequential positions in cache (num_active,) - attention_mask: SDPA attention mask (batch_size, num_active, total_kv_length) - iter_depth: Current iteration depth - past_key_values: Cache object - use_cache: Whether to use cache - output_attentions: Whether to output attention weights - output_hidden_states: Whether to output hidden states - **kwargs: Additional arguments - - Returns: - Model output for this iteration - """ - - # Set iteration depth and position metadata in cache - if past_key_values is not None: - past_key_values.current_iter_depth = iter_depth - past_key_values.position_ids_to_cache = position_ids - past_key_values.valid_mask_to_cache = valid_mask - - # Process through base model with active inputs - outputs = model( - inputs_embeds=sparse_input, - position_ids=position_ids, - cache_position=cache_position, # noqa: not used for now - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - **kwargs, - ) - - return outputs - - def save_pretrained(self, save_directory, **kwargs): - """ - Save the TaH model by directly saving the base model to avoid wrapper prefixes. - Also saves the TaHConfig for automatic loading. - - Args: - save_directory: Directory where to save the model - **kwargs: Additional arguments for saving - """ - # Save adapter and base model - save_adapter(self, save_directory, **kwargs) - # Save iter_decider - # If iter_decider is a noise wrapper, save its base (do not persist wrapper) - try: - from tah.model.iter_decider import NoisyWrapperIterDecider - iter_to_save = self.iter_decider.base_iter_decider if isinstance(self.iter_decider, NoisyWrapperIterDecider) else self.iter_decider - except Exception: - iter_to_save = self.iter_decider - save_iter_decider(iter_to_save, save_directory) - - - # Save TaH config with special handling for type objects - config_dict = asdict(self.tah_config) - serializable_config = type_to_dict_string(config_dict) - - config_path = os.path.join(save_directory, "tah_config.json") - with open(config_path, "w", encoding="utf-8") as f: - json.dump(serializable_config, f, indent=2, ensure_ascii=False) - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: str, - *args, - tah_config: Optional[TaHConfig] = None, - **kwargs, - ): - """ - Load a pretrained TaH model. - - Args: - pretrained_model_name_or_path: Path to the saved TaH model directory - tah_config: Optional TaHConfig to override specific saved config values - *args, **kwargs: Arguments for model loading - - Returns: - TaHForCausalLM instance - """ - # Move to device after initializations are all done - device_map = kwargs.pop("device_map", None) - - # Load saved config from checkpoint if it exists - config_path = os.path.join(pretrained_model_name_or_path, "tah_config.json") - saved_config = None - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - config_dict = json.load(f) - - # Convert serialized type objects back to actual types - config_dict = dict_string_to_type(config_dict) - - # Filter out keys that are not valid TaHConfig fields - valid_fields = {f.name for f in fields(TaHConfig)} - config_dict = {k: v for k, v in config_dict.items() if k in valid_fields} - - saved_config = TaHConfig(**config_dict) - logger.info(f"Loaded TaH config from {config_path}") - - # Determine final config by selectively overriding saved config with provided config - if tah_config is not None: - if saved_config is not None: - # Start with saved config and override specific fields from provided config - final_config_dict = asdict(saved_config) - provided_config_dict = asdict(tah_config) - - # Override only non-None values from provided config - for key, value in provided_config_dict.items(): - if (value is not None) and (value != {}): - final_config_dict[key] = value - logger.info( - f"Overriding config field '{key}' with provided value: {value}" - ) - - final_config = TaHConfig(**final_config_dict) - else: - # No saved config, use provided config - final_config = tah_config - logger.info("No saved config found, using provided tah_config") - else: - if saved_config is not None: - # Use saved config - final_config = saved_config - else: - # No saved config and no provided config, use default - logger.warning( - f"No tah_config.json found in {pretrained_model_name_or_path} and no tah_config provided. " - "Using default TaHConfig." - ) - final_config = TaHConfig() - - # Load base model - base_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path, *args, **kwargs - ) - - # Load tah model - iter_decider_path = None - if "load_path" in final_config.iter_decider_kwargs: - iter_decider_path = final_config.iter_decider_kwargs.pop("load_path") - - # Create TaH model - tah_model = cls(base_model, config=final_config) - - # Reload adapter specific weights/models (delegated) - load_adapter(tah_model, pretrained_model_name_or_path, final_config, *args, **kwargs) - - # Decide whether to skip loading iter_decider weights based on class difference - skip_iter_decider_loading = False - load_base_iter_decider = False - if 'saved_config' in locals() and (saved_config is not None): - if getattr(saved_config, 'iter_decider', None) != final_config.iter_decider: - skip_iter_decider_loading = True - # If new base_iter_decider_cls equals the old model's iter_decider class, still load from old path - old_iter_decider_cls_name = getattr(saved_config, 'iter_decider', None) - final_kwargs = getattr(final_config, 'iter_decider_kwargs', None) - if isinstance(final_kwargs, dict): - final_base = final_kwargs.get('base_iter_decider_cls') - if (final_base is not None) and (final_base == old_iter_decider_cls_name): - skip_iter_decider_loading = False - load_base_iter_decider = True - - # Load iter_decider - if iter_decider_path is not None: - loaded_iter_decider = load_iter_decider(iter_decider_path, class_name=final_config.iter_decider, init_args=final_config.iter_decider_kwargs) - tah_model.iter_decider = loaded_iter_decider - logger.info("Loaded iter_decider from newly provided load_path") - else: - if skip_iter_decider_loading: - logger.info("Detected different iter_decider class; skipping old weight loading and using new iter_decider from final_config") - else: - - if load_base_iter_decider: - loaded_iter_decider = load_iter_decider( - pretrained_model_name_or_path, - class_name=final_config.iter_decider_kwargs.get('base_iter_decider_cls', None), - init_args=final_config.iter_decider_kwargs.get('base_iter_decider_kwargs', {}) - ) - tah_model.iter_decider.base_iter_decider = loaded_iter_decider - else: - loaded_iter_decider = load_iter_decider(pretrained_model_name_or_path, class_name=final_config.iter_decider, init_args=final_config.iter_decider_kwargs) - tah_model.iter_decider = loaded_iter_decider - logger.info("Loaded iter_decider from model checkpoint") - - # Load eval_iter_decider - eval_iter_decider = getattr(final_config, "eval_iter_decider", None) - if eval_iter_decider is not None: - resolved = None - if isinstance(eval_iter_decider, str): - # Support hierarchical path referencing the built training iter_decider - # Example: "iter_decider.primary_iter_decider.final_iter_decider" - if eval_iter_decider.startswith("iter_decider"): - path = eval_iter_decider.split(".") - obj = tah_model - for seg in path: - if not seg: - continue - if seg == "self": - obj = tah_model - else: - obj = getattr(obj, seg) - resolved = obj - # Class-name path - else: - eval_decider_cls = get_iter_decider_class(eval_iter_decider) - resolved = eval_decider_cls(**getattr(final_config, 'eval_iter_decider_kwargs', {})) - - tah_model.eval_iter_decider = resolved if resolved is not None else tah_model.iter_decider - else: - tah_model.eval_iter_decider = tah_model.iter_decider - - - # Move to device if device map is provided - if device_map is not None: - device_map = get_device_map(tah_model, device_map, tah_model.dtype) - dispatch_model_kwargs = { - "device_map": device_map, - "offload_dir": None, - "offload_index": None, - "offload_buffers": False, - "skip_keys": tah_model.simple_base_model._skip_keys_device_placement - } - tah_model = dispatch_model(tah_model, **dispatch_model_kwargs) - - return tah_model diff --git a/tah/model/registry.py b/tah/model/registry.py deleted file mode 100644 index f8123b6..0000000 --- a/tah/model/registry.py +++ /dev/null @@ -1,134 +0,0 @@ -import inspect -from typing import Dict, Type, Callable, Optional, Tuple, TypeVar, TYPE_CHECKING - -if TYPE_CHECKING: - from tah.model.iter_decider import IterDecider - from tah.model.input_updater import InputUpdater - from tah.model.output_updater import OutputUpdater - from tah.model.iter_label import IterLabelGenerator - from tah.model.loss import LossFunc - -T = TypeVar('T') - -def create_registry(registry_name: str, case_insensitive: bool = False) -> Tuple[Dict[str, Type[T]], Callable, Callable[[str], Type[T]]]: - """ - Create a registry system with register and get functions. - - Args: - registry_name: Name for error messages - case_insensitive: Whether to store lowercase versions of names - - Returns: - Tuple of (registry_dict, register_function, get_function) - """ - registry: Dict[str, Type[T]] = {} - - def register(cls_or_name=None, name: Optional[str] = None): - """Register a class in the registry. Supports multiple usage patterns.""" - def _register(c: Type[T]) -> Type[T]: - # Determine the name to use - if isinstance(cls_or_name, str): - # Called as @register("name") - class_name = cls_or_name - elif name is not None: - # Called as @register(name="name") or register(cls, name="name") - class_name = name - else: - # Use class name - class_name = c.__name__ - - # Store in registry - registry[class_name] = c - if case_insensitive: - registry[class_name.lower()] = c - return c - - if cls_or_name is not None and not isinstance(cls_or_name, str): - # Called as @register or register(cls) - return _register(cls_or_name) - else: - # Called as @register("name") or @register(name="name") - return _register - - def get_class(name: str) -> Type[T]: - """Get class by name from registry.""" - key = name if name in registry else (name.lower() if case_insensitive else name) - if key not in registry: - available = list(registry.keys()) - raise ValueError(f"Unknown {registry_name} class: {name}. Available: {available}") - return registry[name] - - return registry, register, get_class - - -# Create model registry (case insensitive for backward compatibility) -ITER_DECIDER_REGISTRY, register_iter_decider, get_iter_decider_class = create_registry("iter_decider", case_insensitive=True) - -# Create updater registry -INPUT_UPDATER_REGISTRY, register_input_updater, get_input_updater_class = create_registry("input_updater", case_insensitive=True) - -# Create loss func registry -LOSS_FUNC_REGISTRY, register_loss_func, get_loss_func_class = create_registry("loss_func", case_insensitive=True) - -# Create output updater registry -OUTPUT_UPDATER_REGISTRY, register_output_updater, get_output_updater_class = create_registry("output_updater", case_insensitive=True) - -# Create iter label generator registry -ITER_LABEL_GENERATOR_REGISTRY, register_iter_label_generator, get_iter_label_generator_class = create_registry("iter_label_generator", case_insensitive=True) - - -# Add specific type annotations for the get functions -if TYPE_CHECKING: - def get_iter_decider_class(name: str) -> Type["IterDecider"]: ... - def get_input_updater_class(name: str) -> Type["InputUpdater"]: ... - def get_output_updater_class(name: str) -> Type["OutputUpdater"]: ... - def get_loss_func_class(name: str) -> Type["LossFunc"]: ... - def get_iter_label_generator_class(name: str) -> Type["IterLabelGenerator"]: ... - - -def capture_init_args(cls): - """ - Decorator to capture initialization arguments of a model class. - - Args: - cls: The class to decorate - - Returns: - The decorated class with automatic init args capture - """ - original_init = cls.__init__ - - def new_init(self, *args, **kwargs): - # Store all initialization arguments - self._init_args = {} - - # Get parameter names from the original __init__ method - sig = inspect.signature(original_init) - param_names = list(sig.parameters.keys())[1:] # Skip 'self' - - # Map positional args to parameter names - for i, arg in enumerate(args): - if i < len(param_names): - self._init_args[param_names[i]] = arg - - # Add keyword args - self._init_args.update(kwargs) - - # Call the original __init__ - original_init(self, *args, **kwargs) - - cls.__init__ = new_init - return cls - - -def mark_wrapper_iter_decider(cls): - """Decorator to mark an IterDecider as a wrapper over another decider. - - This flag allows builder logic (e.g., InterleavedIterDecider) to detect that the - target class expects a base-decider instance or class and to perform special wiring. - """ - try: - setattr(cls, "_is_wrapper_iter_decider", True) - except Exception: - pass - return cls \ No newline at end of file diff --git a/tah/model/tah_config.py b/tah/model/tah_config.py index 13c7ac5..6515f69 100644 --- a/tah/model/tah_config.py +++ b/tah/model/tah_config.py @@ -1,29 +1,51 @@ +"""Persistent configuration for ``TaHForCausalLM``. + +Public TaH carried a config field per pluggable component (input_updater, +output_updater, iter_label_generator, iter_attention_mode). In tah-release +those slots have a single implementation each, inlined into the wrapper, so +their config fields are no longer load-bearing and are dropped here. + +Old checkpoints whose ``tah_config.json`` still contains those keys load fine: +``TaHForCausalLM.from_pretrained`` filters the JSON to fields known by this +dataclass before instantiating it. Conversely, dropping the fields means new +saves don't carry inert names that suggest configurability where there is +none. +""" +from __future__ import annotations + from dataclasses import dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Optional + @dataclass class TaHConfig: - """Configuration for TaH model components.""" - # Overidable configs + # Wrapper-level settings ------------------------------------------------ embedding_key: str = "model.embed_tokens" - max_iter: int = None - iter_decider: str = None - input_updater: str = None - output_updater: str = None - train_loss: str = None - eval_loss: str = None - # Optional: use a different iter_decider for evaluation/inference - eval_iter_decider: str = None - adapter: str = None - iter_label_generator: str = None - iter_attention_mode: str = "duo" # Attention visibility mode: "duo", "root", or "same_iter" - - # Non-overidable configs + max_iter: Optional[int] = None + + # Iter decider — one of {"IterLabelDecider", "MLPIterDecider"}. + iter_decider: Optional[str] = None iter_decider_kwargs: Dict[str, Any] = field(default_factory=dict) + + # Optional alias used at eval/inference time. Either a class name (built + # afresh) or an attribute path like "iter_decider" (alias of the trained + # decider). When None, the trained decider is reused. + eval_iter_decider: Optional[str] = None + eval_iter_decider_kwargs: Dict[str, Any] = field(default_factory=dict) + + # Input updater is fixed to top-k softmax over logits + embedding-row mix. + # Only ``topk`` from this dict is read by the wrapper; the rest is kept as + # a dict for forwards-compat with old saved configs. input_updater_kwargs: Dict[str, Any] = field(default_factory=dict) - output_updater_kwargs: Dict[str, Any] = field(default_factory=dict) + + # Adapter is fixed to LoRA. ``adapter_kwargs`` are forwarded to + # ``peft.LoraConfig`` after popping the TaH-specific + # ``base_grad`` / ``adapter_grad`` knobs. + adapter: str = "lora" + adapter_kwargs: Dict[str, Any] = field(default_factory=dict) + + # Losses — one of {"NextTokenPredLoss", "IterDeciderLoss"}. + train_loss: Optional[str] = None train_loss_kwargs: Dict[str, Any] = field(default_factory=dict) + eval_loss: Optional[str] = None eval_loss_kwargs: Dict[str, Any] = field(default_factory=dict) - eval_iter_decider_kwargs: Dict[str, Any] = field(default_factory=dict) - adapter_kwargs: Dict[str, Any] = field(default_factory=dict) - iter_label_generator_kwargs: Dict[str, Any] = field(default_factory=dict) diff --git a/tah/model/tah_model.py b/tah/model/tah_model.py new file mode 100644 index 0000000..d4a87cd --- /dev/null +++ b/tah/model/tah_model.py @@ -0,0 +1,875 @@ +"""TaH: Selective Latent Iterations for Reasoning Language Models. + +Wraps a Hugging Face causal LM (Qwen3, etc.) so that a learned subset of +tokens runs additional internal forward passes ("iterations") to refine the +prediction. The decision to iterate is per-token, made by: + +* :class:`tah.model.iter_decider.IterLabelDecider` — step-1 SFT: continue iff + the dataset's oracle ``iter_count_labels`` say so. +* :class:`tah.model.iter_decider.MLPIterDecider` — step-2 SFT + eval/serving: + a small classifier over hidden states + top-k logits. + +Each iteration writes its KV into a separate slot of the per-layer +:class:`tah.model.causal_cache.TaHCache` so that future iterations can see +prior ones (causally) without disturbing the iter-0 cache. + +Single-implementation interfaces from the public TaH layout are inlined here: + +* ``topk_softmax_input_update`` — was ``input_updater.TrivialUpdater``. +* ``additive_logits_update`` — was ``output_updater.AdditiveLogitsUpdater``. +* The dense max-merge of per-iteration ``iter_count_labels`` + — was ``iter_label.FixedIterLabelGenerator``. +* LoRA setup / per-iteration enable + — was ``adapter.setup_adapter`` / ``configure_lora_for_iteration``. +""" +from __future__ import annotations + +import json +import os +from dataclasses import asdict, dataclass, fields +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import logging + +from tah.model.causal_cache import TaHCache +from tah.model.iter_decider import ( + ITER_DECIDER_BY_NAME, + IterDecider, + load_iter_decider, + save_iter_decider, +) +from tah.model.loss import LOSS_BY_NAME, LossFunc +from tah.model.tah_config import TaHConfig + +logger = logging.get_logger(__name__) + + +# ──────────────────────────────────────────────────────────────────────────── +# Inlined helpers (single-implementation slots from public TaH). +# ──────────────────────────────────────────────────────────────────────────── + + +def topk_softmax_input_update( + logits: torch.Tensor, + embedding_weight: torch.Tensor, + topk: int, +) -> torch.Tensor: + """Map per-token logits to a soft input embedding. + + Top-k softmax over ``logits`` (last dim is vocab), look up the + corresponding embedding rows, return the probability-weighted sum. Last + dim of ``logits`` is consumed; all leading dims are preserved. + """ + k = min(topk, logits.size(-1)) + topk_values, topk_indices = torch.topk(logits, k=k, dim=-1) + topk_probs = torch.softmax(topk_values, dim=-1) + topk_embeds = embedding_weight[topk_indices] # (..., k, H) + return torch.sum(topk_probs.unsqueeze(-1) * topk_embeds, dim=-2) + + +def additive_logits_update( + logits: torch.Tensor, + prev_logits: Optional[torch.Tensor], +) -> torch.Tensor: + """Residual accumulation of output logits across iterations.""" + return logits if prev_logits is None else prev_logits + logits + + +# ──────────────────────────────────────────────────────────────────────────── +# Active-token sparse <-> dense scatter helpers. +# +# Each iteration of the wrapper processes only those positions whose +# `current_iter_mask` is True. We pack them into a dense (B, max_active, …) +# block (right-padded), run the base model on that block, and scatter results +# back into the dense (B, T, …) buffers. +# ──────────────────────────────────────────────────────────────────────────── + + +def gather_active( + current_iter_mask: torch.BoolTensor, + *tensors, + pad_value=0, +) -> Tuple[torch.Tensor, ...]: + """Vectorised stable-sort gather: pack active positions to the left. + + For each batch row, the True positions of ``current_iter_mask`` are + gathered in order from each input tensor and right-padded to the + per-batch maximum active length. Inputs may be ``None`` (passed + through). Floating and long tensors get filled with ``pad_value`` at + pad positions; the caller may overwrite (e.g. ``-100`` for labels). + Returns ``(pad_mask, *gathered)`` where ``pad_mask`` is a + ``(B, max_active)`` bool that's True at padding positions. + """ + B, S = current_iter_mask.shape + device = current_iter_mask.device + SENTINEL = S + base_idx = torch.arange(S, device=device).expand(B, S).masked_fill(~current_iter_mask, SENTINEL) + sorted_idx, _ = torch.sort(base_idx, dim=1, stable=True) + max_len = int(current_iter_mask.sum(1).max().item()) if current_iter_mask.any() else 0 + gather_idx = sorted_idx[:, :max_len] + pad_mask = gather_idx.eq(SENTINEL) + gather_idx_clamped = gather_idx.clamp(max=max(S - 1, 0)) + + out = [pad_mask] + for t in tensors: + if t is None: + out.append(None) + continue + if t.dim() == 2: # (B, S) + g = torch.gather(t, 1, gather_idx_clamped).masked_fill(pad_mask, pad_value) + else: # (B, S, ...) + extra = t.shape[2:] + g_idx = gather_idx_clamped.view(B, max_len, *([1] * len(extra))).expand(B, max_len, *extra) + g = torch.gather(t, 1, g_idx).masked_fill( + pad_mask.view(B, max_len, *([1] * len(extra))), pad_value, + ) + out.append(g) + return tuple(out) + + +def scatter_back( + current_iter_mask: torch.BoolTensor, + src: torch.Tensor, + dest: torch.Tensor, + *, + in_place: bool = False, + assignment_mask: Optional[torch.BoolTensor] = None, +) -> torch.Tensor: + """Scatter a per-batch right-padded ``src`` back into a dense ``dest``. + + ``src`` is shape ``(B, max_active, ...)`` with row ``b`` valid up to + ``current_iter_mask[b].sum()``; the rest is padding. The valid prefix of + each row is placed at the True positions of ``current_iter_mask[b]`` in + ``dest`` (which has shape ``(B, S, ...)``). + + When ``assignment_mask`` is supplied, only positions where both + ``current_iter_mask`` AND ``assignment_mask`` are True are updated; the + rest of ``dest`` is preserved. + + Vectorised — no per-batch Python loop. + """ + B, max_active = src.shape[:2] + active_counts = current_iter_mask.sum(1) + # Pad positions in ``src``: column k of row b is padding iff k >= active_counts[b]. + pad_mask = torch.arange(max_active, device=src.device).unsqueeze(0) >= active_counts.unsqueeze(1) + valid_src = src[~pad_mask] # (sum(current_iter_mask), ...) — same row-major order as the mask + + out = dest if in_place else dest.clone() + if assignment_mask is None: + out[current_iter_mask] = valid_src + return out + + # Two-stage: first write all active values into a dense intermediate, then + # copy only the positions enabled by the assignment_mask. + intermediate = torch.zeros_like(out) + intermediate[current_iter_mask] = valid_src + final_mask = current_iter_mask & assignment_mask + out[final_mask] = intermediate[final_mask] + return out + + +# ──────────────────────────────────────────────────────────────────────────── +# Output dataclass and config helpers. +# ──────────────────────────────────────────────────────────────────────────── + + +@dataclass +class TaHCausalLMOutputWithPast(CausalLMOutputWithPast): + """Adds per-token ``iter_count`` (and optional generated label tensor).""" + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + iter_count: Optional[torch.LongTensor] = None + iter_count_labels: Optional[torch.LongTensor] = None + + +# ──────────────────────────────────────────────────────────────────────────── +# Config (de)serialisation. ``torch.dtype`` and bare ``type`` objects sneak +# into iter_decider_kwargs (e.g. dtype=torch.bfloat16); these helpers let us +# round-trip them through json without losing identity. +# ──────────────────────────────────────────────────────────────────────────── + +_DTYPE_BY_STR = { + "torch.float32": torch.float32, + "torch.float16": torch.float16, + "torch.bfloat16": torch.bfloat16, +} + + +def _config_to_serialisable(obj): + """Walk a config dict; encode ``torch.dtype`` and ``type`` as small dicts.""" + if isinstance(obj, dict): + return {k: _config_to_serialisable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_config_to_serialisable(v) for v in obj] + if isinstance(obj, type): + return {"__type__": True, "__module__": obj.__module__, "__name__": obj.__name__} + if isinstance(obj, torch.dtype): + return {"__dtype__": True, "__str__": str(obj)} + return obj + + +def _config_from_serialisable(obj): + """Inverse of :func:`_config_to_serialisable`.""" + import importlib + if isinstance(obj, dict): + if obj.get("__type__") is True: + return getattr(importlib.import_module(obj["__module__"]), obj["__name__"]) + if obj.get("__dtype__") is True: + return _DTYPE_BY_STR.get(obj["__str__"], torch.float32) + return {k: _config_from_serialisable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_config_from_serialisable(v) for v in obj] + return obj + + +def _resolve_attr_path(root, dotted: str): + obj = root + for seg in dotted.split("."): + if not seg or seg == "self": + continue + obj = getattr(obj, seg) + return obj + + +def _with_max_iter(kwargs: Optional[dict], max_iter: int) -> dict: + """Return ``kwargs`` with ``max_iter`` filled in (existing values win).""" + out = dict(kwargs or {}) + out.setdefault("max_iter", max_iter) + return out + + +def _build_loss(name: str, kwargs: dict, max_iter: int) -> LossFunc: + """Instantiate a loss class by name. Both shipped classes + (NextTokenPredLoss, IterDeciderLoss) accept ``max_iter`` via kwargs.""" + return LOSS_BY_NAME[name](**_with_max_iter(kwargs, max_iter)) + + +def _build_iter_decider(name: str, kwargs: dict, max_iter: int) -> IterDecider: + """Instantiate an iter_decider class by name with ``max_iter`` filled in.""" + return ITER_DECIDER_BY_NAME[name](**_with_max_iter(kwargs, max_iter)) + + +def _resolve_tah_config(checkpoint_dir: str, override: Optional[TaHConfig]) -> TaHConfig: + """Pick the TaHConfig to use when loading a checkpoint. + + Order of precedence: + 1. ``checkpoint_dir/tah_config.json`` (filtered to known fields, + and with serialised ``type`` / ``torch.dtype`` sentinels restored). + 2. Non-default fields of ``override`` overlay the saved config when both + are present (fields with value ``None`` or ``{}`` are treated as "use + saved"). + 3. ``TaHConfig()`` if neither is available — emits a warning since the + resulting model won't match any checkpoint shape. + """ + cfg_path = os.path.join(checkpoint_dir, "tah_config.json") + saved: Optional[TaHConfig] = None + if os.path.exists(cfg_path): + with open(cfg_path, encoding="utf-8") as f: + raw = _config_from_serialisable(json.load(f)) + valid = {f.name for f in fields(TaHConfig)} + saved = TaHConfig(**{k: v for k, v in raw.items() if k in valid}) + + if override is not None and saved is not None: + merged = asdict(saved) + for k, v in asdict(override).items(): + if v is not None and v != {}: + merged[k] = v + return TaHConfig(**merged) + if override is not None: + return override + if saved is not None: + return saved + logger.warning("No tah_config.json in %s and no tah_config given; using defaults", checkpoint_dir) + return TaHConfig() + + +# ──────────────────────────────────────────────────────────────────────────── +# The TaH wrapper. +# ──────────────────────────────────────────────────────────────────────────── + + +class TaHForCausalLM(PreTrainedModel): + """Selective-iteration wrapper around a HF causal LM. + + During each forward pass, every token starts at ``iter_depth = 0`` (the + base model's regular forward). For tokens where ``iter_decider`` votes + "continue", the wrapper builds a soft input embedding from the current + logits (top-k softmax mix) and re-runs the base model with LoRA enabled. + Up to ``max_iter`` such rounds are performed; finished tokens accumulate + their final logits via additive residuals, in-flight tokens carry on. + + Notes: + * The base model is wrapped with PEFT LoRA in ``__init__``. The adapter + is enabled only at ``iter_depth >= 1``; iter-0 runs the base weights. + * Per-iteration KV caches are kept in a single :class:`TaHCache` so that + attention masks can be built that allow each iteration to see all + prior iterations of the same positions plus the iter-0 prefix. + * ``iter_attention_mode`` is fixed to "duo" (the only mode used in the + canonical recipes); other modes from public TaH have been removed. + """ + + def __init__(self, base_model: PreTrainedModel, config: Optional[TaHConfig] = None): + # SDPA is the only attention impl we exercise in the recurrent loop. + base_model._supports_sdpa = True + super().__init__(base_model.config) + self.config = base_model.config + self.supports_gradient_checkpointing = True + + if config is None: + config = TaHConfig() + self.tah_config = config + + # Validate embedding key early. + try: + _resolve_attr_path(base_model, config.embedding_key) + except AttributeError as e: + raise ValueError(f"embedding_key {config.embedding_key!r} not found in base model") from e + self.embedding_key = config.embedding_key + + self.max_iter = config.max_iter + self.input_topk = int(config.input_updater_kwargs.get("topk", 100)) + + # Iter decider — kept as an nn.Module subclass since two impls are used. + self.iter_decider = _build_iter_decider( + config.iter_decider, config.iter_decider_kwargs, self.max_iter + ) + + # Optional eval-time decider override (always a path like "iter_decider" + # in the canonical recipes; we resolve once during from_pretrained too). + self.eval_iter_decider = self._resolve_eval_iter_decider(config) + + # Loss objects — one for train, one for eval. Step-1 uses the same + # NextTokenPredLoss for both; step-2 uses IterDeciderLoss for train, + # NextTokenPredLoss for eval. + self.train_loss = _build_loss(config.train_loss, config.train_loss_kwargs, self.max_iter) + self.eval_loss = ( + _build_loss(config.eval_loss, config.eval_loss_kwargs, self.max_iter) + if config.eval_loss else self.train_loss + ) + + # Base model attaches AFTER PEFT wrap. + self.simple_base_model = base_model + self._setup_lora(config) + + # LoRA enabled-state cache, so we don't toggle every step needlessly. + self._lora_enabled: Optional[bool] = None + + # ── lora ────────────────────────────────────────────────────────────── + + def _setup_lora(self, config: TaHConfig) -> None: + """Wrap the base model with PEFT LoRA. Adapter is gated per-iteration.""" + if config.adapter != "lora": + raise ValueError(f"adapter must be 'lora', got {config.adapter!r}") + # base_grad / adapter_grad are TaH-specific knobs the upstream PEFT + # LoraConfig doesn't accept; copy and pop before forwarding. + peft_kwargs = dict(config.adapter_kwargs) + base_grad = peft_kwargs.pop("base_grad", True) + adapter_grad = peft_kwargs.pop("adapter_grad", True) + self.simple_base_model = get_peft_model(self.simple_base_model, LoraConfig(**peft_kwargs)) + self._set_lora_grad_flags(base_grad, adapter_grad) + + def _set_lora_grad_flags(self, base_grad: bool, adapter_grad: bool) -> None: + """Enable/disable gradients on lora-* params vs everything else. + + No-op when both flags default to True (the common case at training + time, where HF Trainer manages requires_grad per parameter group). + """ + if base_grad is True and adapter_grad is True: + return + for name, p in self.simple_base_model.base_model.named_parameters(): + p.requires_grad = adapter_grad if "lora" in name.lower() else base_grad + + def _set_lora_enabled(self, enabled: bool) -> None: + if self._lora_enabled is enabled: + return + if enabled: + self.simple_base_model.base_model.enable_adapter_layers() + else: + self.simple_base_model.base_model.disable_adapter_layers() + self._lora_enabled = enabled + + def _resolve_eval_iter_decider(self, config: TaHConfig) -> IterDecider: + name = config.eval_iter_decider + if not name: + return self.iter_decider + if isinstance(name, str) and name.startswith("iter_decider"): + return _resolve_attr_path(self, name) + return _build_iter_decider(name, config.eval_iter_decider_kwargs, self.max_iter) + + # ── handles & device ────────────────────────────────────────────────── + + @property + def device(self) -> torch.device: + return self.simple_base_model.device + + @property + def embed_tokens(self): + # PEFT places the original model under .base_model.model. + return _resolve_attr_path(self.simple_base_model.base_model.model, self.embedding_key) + + @property + def _active_iter_decider(self) -> IterDecider: + """Decider used at the current train/eval mode.""" + if self.training: + return self.iter_decider + return self.eval_iter_decider or self.iter_decider + + # ── small forward-pass helpers ──────────────────────────────────────── + + @staticmethod + def _stack_hidden_states(outputs, device: torch.device) -> Optional[torch.Tensor]: + """Convert HF ``output_hidden_states`` (tuple of (B, T, H)) into ``(B, T, L, H)``.""" + hs = getattr(outputs, "hidden_states", None) + if not hs: + return None + layer_stack = torch.stack([h.to(device=device) for h in hs], dim=0) # (L, B, T, H) + return layer_stack.permute(1, 2, 0, 3) # (B, T, L, H) + + def _force_one_continuation( + self, + decision: torch.Tensor, + active_labels_shifted: Optional[torch.Tensor], + active_valid_mask: torch.Tensor, + iter_depth: int, + ) -> torch.Tensor: + """If labels exist and the decider says "stop everywhere", force one + labeled position to continue so the iter-decider keeps getting + gradient. No-op when there's no label or no remaining iter budget. + """ + if ( + active_labels_shifted is None + or decision is None + or decision.numel() == 0 + or iter_depth >= self.max_iter + ): + return decision + label_mask = (active_labels_shifted != -100)[active_valid_mask == 1] + if label_mask.any() and not decision[label_mask].any(): + candidates = torch.nonzero(label_mask, as_tuple=False).flatten() + chosen = candidates[torch.randint(0, candidates.numel(), (1,), device=decision.device)] + decision[chosen] = True + return decision + + @staticmethod + def _max_merge_iter_labels( + full: torch.Tensor, + active: torch.Tensor, + current_iter_mask: torch.BoolTensor, + ) -> torch.Tensor: + """Update the dense ``full`` iter-labels view with ``active`` proposals. + + ``-100`` positions in ``active`` are treated as 0 (no proposal); the + merge takes the per-position max so each token's label monotonically + accumulates across iterations. + """ + proposal = torch.zeros_like(active) + valid = active != -100 + proposal[valid] = active[valid] + tmp = torch.zeros_like(full) + scatter_back(current_iter_mask, src=proposal, dest=tmp, in_place=True) + return torch.maximum(full, tmp) + + @staticmethod + def _forward_kwargs(kwargs: dict) -> dict: + """Filter caller-supplied forward kwargs down to those the loss / callback + plumbing actually reads.""" + return {k: v for k, v in kwargs.items() if k in ("global_step", "num_items_in_batch")} + + # ── forward ─────────────────────────────────────────────────────────── + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[TaHCache] = None, + labels: Optional[torch.LongTensor] = None, + iter_count_labels: Optional[torch.LongTensor] = None, + use_cache: bool = False, + new_sequence: bool = False, # forwarded to oracle deciders if any + **kwargs, + ) -> TaHCausalLMOutputWithPast: + """One TaH forward over a batch. + + Args: + input_ids: ``(B, T)`` + attention_mask: ``(B, T_total)``; if longer than T, the last T + entries are taken as the per-query mask. + position_ids: optional ``(B, T)``; if absent, computed from the + cumulative sum of ``attention_mask`` plus the cache prefix. + labels: ``(B, T)`` for cross-entropy; ``-100`` is ignored. + iter_count_labels: ``(B, T)`` oracle labels for the iter decider. + use_cache: whether to return ``past_key_values`` for use in + subsequent forwards (decoding). + new_sequence: signals oracle-style deciders that the batch is fresh. + + Returns: + :class:`TaHCausalLMOutputWithPast`. + """ + # Public TaH disabled these unconditionally; we keep the contract. + assert not kwargs.get("output_attentions"), "TaH does not support output_attentions" + assert not kwargs.get("output_hidden_states"), "TaH does not support output_hidden_states" + + # Causal-LM shifted labels (next-token prediction). + if labels is not None: + labels_shifted = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + labels_all_shifted = F.pad(input_ids.clone(), (0, 1), value=-100)[..., 1:].contiguous() + else: + labels_shifted = None + labels_all_shifted = None + + B, T = input_ids.shape + V = self.config.vocab_size + device = input_ids.device + + input_embeds = self.embed_tokens(input_ids) # (B, T, H) + dtype = input_embeds.dtype + + cumulative_logits = torch.zeros(B, T, V, device=device, dtype=dtype) + final_output_logits = torch.zeros(B, T, V, device=device, dtype=dtype) + actual_iter_counts = torch.zeros(B, T, dtype=torch.long, device=device) + + cache = past_key_values if past_key_values is not None else TaHCache().to(device=device, dtype=dtype) + + if attention_mask is not None: + valid_mask = attention_mask[:, -T:].clone().to(dtype=torch.long) + assert valid_mask.shape == (B, T), f"attention_mask shape {attention_mask.shape} bad for T={T}" + else: + valid_mask = torch.ones_like(input_ids, dtype=torch.long) + + if position_ids is None: + cache_iter0_valid = cache.get_valid_mask_upto_iter(layer_idx=0, upto_iter_idx=0, init_batch_size=B).to(device) + position_ids = ( + torch.cumsum(torch.cat((cache_iter0_valid, valid_mask), dim=-1), dim=-1)[:, -T:] - 1 + ).clamp(min=0) + else: + position_ids = position_ids.clone() + + loss_func = self.train_loss if self.training else self.eval_loss + loss_func.prepare_loss(B, T, device, dtype) + + # FixedIterLabelGenerator behaviour, inlined: the caller-supplied + # `iter_count_labels` are the per-token oracle counts; we maintain a + # dense max-merge view across iterations for analysis/loss. + track_iter_labels = (iter_count_labels is not None) and (labels_shifted is not None) + if track_iter_labels: + full_iter_labels = torch.zeros(B, T, dtype=torch.long, device=device) + else: + full_iter_labels = None + + # Per-iteration loop ------------------------------------------------ + current_iter_mask = torch.ones_like(input_ids, dtype=torch.bool) + finished_mask = torch.zeros_like(current_iter_mask, dtype=torch.bool) + iter_depth = 0 + + while iter_depth < self.max_iter and current_iter_mask.any(): + self._set_lora_enabled(iter_depth >= 1) + + # ── Phase 1: pack active positions + run the base model once. + ( + pad_mask, + active_input_embeds, active_cumulative_logits, active_position_ids, + active_valid_mask, active_labels_shifted, active_iter_count_labels, + active_labels_all_shifted, + ) = gather_active( + current_iter_mask, + input_embeds, cumulative_logits, position_ids, + valid_mask, labels_shifted, iter_count_labels, labels_all_shifted, + ) + # Label-typed tensors need the loss ignore index at pad positions + # rather than the default 0. + if active_labels_shifted is not None: + active_labels_shifted = active_labels_shifted.masked_fill(pad_mask, -100) + if active_labels_all_shifted is not None: + active_labels_all_shifted = active_labels_all_shifted.masked_fill(pad_mask, -100) + + if active_valid_mask.shape[1] == 0: + break + + sdpa_attn_mask = self._build_attention_mask( + active_position_ids, active_valid_mask, cache, iter_depth, dtype=dtype + ) + cache.current_iter_depth = iter_depth + cache.position_ids_to_cache = active_position_ids + cache.valid_mask_to_cache = active_valid_mask + outputs = self.simple_base_model( + inputs_embeds=active_input_embeds, + position_ids=active_position_ids, + attention_mask=sdpa_attn_mask, + past_key_values=cache, + use_cache=True if iter_depth < self.max_iter - 1 else use_cache, + output_hidden_states=True, + ) + + iter_depth += 1 + + # ── Phase 2: residual-accumulate logits + scatter back to dense. + updated_active_logits = additive_logits_update( + outputs.logits.to(device=device), active_cumulative_logits, + ) + cumulative_logits = scatter_back( + current_iter_mask, src=updated_active_logits, dest=cumulative_logits, + ) + + # ── Phase 3: ask the iter decider whether to keep iterating. + all_hidden = self._stack_hidden_states(outputs, device) + decider = self._active_iter_decider + valid_active = active_valid_mask == 1 + decision, continue_logits = decider( + logits=updated_active_logits[valid_active], + iter_depth=iter_depth, + all_hidden_states=all_hidden[valid_active] if all_hidden is not None else None, + labels_shifted=active_labels_all_shifted[valid_active] if active_labels_all_shifted is not None else None, + iter_count_labels=active_iter_count_labels[valid_active] if active_iter_count_labels is not None else None, + ) + decision = self._force_one_continuation(decision, active_labels_shifted, active_valid_mask, iter_depth) + if continue_logits is not None: + continue_logits = continue_logits.to(device=device) + + # ── Phase 4: convert the per-active decision into per-token finished + # / next-iter masks; bump iter counts for everyone we just processed. + active_finished_mask = torch.ones_like(active_valid_mask, dtype=torch.bool) + active_finished_mask[valid_active] = ~decision + scatter_back(current_iter_mask, src=active_finished_mask, dest=finished_mask, in_place=True) + actual_iter_counts[current_iter_mask] += 1 + + # ── Phase 5: optional intra-iter loss accumulation (IterDeciderLoss). + if labels_shifted is not None and loss_func._is_intra_iter_loss: + loss_func.intra_iter_loss_func( + active_logits=updated_active_logits, + current_iter_mask=current_iter_mask, + active_labels_shifted=active_labels_shifted, + active_valid_continue_logits=continue_logits, + active_valid_mask=active_valid_mask, + iter_depth=iter_depth, + active_iter_count_labels=active_iter_count_labels, + iter_decider_threshold=decider.threshold, + model=self, + **self._forward_kwargs(kwargs), + ) + + # ── Phase 6: write tokens that just finalised into the output buffer. + if active_finished_mask.any(): + scatter_back( + current_iter_mask, src=updated_active_logits, dest=final_output_logits, + in_place=True, assignment_mask=finished_mask, + ) + + # ── Phase 7: max-merge iter-count labels into the dense view. + if track_iter_labels and active_iter_count_labels is not None: + full_iter_labels = self._max_merge_iter_labels( + full_iter_labels, active_iter_count_labels, current_iter_mask, + ) + + # ── Phase 8: prepare next-iter inputs by feeding logits → embeddings. + next_iter_mask = (~finished_mask) & current_iter_mask & (valid_mask == 1) + if next_iter_mask.any(): + active_next = (~active_finished_mask) & valid_active + active_input_embeds[active_next] = topk_softmax_input_update( + logits=updated_active_logits[active_next], + embedding_weight=self.embed_tokens.weight, + topk=self.input_topk, + ).to(device=device) + input_embeds = torch.zeros_like(input_embeds) + scatter_back( + current_iter_mask, src=active_input_embeds, dest=input_embeds, + in_place=True, assignment_mask=next_iter_mask, + ) + + current_iter_mask = next_iter_mask + if not current_iter_mask.any(): + break + + # Final cross-entropy / iter-decider loss. + loss = None + if labels_shifted is not None: + loss_kwargs = self._forward_kwargs(kwargs) + loss_kwargs["iter_count_labels"] = full_iter_labels if full_iter_labels is not None else iter_count_labels + loss_kwargs["model"] = self + if hasattr(self, "logger_callback"): + loss_kwargs["logger_callback"] = self.logger_callback + loss = loss_func.final_loss_func( + logits=final_output_logits, + labels_shifted=labels_shifted, + iter_count=actual_iter_counts, + training=self.training, + **loss_kwargs, + ) + # Optional avg-iter-count logging. + if hasattr(self, "logger_callback"): + num_items = kwargs.get("num_items_in_batch") + with torch.no_grad(): + if num_items is not None: + valid_iter = labels_shifted.detach() != -100 + iter_sum = actual_iter_counts.detach()[valid_iter].float().sum() + self.logger_callback.avg_iter_count += float((iter_sum / num_items).item()) + else: + self.logger_callback.avg_iter_count = float(actual_iter_counts.detach().float().mean().item()) + + return TaHCausalLMOutputWithPast( + loss=loss, + logits=final_output_logits, + past_key_values=cache if use_cache else None, + iter_count=actual_iter_counts, + iter_count_labels=full_iter_labels, + ) + + # ── attention mask ──────────────────────────────────────────────────── + + def _build_attention_mask( + self, + active_position_ids: torch.Tensor, + active_valid_mask: torch.LongTensor, + cache: TaHCache, + iter_depth: int, + dtype: torch.dtype, + ) -> Optional[torch.Tensor]: + """Build the SDPA attention mask for one iteration (``"duo"`` mode). + + At depth ``i``, position ``p`` may attend to KV slot ``(cp, ci)`` + iff ``cp <= p`` AND ``ci <= i`` AND that KV slot is valid. Returns a + ``(B, 1, query_len, total_kv)`` additive mask (0 = unmasked, + ``min_dtype`` = masked) or ``None`` if there is nothing to attend to. + """ + B, T = active_position_ids.shape + device = active_position_ids.device + + if cache is not None and cache.has_layer(layer_idx=0): + iter_index_cache = cache.get_cache_iter_index_upto_iter(layer_idx=0, upto_iter_idx=iter_depth) + pos_cache = cache.get_position_id_upto_iter(layer_idx=0, upto_iter_idx=iter_depth, init_batch_size=B) + valid_cache = cache.get_valid_mask_upto_iter(layer_idx=0, upto_iter_idx=iter_depth, init_batch_size=B) + kv_cache_len = iter_index_cache.shape[-1] + else: + iter_index_cache = torch.empty((0,), device=device, dtype=torch.long) + pos_cache = torch.empty((B, 0), device=device, dtype=torch.long) + valid_cache = torch.empty((B, 0), device=device, dtype=torch.long) + kv_cache_len = 0 + + kv_total = kv_cache_len + T + if kv_total == 0: + return None + min_dtype = torch.finfo(dtype).min + + kv_pos = torch.cat((pos_cache, active_position_ids), dim=-1)[:, None, :] # (B, 1, kv_total) + kv_valid = torch.cat((valid_cache, active_valid_mask), dim=-1)[:, None, :] # (B, 1, kv_total) + kv_iter = torch.cat( + (iter_index_cache, torch.full((T,), iter_depth, dtype=torch.long, device=device)), + dim=-1, + )[None, None, :] # (1, 1, kv_total) + + q_pos = active_position_ids[:, :, None] # (B, T, 1) + q_iter = torch.full_like(q_pos, iter_depth) # (B, T, 1) + + attn_bool = (kv_pos <= q_pos) & (kv_iter <= q_iter) & (kv_valid == 1) + attn = torch.full((B, T, kv_total), min_dtype, device=device, dtype=dtype) + attn[attn_bool] = 0.0 + return attn[:, None, :, :] + + # ── persistence ─────────────────────────────────────────────────────── + + def save_pretrained(self, save_directory: str, **kwargs) -> None: + """Save base model + LoRA + iter_decider + ``tah_config.json``.""" + os.makedirs(save_directory, exist_ok=True) + + # LoRA adapter directory + lora_dir = os.path.join(save_directory, "lora") + os.makedirs(lora_dir, exist_ok=True) + self.simple_base_model.save_pretrained(lora_dir, **kwargs) + + # Base model with cleaned keys (strip PEFT's `.base_layer` and `lora_*`). + base_model = self.simple_base_model.base_model.model + original_state_dict = base_model.state_dict + + def cleaned_state_dict(): + sd = original_state_dict() + return { + k.replace(".base_layer", ""): v + for k, v in sd.items() if "lora" not in k.lower() + } + + base_model.state_dict = cleaned_state_dict + try: + base_model.save_pretrained(save_directory, **kwargs) + finally: + base_model.state_dict = original_state_dict + + save_iter_decider(self.iter_decider, save_directory) + + # iter_decider_kwargs may contain torch.dtype / type objects (e.g. + # ``dtype=torch.bfloat16``); stringify them before json.dump. + with open(os.path.join(save_directory, "tah_config.json"), "w", encoding="utf-8") as f: + json.dump(_config_to_serialisable(asdict(self.tah_config)), f, indent=2, ensure_ascii=False) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + *args, + tah_config: Optional[TaHConfig] = None, + **kwargs, + ) -> "TaHForCausalLM": + """Load a saved TaH model. + + Accepts either a local directory or a Hugging Face Hub repo id; in + the latter case the snapshot is fetched (or its cached copy is + located) so we can read ``tah_config.json`` and ``iter_decider.bin`` + from a real path on disk. + + Resolution order for the final TaH config: + 1. ``tah_config.json`` from the checkpoint, if present. + 2. Fields of the ``tah_config`` argument that are non-None override + those values (lets callers tweak inference-time knobs). + 3. ``TaHConfig()`` defaults if nothing else is available. + """ + if not os.path.isdir(pretrained_model_name_or_path): + from huggingface_hub import snapshot_download + pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) + + final_cfg = _resolve_tah_config(pretrained_model_name_or_path, tah_config) + + # Pop a one-shot iter_decider override (the recipe form is + # `iter_decider_kwargs.load_path: …`). + iter_decider_path = None + if "load_path" in (final_cfg.iter_decider_kwargs or {}): + iter_decider_path = final_cfg.iter_decider_kwargs.pop("load_path") + + base_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + model = cls(base_model, config=final_cfg) + + # Re-attach LoRA weights. + adapter_path = os.path.join(pretrained_model_name_or_path, "lora") + if os.path.isdir(adapter_path): + model.simple_base_model.load_adapter(adapter_path, adapter_name="default") + model._set_lora_grad_flags( + base_grad=final_cfg.adapter_kwargs.get("base_grad", True), + adapter_grad=final_cfg.adapter_kwargs.get("adapter_grad", True), + ) + + # Load iter_decider weights. ``load_iter_decider`` always returns a + # CPU module; move it to the base model's device so the iter loop + # doesn't trip on a device mismatch. + if iter_decider_path is not None: + model.iter_decider = load_iter_decider( + iter_decider_path, + class_name=final_cfg.iter_decider, + init_args=final_cfg.iter_decider_kwargs, + ) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, "iter_decider.bin")): + model.iter_decider = load_iter_decider( + pretrained_model_name_or_path, + class_name=final_cfg.iter_decider, + init_args=final_cfg.iter_decider_kwargs, + ) + model.iter_decider = model.iter_decider.to(device=model.device, dtype=model.dtype) + model.eval_iter_decider = model._resolve_eval_iter_decider(final_cfg) + if model.eval_iter_decider is not model.iter_decider: + model.eval_iter_decider = model.eval_iter_decider.to(device=model.device, dtype=model.dtype) + return model diff --git a/tah/model/tracker.py b/tah/model/tracker.py deleted file mode 100644 index 2a5d95b..0000000 --- a/tah/model/tracker.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch -from typing import Any, Dict, List -import pandas as pd - -class TaHTracker: - """Utility to track TaH model internal states. - - Currently it records top-k logits for each call of - :func:`TaHForCausalLM._process_sparse_iteration` which corresponds to one - iteration in the recurrent loop. The tracker can be attached to an - ``TaHForCausalLM`` instance without modifying its code. - """ - - def __init__(self, top_k: int = 5) -> None: - self.top_k = top_k - self.records: List[Dict[str, Any]] = [] - self._orig_fn = None - self._model = None - self._call_idx = 0 - # Keep original output_updater forward and a pending context queue - self._orig_updater_fn = None - self._updater = None - self._pending_contexts: List[Dict[str, Any]] = [] - - def attach(self, model: Any) -> None: - """Attach tracker to ``model``. - - Parameters - ---------- - model : TaHForCausalLM - Model to track. ``model._process_sparse_iteration`` will be wrapped - so that logits from every iteration are logged. - """ - if self._model is not None: - raise RuntimeError("Tracker already attached to a model") - - self._model = model - self._orig_fn = model._process_sparse_iteration - - def wrapper(*args, **kwargs): - outputs = self._orig_fn(*args, **kwargs) - # Cache minimal context for the subsequent output_updater call - iter_depth = kwargs.get("iter_depth") - if iter_depth is None and len(args) > 5: - iter_depth = args[5] - - valid_mask = kwargs.get("valid_mask") - if valid_mask is None and len(args) > 2: - valid_mask = args[2] - - cache = kwargs.get("past_key_values") - if cache is None and len(args) > 6: - cache = args[6] - - # Queue context to be consumed by output_updater wrapper - self._pending_contexts.append({ - "iter_depth": iter_depth, - "valid_mask": valid_mask, - "cache": cache, - }) - return outputs - - model._process_sparse_iteration = wrapper - - # Also wrap the output_updater to record active_updated_cumulative_logits - if hasattr(model, "output_updater") and model.output_updater is not None: - self._updater = model.output_updater - self._orig_updater_fn = model.output_updater.forward - - def updater_wrapper(*u_args, **u_kwargs): - updated_logits = self._orig_updater_fn(*u_args, **u_kwargs) - - # Obtain the most recent pending context if available - context = None - if self._pending_contexts: - context = self._pending_contexts.pop(0) - - if context is not None and updated_logits is not None: - k = min(self.top_k, updated_logits.size(-1)) - last_token_logits = updated_logits[:, -1, :] - values, indices = torch.topk(last_token_logits, k=k, dim=-1) - perplexity, entropy = self.logits_to_perplexity_entropy(last_token_logits) - - iter_depth = context.get("iter_depth") - valid_mask = context.get("valid_mask") - cache = context.get("cache") - - for batch_idx in range(updated_logits.size(0)): - # Skip if last position is padding/inactive - if valid_mask is not None and valid_mask[batch_idx, -1] == 0: - continue - - record = { - "batch_idx": batch_idx, - "call_index": self._call_idx, - "iter_depth": iter_depth, - "step_index": (cache.get_seq_length() if cache is not None else None), - "perplexity": perplexity[batch_idx].item(), - "entropy": entropy[batch_idx].item(), - "topk_values": values.detach().cpu()[batch_idx, :].tolist(), - "topk_indices": indices.detach().cpu()[batch_idx, :].tolist(), - } - self.records.append(record) - self._call_idx += 1 - - return updated_logits - - model.output_updater.forward = updater_wrapper - - def detach(self) -> None: - """Remove hooks and restore the original model method.""" - if self._model is not None and self._orig_fn is not None: - self._model._process_sparse_iteration = self._orig_fn - if self._updater is not None and self._orig_updater_fn is not None: - self._updater.forward = self._orig_updater_fn - self._model = None - self._orig_fn = None - self._updater = None - self._orig_updater_fn = None - - def clear(self) -> None: - """Clear all recorded states.""" - self.records.clear() - self._call_idx = 0 - self._pending_contexts.clear() - - @staticmethod - def logits_to_perplexity_entropy(logits: torch.Tensor) -> torch.Tensor: - """Compute the perplexity of a logits tensor. - - Perplexity is calculated as :math:`\exp(H(p))` where - :math:`H(p)` is the entropy of the probability distribution obtained - by applying softmax to the logits. - - The returned value is the exponential of the entropy over - all tokens in ``logits``, keeping the same ... dimensions as logits except for the vocab dimension. - - Parameters - ---------- - logits : torch.Tensor - Tensor of shape ``(..., vocab_size)`` containing pre-softmax - activations. - - Returns - ------- - torch.Tensor - The computed perplexity, shape ``...`` (same as logits without vocab dimension). - torch.Tensor - The computed entropy, shape ``...`` (same as logits without vocab dimension). - """ - # Convert logits to probabilities - probs = torch.softmax(logits, dim=-1) - # Compute log-probabilities in a numerically stable way - log_probs = torch.log_softmax(logits, dim=-1) - # Entropy per token: -sum(p * log p) - entropy = -(probs * log_probs).sum(dim=-1) - # Perplexity = exp(entropy) - return torch.exp(entropy).detach().cpu(), entropy.detach().cpu() - - def to_pandas(self, selected_keys: List[str] | None = None): - """Convert tracked records to a ``pandas.DataFrame``. - - Parameters - ---------- - selected_keys : List[str] | None - Optional list of keys to keep from each record. If omitted all - keys are included. - - Returns - ------- - pandas.DataFrame - DataFrame containing the requested fields for every tracked call. - """ - if not self.records: - return pd.DataFrame() - - # Determine which keys to include - ALL_KEYS = list(self.records[0].keys()) - - if selected_keys is None: - keys = ALL_KEYS - else: - keys = list(selected_keys) - - # Build data limited to desired keys, falling back to None when key missing - data = [{k: rec.get(k) for k in keys} for rec in self.records] - - return pd.DataFrame(data)[keys] - \ No newline at end of file diff --git a/tah/model/utils.py b/tah/model/utils.py index a5d5182..c03a8e7 100644 --- a/tah/model/utils.py +++ b/tah/model/utils.py @@ -1,235 +1,203 @@ +"""Helpers used across ``tah/`` — config (de)serialisation, generation, debug +colouring, and a couple of small attribute / param utilities for trainers. + +Public TaH bundled a much larger ``utils.py`` (device-map juggling, an unused +class-string-to-type helper, an ``get_attr_recursive`` whose only caller was +``recurrent_transformer.py``); those have been removed. """ -Utility functions for TaH model components. -""" -import importlib -from typing import Optional, List, Type, TYPE_CHECKING, Union +from __future__ import annotations + +import os +import random +from typing import Optional, TYPE_CHECKING + +import numpy as np import torch import torch.nn.functional as F -from transformers import AutoTokenizer, PreTrainedModel import transformers -import random -import numpy as np -import os - -from accelerate import infer_auto_device_map -from accelerate.utils import get_balanced_memory - +from transformers import AutoTokenizer if TYPE_CHECKING: - from tah.model.recurrent_transformer import TaHForCausalLM + from tah.model.tah_model import TaHForCausalLM -def get_attr_by_path(root_obj, attr_path: str): - """Resolve dotted attribute path on an object; returns None if any hop is missing.""" - current_obj = root_obj - for name in attr_path.split('.'): - if not hasattr(current_obj, name): - return None - current_obj = getattr(current_obj, name) - return current_obj +# ──────────────────────────────────────────────────────────────────────────── +# Determinism / param introspection (used by SFT_TaH.py). +# ──────────────────────────────────────────────────────────────────────────── -def freeze_components(model, component_paths, accelerator): - """Freeze parameters of components specified by dotted paths (e.g., 'model.cascade_model').""" - if not component_paths: - return - for raw_path in component_paths: - # Accept paths starting with 'model.' or direct module names under model - path = raw_path - if path.startswith('model.'): - path = path[len('model.'):] - target = get_attr_by_path(model, path) - if target is None: - accelerator.print(f"Warning: freeze_component '{raw_path}' not found on model.") - continue - - params = list(target.parameters()) if hasattr(target, 'parameters') else [] - if not params: - accelerator.print(f"Warning: freeze_component '{raw_path}' has no parameters to freeze.") - continue - - for p in params: - p.requires_grad = False - num_params = sum(p.numel() for p in params) - accelerator.print(f"Froze component '{raw_path}' ({num_params:,} params).") - -def compute_trainable_param_size_gb(model) -> float: - total_bytes = 0 - for p in model.parameters(): - if p.requires_grad: - total_bytes += p.numel() * p.element_size() - return total_bytes / (1024 ** 3) - - -def set_all_seeds(seed=42): - """Set all random seeds for reproducibility""" - # Python built-in random +def set_all_seeds(seed: int = 42) -> None: + """Set Python, NumPy, PyTorch (CPU+CUDA) and HF Transformers seeds.""" random.seed(seed) - # NumPy np.random.seed(seed) - # PyTorch torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) # for multi-GPU - # Transformers + torch.cuda.manual_seed_all(seed) transformers.set_seed(seed) - # Environment variables - os.environ['PYTHONHASHSEED'] = str(seed) - + os.environ["PYTHONHASHSEED"] = str(seed) print(f"All random seeds set to {seed}") -def get_attr_recursive(obj, attr_path): - """ - Recursively get attribute from object using dot notation. +def _get_attr_by_path(root_obj, attr_path: str): + """Resolve dotted attribute path; return None if any hop is missing.""" + obj = root_obj + for name in attr_path.split("."): + if not hasattr(obj, name): + return None + obj = getattr(obj, name) + return obj - Args: - obj: Object to get attribute from - attr_path: Dot-separated attribute path (e.g., "model.embed_tokens") - Returns: - The requested attribute +def freeze_components(model, component_paths, accelerator) -> None: + """Freeze parameters of components named by dotted paths under the model. - Raises: - AttributeError: If attribute doesn't exist + Path segments may optionally start with ``model.`` (stripped). Missing + components are reported via ``accelerator.print`` rather than raising, + matching the public-TaH behaviour. """ - attrs = attr_path.split(".") - for attr in attrs: - obj = getattr(obj, attr) - return obj + if not component_paths: + return + for raw_path in component_paths: + path = raw_path[len("model."):] if raw_path.startswith("model.") else raw_path + target = _get_attr_by_path(model, path) + if target is None: + accelerator.print(f"Warning: freeze_component {raw_path!r} not found on model.") + continue + params = list(target.parameters()) if hasattr(target, "parameters") else [] + if not params: + accelerator.print(f"Warning: freeze_component {raw_path!r} has no parameters to freeze.") + continue + for p in params: + p.requires_grad = False + accelerator.print(f"Froze component {raw_path!r} ({sum(p.numel() for p in params):,} params).") -def class_string_to_type(cls_str: str) -> Type: - """ - Convert a string of class to a class - """ - module_name, class_name = cls_str.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, class_name) -def type_to_dict_string(obj): - """ - Convert type objects to serializable strings - """ +def compute_trainable_param_size_gb(model) -> float: + return sum(p.numel() * p.element_size() for p in model.parameters() if p.requires_grad) / (1024 ** 3) - if isinstance(obj, dict): - return {k: type_to_dict_string(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [type_to_dict_string(item) for item in obj] - elif isinstance(obj, type): - return { - "__type__": True, - "__module__": obj.__module__, - "__name__": obj.__name__ - } - elif isinstance(obj, torch.dtype): - return { - "__dtype__": True, - "__str__": str(obj) - } - else: - return obj - -def dict_string_to_type(obj): - """ - Convert serialized strings to type objects - """ - if isinstance(obj, dict): - if obj.get("__type__") is True: - # This is a serialized type object - module = importlib.import_module(obj["__module__"]) - return getattr(module, obj["__name__"]) - elif obj.get("__dtype__") is True: - # This is a serialized torch.dtype object - dtype_str = obj["__str__"] - # Map string representations back to torch.dtype objects - dtype_map = { - "torch.float32": torch.float32, - "torch.float16": torch.float16, - "torch.bfloat16": torch.bfloat16 - } - return dtype_map.get(dtype_str, torch.float32) # Default to float32 if not found - else: - return {k: dict_string_to_type(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [dict_string_to_type(item) for item in obj] - else: - return obj +# ──────────────────────────────────────────────────────────────────────────── +# Sampling. Used by ``TaHForCasualLM_generate`` below; kept separate so other +# generation code paths can pull just the sampler. +# ──────────────────────────────────────────────────────────────────────────── def sample_next_token( - logits, temperature=1.0, top_p=1.0, top_k=0, min_p=0.0, do_sample=True -): - """ - Sample next token from logits with various sampling strategies. - - Args: - logits: Tensor of shape (batch_size, vocab_size) - logits for next token prediction - temperature: Float > 0. Controls randomness. Lower = more deterministic. Default: 1.0 - top_p: Float between 0 and 1. Nucleus sampling - keep tokens with cumulative probability <= top_p. Default: 1.0 - top_k: Int >= 0. Keep only top k tokens. 0 means no filtering. Default: 0 - min_p: Float between 0 and 1. Remove tokens with probability < min_p * max_probability. Default: 0.0 - do_sample: Bool. If False, use greedy sampling (argmax). Default: True - - Returns: - token_ids: Sampled token IDs as a tensor of shape (batch_size,) + logits: torch.Tensor, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + min_p: float = 0.0, + do_sample: bool = True, +) -> torch.Tensor: + """Sample one token per row from ``(B, V)`` logits with the usual knobs. + + ``do_sample=False`` ⇒ argmax. Otherwise: temperature → softmax → optional + min_p / top_k / top_p filtering → multinomial. """ - # Handle greedy sampling cases if not do_sample: return torch.argmax(logits, dim=-1) - # Apply temperature if temperature != 1.0: logits = logits / temperature - - # Convert to probabilities probs = F.softmax(logits, dim=-1) - # Apply min_p filtering if min_p > 0.0: - max_probs = torch.max(probs, dim=-1, keepdim=True)[0] # (batch_size, 1) - min_prob_threshold = min_p * max_probs - probs = torch.where(probs >= min_prob_threshold, probs, torch.zeros_like(probs)) - # Renormalize - probs = probs / torch.sum(probs, dim=-1, keepdim=True) + max_probs = probs.max(dim=-1, keepdim=True).values + probs = torch.where(probs >= min_p * max_probs, probs, torch.zeros_like(probs)) + probs = probs / probs.sum(dim=-1, keepdim=True) - # Apply top_k filtering if top_k > 0: - top_k = min(top_k, probs.size(-1)) # Safety check - top_k_probs, _ = torch.topk(probs, top_k, dim=-1) # (batch_size, top_k) - threshold = top_k_probs[..., -1:] # (batch_size, 1) - the k-th largest value - indices_to_remove = probs < threshold - probs = torch.where(indices_to_remove, torch.zeros_like(probs), probs) - # Renormalize - probs = probs / torch.sum(probs, dim=-1, keepdim=True) - - # Apply top_p (nucleus) filtering + k = min(top_k, probs.size(-1)) + topk_probs, _ = torch.topk(probs, k, dim=-1) + probs = torch.where(probs < topk_probs[..., -1:], torch.zeros_like(probs), probs) + probs = probs / probs.sum(dim=-1, keepdim=True) + if top_p < 1.0: sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) - cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + cumulative = torch.cumsum(sorted_probs, dim=-1) + sorted_remove = cumulative > top_p + sorted_remove[..., 1:] = sorted_remove[..., :-1].clone() + sorted_remove[..., 0] = 0 + remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_remove) + probs = torch.where(remove, torch.zeros_like(probs), probs) + probs = probs / probs.sum(dim=-1, keepdim=True) + + return torch.multinomial(probs, num_samples=1).squeeze(-1) - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - # Create a mask for indices to remove - indices_to_remove = torch.zeros_like(probs, dtype=torch.bool) - indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove) - probs = torch.where(indices_to_remove, torch.zeros_like(probs), probs) - # Renormalize - probs = probs / torch.sum(probs, dim=-1, keepdim=True) +# ──────────────────────────────────────────────────────────────────────────── +# Generation loop + token coloring (used by the playground demo). +# ──────────────────────────────────────────────────────────────────────────── - # Sample from the filtered distribution - return torch.multinomial(probs, num_samples=1).squeeze(-1) # (batch_size,) + +class IterCountColors: + """ANSI-colour the per-token iteration counts during generation.""" + + _COLORS = { + 1: "\033[0m", # reset (white) + 2: "\033[92m", # green + 3: "\033[94m", # blue + 4: "\033[91m", # red + 5: "\033[95m", # magenta + 6: "\033[93m", # yellow + } + + @classmethod + def get_color(cls, n: int) -> str: + return cls._COLORS.get(n, "\033[96m") # cyan for >6 + + @classmethod + def print_token(cls, token_text: str, n: int) -> None: + print(f"{cls.get_color(n)}{token_text}\033[0m", end="", flush=True) + + @classmethod + def get_legend(cls) -> str: + return "Color legend: " + ", ".join( + f"{cls.get_color(n)}={n} iter\033[0m" for n in (1, 2, 3, 4, 5, 6, 7) + ) + + +def _forward_and_print( + tah_model: "TaHForCausalLM", + tokenizer: AutoTokenizer, + model_inputs: dict, + cache, + *, + new_sequence: bool, + verbose: bool, + **kwargs, +): + """One forward pass; if ``verbose``, print each non-padded token coloured + by its iteration count.""" + forward_kwargs = { + "input_ids": model_inputs["input_ids"], + "past_key_values": cache, + "use_cache": True, + "new_sequence": new_sequence, + **{k: v for k, v in model_inputs.items() if k != "input_ids" and v is not None}, + **kwargs, + } + outputs = tah_model(**forward_kwargs) + + if verbose and outputs.iter_count is not None: + attention_mask = model_inputs.get("attention_mask") + counts = outputs.iter_count[0] + if attention_mask is not None: + valid = attention_mask[0, -counts.shape[0]:] == 1 + tokens = [tokenizer.decode([t]) for t in model_inputs["input_ids"][0][valid]] + counts = counts[valid] + else: + tokens = [tokenizer.decode([t]) for t in model_inputs["input_ids"][0]] + for tok, c in zip(tokens, counts): + IterCountColors.print_token(tok, int(c.item())) + return outputs def TaHForCasualLM_generate( tah_model: "TaHForCausalLM", tokenizer: AutoTokenizer, model_inputs: dict, - iter_count: Optional[torch.Tensor] = None, + *, max_new_tokens: int = 1024, do_sample: bool = True, temperature: float = 1.0, @@ -238,316 +206,59 @@ def TaHForCasualLM_generate( min_p: float = 0.0, verbose: bool = True, **kwargs, -) -> tuple[list[list[int]], list[str]]: - """ - Generation function for TaH model with sampling support for batched inputs. - - Args: - tah_model: TaHForCausalLM model instance - tokenizer: tokenizer instance - model_inputs: dict containing 'input_ids', 'attention_mask', and other model inputs - iter_count: torch.Tensor of shape (batch_size, seq_len) or None - max_new_tokens: maximum number of new tokens to generate - do_sample: whether to use sampling or greedy decoding - temperature: sampling temperature (> 0.0) - top_p: nucleus sampling probability threshold - top_k: top-k sampling parameter (0 = disabled) - min_p: minimum probability threshold relative to the most likely token - verbose: whether to print debug output during generation - **kwargs: additional keyword arguments to pass to the model - Returns: - generated_tokens: list of lists, each containing generated token IDs for each batch item - generated_texts: list of decoded texts for each batch item +): + """Greedy / sampling generation for ``TaHForCausalLM`` with batched inputs. + + Returns ``(output_tokens, generated_texts)`` — the former is per-batch + list-of-token-ids, the latter their decoded strings. """ device = model_inputs["input_ids"].device - batch_size = model_inputs["input_ids"].shape[0] + B = model_inputs["input_ids"].shape[0] tah_model.eval() - # Initialize generation state cache = None - output_tokens = [[] for _ in range(batch_size)] - finished = torch.zeros(batch_size, dtype=torch.bool, device=device) - - # Keep track of current attention mask for extension - current_attention_mask = model_inputs.get("attention_mask", None) + output_tokens: list[list[int]] = [[] for _ in range(B)] + finished = torch.zeros(B, dtype=torch.bool, device=device) + current_attn = model_inputs.get("attention_mask") if verbose: print("Input tokens with iteration counts:") - with torch.no_grad(): - # Phase 1: Prefill - process initial input sequence - outputs = _forward_and_display( - tah_model, - tokenizer, - model_inputs, - iter_count, - cache, - is_prefill=True, - verbose=verbose, - new_sequence=True, - **kwargs, + outputs = _forward_and_print( + tah_model, tokenizer, model_inputs, cache, + new_sequence=True, verbose=verbose, **kwargs, ) cache = outputs.past_key_values - if verbose: print("\n\nGenerating new tokens:") - # Phase 2: Decoding - generate new tokens one by one - for step in range(max_new_tokens): - # Sample next token from current outputs for all batch items - last_token_logits = outputs.logits[:, -1, :] # (batch_size, vocab_size) - next_token_ids = sample_next_token( - logits=last_token_logits, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - do_sample=do_sample, - ) # (batch_size,) - - # Check for EOS token and update finished status + for _ in range(max_new_tokens): + last_logits = outputs.logits[:, -1, :] + next_ids = sample_next_token( + last_logits, temperature=temperature, top_p=top_p, + top_k=top_k, min_p=min_p, do_sample=do_sample, + ) if tokenizer.eos_token_id is not None: - eos_mask = next_token_ids == tokenizer.eos_token_id - finished = finished | eos_mask - - # Add tokens to output for non-finished sequences - for batch_idx in range(batch_size): - if not finished[batch_idx]: - output_tokens[batch_idx].append(next_token_ids[batch_idx].item()) - - # Check if all sequences are finished + finished = finished | (next_ids == tokenizer.eos_token_id) + for i in range(B): + if not finished[i]: + output_tokens[i].append(int(next_ids[i].item())) if finished.all(): break - # Prepare inputs for next token - next_input_ids = next_token_ids.unsqueeze(1) - next_model_inputs = {"input_ids": next_input_ids} - - # Extend attention mask - if current_attention_mask is not None: - new_token_mask = torch.ones( - batch_size, 1, dtype=current_attention_mask.dtype, device=device + next_inputs = {"input_ids": next_ids.unsqueeze(1)} + if current_attn is not None: + current_attn = torch.cat( + [current_attn, torch.ones(B, 1, dtype=current_attn.dtype, device=device)], dim=1, ) - current_attention_mask = torch.cat( - [current_attention_mask, new_token_mask], dim=1 - ) - next_model_inputs["attention_mask"] = current_attention_mask - - # Forward pass for next token - outputs = _forward_and_display( - tah_model, - tokenizer, - next_model_inputs, - iter_count=None, # Use automatic iteration from iter_decider - cache=cache, - is_prefill=False, - verbose=verbose, - new_sequence=False, - **kwargs, - ) - cache = outputs.past_key_values - - if verbose: - print("\033[0m") # Reset color - - # Decode generated texts - generated_texts = [ - tokenizer.decode(tokens) if tokens else "" - for tokens in output_tokens - ] + next_inputs["attention_mask"] = current_attn - return output_tokens, generated_texts - -def get_device_map(model: "TaHForCausalLM", device_map: Union[str, torch.device, int], dtype: torch.dtype): - """ - Get the device map for the model. Input device map should choose from: - - a string: "auto", "balanced", "balanced_low_0", "sequential", or a device name like "cpu", "cuda:0" - - a torch.device object - - an int (device index) - - a dict mapping module names to devices - This function normalizes the device_map to a dict, or infers it if using auto-mapping. - """ - # change device_map into a map if we passed an int, a str or a torch.device - if isinstance(device_map, torch.device): - device_map = {"": device_map} - elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: - try: - device_map = {"": torch.device(device_map)} - except RuntimeError: - raise ValueError( - "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " - f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." - ) - elif isinstance(device_map, int): - if device_map < 0: - raise ValueError( - "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + outputs = _forward_and_print( + tah_model, tokenizer, next_inputs, cache, + new_sequence=False, verbose=verbose, **kwargs, ) - else: - device_map = {"": device_map} - else: - no_split_modules = model.simple_base_model._get_no_split_modules(device_map) - no_split_modules.append(model.iter_decider.__class__.__name__) - no_split_modules.append(model.input_updater.__class__.__name__) - - device_map_kwargs = { - "no_split_module_classes": no_split_modules, - } - - max_mem = get_balanced_memory( - model, - dtype=dtype, - **device_map_kwargs, - ) - device_map = infer_auto_device_map( - model, - max_memory=max_mem, - dtype=dtype, - **device_map_kwargs - ) - return device_map - - - - -def _forward_and_display( - tah_model: "TaHForCausalLM", - tokenizer: AutoTokenizer, - model_inputs: dict, - iter_count: Optional[torch.Tensor], - cache: Optional[object], - is_prefill: bool = False, - new_sequence: bool = False, - verbose: bool = True, - **kwargs, -) -> object: - """ - Uniform function for forward pass and token display for both prefill and decoding. - - Args: - tah_model: TaH model instance - tokenizer: tokenizer instance - model_inputs: dict containing 'input_ids', 'attention_mask', and other model inputs - iter_count: iteration counts for tokens - cache: past key values cache - is_prefill: whether this is the prefill phase or decoding phase - new_sequence: whether this is a new sequence - verbose: whether to display token colors and debug output - - Returns: - Model outputs - """ - # Extract input_ids and prepare forward pass arguments - input_ids = model_inputs["input_ids"] - - # Prepare forward pass arguments with all available model inputs - forward_kwargs = { - "input_ids": input_ids, - "iter_count": iter_count, - "past_key_values": cache, - "use_cache": True, - "new_sequence": new_sequence, - **kwargs, - } - - # Add attention mask if available - if "attention_mask" in model_inputs and model_inputs["attention_mask"] is not None: - forward_kwargs["attention_mask"] = model_inputs["attention_mask"] - - # Add any other inputs that might be present - for key, value in model_inputs.items(): - if key not in ["input_ids", "attention_mask"] and value is not None: - forward_kwargs[key] = value - - # Forward pass - # TODO: Add position ids - outputs = tah_model(**forward_kwargs) + cache = outputs.past_key_values - # Display tokens with actual iteration counts, respecting attention mask if verbose: - tokens = [tokenizer.decode([token_id]) for token_id in input_ids[0]] - - # Get attention mask to avoid printing padding tokens - attention_mask = model_inputs.get("attention_mask", None) - if attention_mask is not None: - # Only print tokens where attention mask is 1 (non-padding) - valid_positions = attention_mask[0, -outputs.iter_count[0].shape[0]:] == 1 - tokens = [token for i, token in enumerate(tokens) if valid_positions[i]] - - if hasattr(outputs, "iter_count") and outputs.iter_count is not None: - actual_counts = outputs.iter_count[0] - if attention_mask is not None: - actual_counts = actual_counts[valid_positions] - for token, actual_count in zip(tokens, actual_counts): - IterCountColors.print_token(token, actual_count.item()) - elif iter_count is not None: - # Fallback to input iter_count - iter_counts_to_use = iter_count[0] - if attention_mask is not None: - iter_counts_to_use = iter_counts_to_use[valid_positions] - for token, count in zip(tokens, iter_counts_to_use): - IterCountColors.print_token(token, count.item()) - else: - # Default to 1 iteration - for token in tokens: - IterCountColors.print_token(token, 1) - - return outputs - - -class IterCountColors: - """Utility class for handling iteration count based coloring.""" - - @staticmethod - def get_color(iter_count_val): - """ - Get ANSI color code for given iteration count. - - Args: - iter_count_val: The iteration count value - - Returns: - ANSI color code string - """ - colors = { - 1: "\033[0m", # Default/reset (white) - 2: "\033[92m", # Green - 3: "\033[94m", # Blue - 4: "\033[91m", # Red - 5: "\033[95m", # Magenta - 6: "\033[93m", # Yellow - } - return colors.get(iter_count_val, "\033[96m") # Cyan for counts > 6 - - @staticmethod - def print_token(token_text, iter_count_val): - """ - Print token with color based on iteration count. - - Args: - token_text: The token text to print - iter_count_val: The iteration count value for coloring - """ - color = IterCountColors.get_color(iter_count_val) - reset = "\033[0m" - print(f"{color}{token_text}{reset}", end="", flush=True) - - @staticmethod - def get_legend(): - """ - Get color legend string for iteration counts. - - Returns: - String describing the color mapping with colors applied - """ - reset = "\033[0m" - legend_parts = [ - f"{IterCountColors.get_color(1)}Default=1 iter{reset}", - f"{IterCountColors.get_color(2)}Green=2 iter{reset}", - f"{IterCountColors.get_color(3)}Blue=3 iter{reset}", - f"{IterCountColors.get_color(4)}Red=4 iter{reset}", - f"{IterCountColors.get_color(5)}Magenta=5 iter{reset}", - f"{IterCountColors.get_color(6)}Yellow=6 iter{reset}", - f"{IterCountColors.get_color(7)}Cyan=7+ iter{reset}", - ] - return "Color legend: " + ", ".join(legend_parts) + print("\033[0m") + return output_tokens, [tokenizer.decode(toks) if toks else "" for toks in output_tokens] diff --git a/tah/train/__init__.py b/tah/train/__init__.py index 5660148..4dc7187 100644 --- a/tah/train/__init__.py +++ b/tah/train/__init__.py @@ -1,16 +1,11 @@ -""" -TaH Training Module - -This module contains training-related components for TaH models. -""" +"""TaH training: HF Trainer subclass, data collator, callback for iter-aware logging.""" from .data_collator import CustomTaHDataCollator -from .trainer import CustomTaHTrainer, LoggerCallback, weighted_cross_entropy, fixed_cross_entropy +from .trainer import CustomTaHTrainer, LoggerCallback, fixed_cross_entropy __all__ = [ - "CustomTaHDataCollator", + "CustomTaHDataCollator", "CustomTaHTrainer", "LoggerCallback", - "weighted_cross_entropy", - "fixed_cross_entropy" -] \ No newline at end of file + "fixed_cross_entropy", +] diff --git a/tah/train/data_collator.py b/tah/train/data_collator.py index 339a361..e2dcae8 100644 --- a/tah/train/data_collator.py +++ b/tah/train/data_collator.py @@ -1,127 +1,81 @@ -from typing import Optional, Any, Union -from transformers import PreTrainedTokenizerBase -from transformers.data.data_collator import PaddingStrategy, DataCollatorForSeq2Seq +"""Data collator for TaH SFT. + +Wraps :class:`~transformers.DataCollatorForSeq2Seq` to handle the additional +``iter_count_labels`` field that the labelling pipeline produces. The field is +padded alongside ``input_ids`` (same length, same padding side) using the +label ignore-index ``-100`` for padding positions, then converted to a +``LongTensor``. + +Public TaH had separate code paths for "no padding" inputs and for ``list`` +vs ``ndarray`` element types; in practice every dataset feeds tokenised +inputs as ``list[int]`` and asks for padding, so the cleaned collator only +implements that one path. ``iter_count_pad_value`` is removed (the base +collator's ``label_pad_token_id`` controls both labels and iter_count_labels). +""" +from __future__ import annotations + +from typing import Any, Optional, Union + import numpy as np +import torch +from transformers import PreTrainedTokenizerBase +from transformers.data.data_collator import DataCollatorForSeq2Seq, PaddingStrategy + class CustomTaHDataCollator: - """ - Custom data collator for TaH that handles iter_count field along with standard fields. - """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, - model: Optional[Any] = None, - padding: Union[bool, str, PaddingStrategy] = True, - max_length: Optional[int] = None, - pad_to_multiple_of: Optional[int] = None, - label_pad_token_id: int = -100, - iter_count_pad_value: int = -1, - return_tensors: str = "pt"): - """ - Initialize custom data collator for TaH. - - Args: - tokenizer: Tokenizer instance - model: Optional model instance - padding: Padding strategy - max_length: Maximum length for padding - pad_to_multiple_of: Pad to multiple of this value - label_pad_token_id: Padding token ID for labels (default: -100) - iter_count_pad_value: Padding value for iter_count (default: -1) - return_tensors: Type of tensors to return (default: "pt") - """ + """Pads ``input_ids`` / ``attention_mask`` / ``labels`` (via base collator) + plus the TaH-specific ``iter_count_labels`` field.""" + + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + padding: Union[bool, str, PaddingStrategy] = True, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + label_pad_token_id: int = -100, + return_tensors: str = "pt", + ): self.tokenizer = tokenizer - self.model = model self.padding = padding self.max_length = max_length self.pad_to_multiple_of = pad_to_multiple_of self.label_pad_token_id = label_pad_token_id - self.iter_count_pad_value = iter_count_pad_value self.return_tensors = return_tensors - - # Create base data collator for handling standard fields self.base_collator = DataCollatorForSeq2Seq( - tokenizer=tokenizer, - padding=padding, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - label_pad_token_id=label_pad_token_id, - return_tensors=return_tensors + tokenizer=tokenizer, padding=padding, max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, label_pad_token_id=label_pad_token_id, + return_tensors=return_tensors, ) - + def __call__(self, features, return_tensors=None): - if return_tensors is None: - return_tensors = self.return_tensors - - # Extract iter_count_labels from features if present - iter_count_labels_list = [] - has_iter_count_labels = False - if features and 'iter_count_labels' in features[0]: - has_iter_count_labels = True - iter_count_labels_list = [feature.pop('iter_count_labels') for feature in features] - - # Use base collator for standard fields (input_ids, attention_mask, labels) + return_tensors = return_tensors or self.return_tensors + + iter_labels_list = [] + if features and "iter_count_labels" in features[0]: + iter_labels_list = [f.pop("iter_count_labels") for f in features] + batch = self.base_collator(features, return_tensors=return_tensors) - - # Handle iter_count_labels field if present - if has_iter_count_labels and iter_count_labels_list: - # Get padding configuration - no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD - - if no_padding: - # No padding case - if isinstance(iter_count_labels_list[0], list): - batch["iter_count_labels"] = list(iter_count_labels_list) - else: - batch["iter_count_labels"] = [ - np.concatenate([iter_count_labels, []]) - for iter_count_labels in iter_count_labels_list - ] - else: - # Padding case - strictly align with input_ids padding length - if "input_ids" in batch: - max_iter_length = batch["input_ids"].shape[1] - else: - # Fallback: infer from current list - max_iter_length = max(len(v) for v in iter_count_labels_list) - - # Apply pad_to_multiple_of if specified - if self.pad_to_multiple_of is not None: - max_iter_length = ( - (max_iter_length + self.pad_to_multiple_of - 1) - // self.pad_to_multiple_of - * self.pad_to_multiple_of - ) - - # Determine padding side - padding_side = self.tokenizer.padding_side - pad_value = self.label_pad_token_id - - # Pad iter_count_labels sequences - if isinstance(iter_count_labels_list[0], list): - batch["iter_count_labels"] = [ - iter_count_labels + [pad_value] * (max_iter_length - len(iter_count_labels)) - if padding_side == "right" - else [pad_value] * (max_iter_length - len(iter_count_labels)) + iter_count_labels - for iter_count_labels in iter_count_labels_list - ] - else: - batch["iter_count_labels"] = [ - np.concatenate([ - iter_count_labels, - np.array([pad_value] * (max_iter_length - len(iter_count_labels)), dtype=np.int64) - ]) if padding_side == "right" - else np.concatenate([ - np.array([pad_value] * (max_iter_length - len(iter_count_labels)), dtype=np.int64), - iter_count_labels - ]) - for iter_count_labels in iter_count_labels_list - ] - - # Convert iter_count_labels to tensors if needed - if has_iter_count_labels and batch.get("iter_count_labels", None) is not None: - if return_tensors == "pt": - import torch - batch["iter_count_labels"] = torch.tensor(batch["iter_count_labels"], dtype=torch.long) - else: - batch["iter_count_labels"] = np.array(batch["iter_count_labels"], dtype=np.int64) + if not iter_labels_list: + return batch + + # Pad iter_count_labels to match input_ids length on the same side as + # the tokenizer's padding side; -100 marks ignored positions for the loss. + target_len = batch["input_ids"].shape[1] if "input_ids" in batch else max(len(v) for v in iter_labels_list) + if self.pad_to_multiple_of is not None: + target_len = (target_len + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of + + right_pad = self.tokenizer.padding_side == "right" + pad_val = self.label_pad_token_id + padded = [] + for v in iter_labels_list: + v = list(v) if not isinstance(v, list) else v + n_pad = target_len - len(v) + row = (v + [pad_val] * n_pad) if right_pad else ([pad_val] * n_pad + v) + padded.append(row) + + if return_tensors == "pt": + batch["iter_count_labels"] = torch.tensor(padded, dtype=torch.long) + else: + batch["iter_count_labels"] = np.asarray(padded, dtype=np.int64) return batch diff --git a/tah/train/trainer.py b/tah/train/trainer.py index edb3c2f..a7c8adc 100644 --- a/tah/train/trainer.py +++ b/tah/train/trainer.py @@ -1,25 +1,39 @@ +"""HF Trainer subclass + iter-aware logging callback for TaH SFT. + +The Trainer override exists for two reasons: + +1. ``_save`` calls ``model.save_pretrained`` so the TaH-specific layout + (``tah_config.json``, ``iter_decider.bin``, ``lora/``) is written instead + of the default HF flat-state-dict save. + +2. ``_get_dataloader`` adds ``shuffle=True`` to the DataLoader kwargs (the + public TaH had this; we preserve the behaviour). + +The callback (``LoggerCallback``) just wires the wrapper's running counters +(``avg_iter_count``, ``iter_decider_accuracy``) into the per-step ``logs`` +dict that HF Trainer flushes to wandb / tensorboard. Public TaH had extra +plumbing for composite/scheduled losses (``CombinedLoss``, ``InterleavedLoss``) +which never existed in the released checkpoint and is dropped. + +``fixed_cross_entropy`` is the loss helper used by ``NextTokenPredLoss``; +``weighted_cross_entropy`` from public TaH is removed (the hard-token weight +plumbing it served has been removed from ``NextTokenPredLoss`` too). +""" +from __future__ import annotations + import os -from transformers import Trainer, TrainingArguments, PreTrainedTokenizerBase, TrainerCallback, TrainerState, TrainerControl +from functools import partial +from typing import Optional + import torch -import torch.nn as nn -from typing import Optional, Any, Union, Tuple from torch.utils.data import DataLoader, Dataset -from functools import partial +from transformers import Trainer, TrainerCallback, TrainerControl, TrainerState, TrainingArguments from transformers.trainer_utils import seed_worker -# from tah.evaluate.eval_unified import allocate_gpus_and_run_jobs class CustomTaHTrainer(Trainer): - """ - self-defined Trainer class to ensure that the custom save_pretrained method of the TaH model is called when saving the model. - """ - def __init__(self, *args, **kwargs): - # Extract prediction_config from kwargs before passing to parent - self.prediction_config = kwargs.pop('prediction_config', None) - super().__init__(*args, **kwargs) - self.gradient_accumulation_steps = getattr(self.args, 'gradient_accumulation_steps', 1) - - # override _get_dataloader method to add shuffle=True parameter + """Trainer with TaH-aware ``_save`` and a shuffling dataloader override.""" + def _get_dataloader( self, dataset: Dataset, @@ -29,77 +43,54 @@ def _get_dataloader( is_training: bool = False, dataloader_key: Optional[str] = None, ) -> DataLoader: - """Override _get_dataloader method to add shuffle=True parameter.""" - data_collator = self.data_collator - if hasattr(self, '_remove_unused_columns'): + if hasattr(self, "_remove_unused_columns"): dataset = self._remove_unused_columns(dataset, description=description) else: - data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description) + data_collator = self._get_collator_with_removed_columns( + self.data_collator, description=description + ) - dataloader_params = { + params = { "batch_size": batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "persistent_workers": self.args.dataloader_persistent_workers, - "shuffle": True, # Add shuffle=True parameter + "shuffle": True, } - if not isinstance(dataset, torch.utils.data.IterableDataset): if sampler_fn is not None: - dataloader_params["sampler"] = sampler_fn(dataset) - dataloader_params.pop("shuffle", None) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + params["sampler"] = sampler_fn(dataset) + params.pop("shuffle", None) + params["drop_last"] = self.args.dataloader_drop_last + params["prefetch_factor"] = self.args.dataloader_prefetch_factor if is_training: - dataloader_params["worker_init_fn"] = partial( + params["worker_init_fn"] = partial( seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index ) else: - # For IterableDataset, remove shuffle parameter to avoid conflict - dataloader_params.pop("shuffle", None) - - dataloader = DataLoader(dataset, **dataloader_params) + params.pop("shuffle", None) - # Accelerator.free_memory() will destroy the references, so - # we need to store the non-prepared version for eval dataloaders. + dataloader = DataLoader(dataset, **params) if dataloader_key is not None and self.args.dataloader_persistent_workers: - if hasattr(self, "_eval_dataloaders"): - self._eval_dataloaders[dataloader_key] = dataloader - else: - self._eval_dataloaders = {dataloader_key: dataloader} - + store = getattr(self, "_eval_dataloaders", None) or {} + store[dataloader_key] = dataloader + self._eval_dataloaders = store return self.accelerator.prepare(dataloader) - + def _save(self, output_dir=None, state_dict=None): - """Override _save method to ensure that the custom save_pretrained method of the TaH model is called when saving the model.""" - # use output directory or default output directory - output_dir = output_dir if output_dir is not None else self.args.output_dir + """Use TaH's custom save_pretrained when available; default Trainer behavior otherwise.""" + output_dir = output_dir or self.args.output_dir os.makedirs(output_dir, exist_ok=True) - - # save model - if hasattr(self.model, 'save_pretrained') and hasattr(self.model, 'config'): - # for TaH model, use custom save_pretrained method - print(f"use TaH custom save_pretrained method to save model to: {output_dir}") - print(f" - save base model and config...") + if hasattr(self.model, "save_pretrained") and hasattr(self.model, "config"): + print(f"Saving TaH model to: {output_dir}") self.model.save_pretrained(output_dir) else: - # for normal model, use default method - print(f"use default method to save model to: {output_dir}") super()._save(output_dir, state_dict) - - # save tokenizer - if getattr(self, 'tokenizer', None) is not None: + if getattr(self, "tokenizer", None) is not None: self.tokenizer.save_pretrained(output_dir) - def evaluate(self, *args, **kwargs): - """Override evaluate method to add predict with generation.""" - base_metrics = super().evaluate() - - # Only run generation evaluation if prediction_config is provided and this is the main process - return base_metrics - def fixed_cross_entropy( source: torch.Tensor, @@ -108,158 +99,31 @@ def fixed_cross_entropy( ignore_index: int = -100, **kwargs, ) -> torch.Tensor: + """Cross-entropy with optional sum/N normalisation matching HF's per-item averaging.""" reduction = "sum" if num_items_in_batch is not None else "mean" - loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) if reduction == "sum": loss = loss / num_items_in_batch return loss -def weighted_cross_entropy( - logits: torch.Tensor, - labels: torch.Tensor, - token_weights: torch.Tensor, - num_items_in_batch: Optional[int] = None, - ignore_index: int = -100, - **kwargs, -) -> torch.Tensor: - """ - Compute weighted cross-entropy loss where each token can have a different weight. - - Args: - logits: Model predictions of shape (batch_size * seq_len, vocab_size) - labels: Target labels of shape (batch_size * seq_len,) - token_weights: Weight for each token of shape (batch_size * seq_len,) - num_items_in_batch: Number of valid items in batch for averaging - ignore_index: Label index to ignore (default: -100) - - Returns: - Weighted cross-entropy loss - """ - # Compute per-token cross-entropy loss (no reduction) - loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=ignore_index) - per_token_loss = loss_fct(logits, labels) - - # Apply mask to ignore specified labels - valid_mask = labels != ignore_index - - # Compute weighted loss - if valid_mask.any(): - valid_losses = per_token_loss[valid_mask] - valid_weights = token_weights[valid_mask] - - # Weighted sum - weighted_sum = (valid_losses * valid_weights).sum() - - if num_items_in_batch is not None: - # Average by number of items in batch - loss = weighted_sum / num_items_in_batch - else: - # Average by sum of weights - loss = weighted_sum / valid_weights.sum() - else: - loss = torch.tensor(0.0, device=logits.device, requires_grad=True) - - return loss class LoggerCallback(TrainerCallback): - def __init__(self, trainer): - self.trainer = trainer - self.avg_iter_count = 0 - self.iter_decider_accuracy = 0.0 - self.iter_decider_precision = 0.0 - - def _update_iter_decider_training_state(self, state: TrainerState, args: TrainingArguments): - """Propagate current step/epoch into model.iter_decider if supported.""" - model = getattr(self.trainer, 'model', None) - if model is None: - return - iter_decider = getattr(model, 'iter_decider', None) - if iter_decider is None: - return - # Initialize scheduling meta once at train begin - if hasattr(iter_decider, 'num_grow_steps') and (getattr(iter_decider, 'num_grow_steps', None) in [None, 0]): - # Align grow steps with trainer's planned max_steps - if hasattr(state, 'max_steps') and state.max_steps is not None: - iter_decider.num_grow_steps = state.max_steps - if hasattr(iter_decider, 'num_epochs') and (getattr(iter_decider, 'num_epochs', None) in [None, 0]): - # Align epochs with args - if hasattr(args, 'num_train_epochs') and args.num_train_epochs is not None: - try: - iter_decider.num_epochs = int(args.num_train_epochs) - except Exception: - pass - - # Update dynamic training state for this step/epoch - if hasattr(iter_decider, 'update_training_state') and callable(iter_decider.update_training_state): - current_step = getattr(state, 'global_step', 0) or 0 - current_epoch = int(state.epoch) if getattr(state, 'epoch', None) is not None else 0 - iter_decider.update_training_state(current_step=current_step, current_epoch=current_epoch) + """Push the wrapper's running iter-aware counters into the trainer's log stream. - def _update_loss_training_state(self, state: TrainerState, args: TrainingArguments): - """Propagate current step/epoch into InterleavedLoss (or nested losses) if supported.""" - model = getattr(self.trainer, 'model', None) - if model is None: - return - loss_objs = [] - if hasattr(model, 'train_loss') and model.train_loss is not None: - loss_objs.append(model.train_loss) - if hasattr(model, 'eval_loss') and model.eval_loss is not None: - loss_objs.append(model.eval_loss) - - if not loss_objs: - return - - current_step = getattr(state, 'global_step', 0) or 0 - current_epoch = int(state.epoch) if getattr(state, 'epoch', None) is not None else 0 - - def _maybe_update(obj): - if obj is None: - return - # Update if this loss exposes the method - if hasattr(obj, 'update_training_state') and callable(obj.update_training_state): - try: - obj.update_training_state(current_step=current_step, current_epoch=current_epoch) - except Exception: - pass - # Recurse into composite losses (e.g., CombinedLoss, InterleavedLoss) - for attr in ('primary_loss', 'secondary_loss'): - if hasattr(obj, attr): - _maybe_update(getattr(obj, attr)) - - for loss_obj in loss_objs: - _maybe_update(loss_obj) - - def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - # Ensure iter_decider receives initial scheduling meta and state - self._update_iter_decider_training_state(state, args) - self._update_loss_training_state(state, args) - - def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - # Keep iter_decider in sync each step - self._update_iter_decider_training_state(state, args) - self._update_loss_training_state(state, args) + The wrapper writes ``self.avg_iter_count`` (every forward) and + ``self.iter_decider_accuracy`` (intra-iter loss path) onto its + ``logger_callback`` attribute; we forward and reset them on every + ``on_log`` event. + """ - def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - # Also refresh at epoch boundaries - self._update_iter_decider_training_state(state, args) - self._update_loss_training_state(state, args) + def __init__(self): + self.avg_iter_count = 0.0 + self.iter_decider_accuracy = 0.0 def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - """ - Log the average iteration count after each step. - - Kwargs: 'model', 'processing_class', 'optimizer', 'lr_scheduler', 'train_dataloader', 'eval_dataloader', 'logs' - """ - kwargs['logs']['avg_iter_count'] = self.avg_iter_count - if self.iter_decider_accuracy is not None and self.iter_decider_accuracy > 0.0: - kwargs['logs']['iter_decider_accuracy'] = self.iter_decider_accuracy - # kwargs['logs']['iter_decider_precision'] = self.iter_decider_precision - - # if hasattr(self.trainer.model, 'iter_decider'): - # iter_decider = self.trainer.model.iter_decider - # if hasattr(iter_decider, 'transition_weight'): - # kwargs['logs']['transition_weight'] = iter_decider.transition_weight - - self.avg_iter_count = 0 + logs = kwargs["logs"] + logs["avg_iter_count"] = self.avg_iter_count + if self.iter_decider_accuracy > 0.0: + logs["iter_decider_accuracy"] = self.iter_decider_accuracy + self.avg_iter_count = 0.0 self.iter_decider_accuracy = 0.0 - self.iter_decider_precision = 0.0 \ No newline at end of file diff --git a/tah/utils/__init__.py b/tah/utils/__init__.py new file mode 100644 index 0000000..a2a7d0a --- /dev/null +++ b/tah/utils/__init__.py @@ -0,0 +1,5 @@ +"""TaH utils: SFT data preprocessing.""" + +from tah.utils.data_prepare import preprocess_dataset + +__all__ = ["preprocess_dataset"] diff --git a/tah/utils/data_prepare.py b/tah/utils/data_prepare.py old mode 100755 new mode 100644 index baf0cba..ee0de9c --- a/tah/utils/data_prepare.py +++ b/tah/utils/data_prepare.py @@ -1,567 +1,202 @@ -from typing import Dict -import numpy as np -from datasets import load_from_disk -from functools import partial -from accelerate import Accelerator +"""SFT dataset preprocessing for TaH. + +Loads the labeller-produced dataset (with ``real_token``, ``mask``, +``mismatch`` columns) and turns each example into the +``(input_ids, attention_mask, labels, iter_count_labels)`` tuple consumed by +:class:`tah.train.CustomTaHDataCollator`. + +Public TaH supported many alternative ``iter_count_strategy`` modes (random, +top_entropy, top_ce, ds_divergence, nonmismatch, divergent, all, maxiter) +that the canonical SFT recipes never selected. The cleaned version keeps the +``mismatch`` strategy only — the same one used by both +``sft_tah_step1.yaml`` and ``sft_tah_step2.yaml``. + +Two columns the canonical labeller never writes are also dropped (``divergent``, +``ds_divergence``); ``entropy`` and ``cross_entropy`` are still removed before +collation when present (the labeller can be configured to emit them, but the +mismatch strategy doesn't read them). +""" +from __future__ import annotations +from functools import partial +from typing import Dict, Optional -def _infer_strategy_from_iter_decider(model_config: Dict, is_eval: bool = False) -> str | None: - """ - Infer hard token strategy from iter_decider configuration in model_config. - Returns one of {"mismatch", "divergent"} or None when not applicable. - """ - # Select decider source - decider_key = 'eval_iter_decider' if is_eval and 'eval_iter_decider' in model_config else 'iter_decider' - decider_kwargs_key = 'eval_iter_decider_kwargs' if is_eval and 'eval_iter_decider_kwargs' in model_config else 'iter_decider_kwargs' - - iter_decider = model_config.get(decider_key) - iter_decider_kwargs = model_config.get(decider_kwargs_key, {}) or {} - - if iter_decider is None: - return None - - # Direct FixedLabelIterDecider - if iter_decider == 'FixedLabelIterDecider': - label_type = (iter_decider_kwargs or {}).get('label_type', None) - return label_type if label_type in {"mismatch", "divergent"} else None - - # SmoothTransitionIterDecider: look at initial decider - if iter_decider == 'SmoothTransitionIterDecider': - init_cls = iter_decider_kwargs.get('initial_iter_decider_cls') - init_kwargs = iter_decider_kwargs.get('initial_iter_decider_kwargs', {}) or {} - if init_cls == 'FixedLabelIterDecider': - label_type = init_kwargs.get('label_type', None) - return label_type if label_type in {"mismatch", "divergent"} else None +import numpy as np +from accelerate import Accelerator +from datasets import load_from_disk - return None +# Columns produced by the labeller that are inputs to preprocessing but not +# meaningful at training time, so we strip them after building the SFT batch. +_LABELLER_INPUT_COLUMNS = ( + "data_id", "real_text", "real_token", "mask", "mismatch", + "divergent", "entropy", "cross_entropy", "ds_divergence", "problem_idx", +) -def calculate_hard_token_ratio(dataset, hard_token_strategy: str, accelerator: Accelerator): - """ - Optimized calculation of hard token ratio using batch processing and vectorized operations. - """ - - accelerator.print(f"Calculating hard token ratio for {hard_token_strategy} strategy...") - - def calculate_stats_batch(examples): - """Calculate statistics for a batch of examples.""" - batch_labels = [] - batch_iter_count_labels = [] - - for i in range(len(examples['labels'])): - batch_labels.extend(examples['labels'][i]) - batch_iter_count_labels.extend(examples['iter_count_labels'][i]) - - # Convert to numpy for vectorized operations - labels_np = np.array(batch_labels) - iter_count_labels_np = np.array(batch_iter_count_labels) - - # Calculate masks - valid_mask = labels_np != -100 - hard_mask = iter_count_labels_np > 1 - - valid_tokens = np.sum(valid_mask) - hard_tokens = np.sum(valid_mask & hard_mask) - - return { - 'valid_tokens': [valid_tokens], - 'hard_tokens': [hard_tokens] - } - - # Use map with batched=True for efficient processing - # Ensure only the main process performs the computation first to populate cache - with accelerator.main_process_first(): - stats_dataset = dataset.map( - calculate_stats_batch, - batched=True, - batch_size=1000, - remove_columns=dataset.column_names, - desc="Calculating token statistics" - ) - - # Sum up the results - total_valid_tokens = sum(stats_dataset['valid_tokens']) - total_hard_tokens = sum(stats_dataset['hard_tokens']) - - if total_valid_tokens > 0: - avg_hard_ratio = total_hard_tokens / total_valid_tokens - accelerator.print(f"hard token ratio for {hard_token_strategy}: {avg_hard_ratio:.4f}") - accelerator.print(f" - Total hard tokens: {total_hard_tokens}") - accelerator.print(f" - Total valid tokens: {total_valid_tokens}") - return avg_hard_ratio - - return None -def preprocess_for_sft_batch( - examples: Dict, - max_length: int, +def _build_sft_example( + input_ids: list, + prompt_mask: list, + mismatch_mask: Optional[list], max_iter: int, - iter_count_strategy: str, - iter_count_strategy_kwargs: Dict, - iter_count_label_strategy: str, - query_iter_count: int = 1, - max_length_action: str = "cutoff", -) -> Dict: + max_length: Optional[int], + truncate: bool, + query_iter_count: int, +) -> Optional[Dict]: + """Convert one labelled example into an SFT example. + + Returns None for examples that should be filtered out (no supervised + tokens, or longer than ``max_length`` when truncation is disabled). """ - Optimized batch processing function for better performance with large datasets. - """ - batch_size = len(examples['real_token']) - batch_input_ids = [] - batch_iter_count_labels = [] - batch_attention_mask = [] - batch_labels = [] - - for i in range(batch_size): - input_ids = examples['real_token'][i] - prompt_mask = examples['mask'][i] - - # skip this data if prompt_mask is all zeros - if not any(prompt_mask): - continue - - mismatch_mask = examples['mismatch'][i] if 'mismatch' in examples else None - divergent_mask = examples['divergent'][i] if 'divergent' in examples else None - entropy = examples['entropy'][i] if 'entropy' in examples else None - cross_entropy = examples['cross_entropy'][i] if 'cross_entropy' in examples else None - ds_divergence = examples['ds_divergence'][i] if 'ds_divergence' in examples else None - - # Apply truncation or filtering if max_length is specified - if max_length is not None and len(input_ids) > max_length: - action = (max_length_action or "cutoff").lower() - if action == "filter": - # skip this sample entirely when it exceeds max_length - continue - # default: cutoff - input_ids = input_ids[:max_length] - prompt_mask = prompt_mask[:max_length] - mismatch_mask = mismatch_mask[:max_length] if mismatch_mask is not None else None - divergent_mask = divergent_mask[:max_length] if divergent_mask is not None else None - entropy = entropy[:max_length] if entropy is not None else None - cross_entropy = cross_entropy[:max_length] if cross_entropy is not None else None - - # Vectorized label creation - replace prompt tokens with -100 - labels = np.array(input_ids, dtype=np.int64) - prompt_mask_np = np.array(prompt_mask) - labels[prompt_mask_np == 0] = -100 - labels = labels.tolist() - - attention_mask = [1] * len(input_ids) - - iter_count_labels = np.ones(len(input_ids), dtype=np.int32) - - # Initialize iter_count/iter_count_labels independently based on strategies - if max_iter > 1: - def _sample_random_indices(length: int, k: int) -> np.ndarray: - if k <= 0: - return np.array([], dtype=np.int64) - k = min(k, length) - return np.random.choice(length, k, replace=False) - - def _select_indices_by_entropy(entropy_list, k: int) -> np.ndarray: - if k <= 0: - return np.array([], dtype=np.int64) - entropy_np = np.array(entropy_list) - k = min(k, len(entropy_np)) - return np.argsort(entropy_np)[-k:] - - def _assign_binned_iter_values(sorted_idx: np.ndarray, values: np.ndarray, max_iter: int) -> None: - num_tokens = len(sorted_idx) - if num_tokens == 0: - return - - num_bins = max_iter - 1 - base = num_tokens // num_bins if num_bins > 0 else num_tokens - rem = num_tokens % num_bins if num_bins > 0 else 0 - start = 0 - for bin_id in range(max(num_bins, 1)): - size = base + (1 if bin_id < rem else 0) - if size <= 0: - continue - end = start + size - bin_indices = sorted_idx[start:end] - iter_value = max_iter - bin_id if num_bins > 0 else 2 - iter_value = max(2, min(max_iter, iter_value)) - values[bin_indices] = iter_value - start = end - - def compute_iter_values(strategy: str, strategy_kwargs: Dict | None = None, base_counts: np.ndarray | None = None, prompt_mask_np: np.ndarray = None) -> np.ndarray: - values = np.ones(len(input_ids), dtype=np.int32) - if strategy == "copy": - # copy only makes sense for labels; fall back to ones if base_counts is None - return base_counts.copy() if base_counts is not None else values - if strategy == "random": - random_token_ratio = strategy_kwargs.get('random_token_ratio', 0.1) - # Only select from valid tokens (prompt_mask==1) - valid_idx = np.nonzero(prompt_mask_np)[0] - if len(valid_idx) > 0: - num = int(len(valid_idx) * random_token_ratio) - idx = _sample_random_indices(len(valid_idx), num) - if len(idx) > 0: - selected_idx = valid_idx[idx] - values[selected_idx] = np.random.randint(2, max_iter + 1, len(selected_idx)) - elif strategy == "top_entropy": - entropy_token_ratio = strategy_kwargs.get('top_entropy_ratio', 0.1) - # Only select from valid tokens (prompt_mask==1) - valid_idx = np.nonzero(prompt_mask_np)[0] - if len(valid_idx) > 0: - valid_entropy = [entropy[i] for i in valid_idx] - num = int(len(valid_idx) * entropy_token_ratio) - relative_idx = _select_indices_by_entropy(valid_entropy, num) - if len(relative_idx) > 0: - selected_idx = valid_idx[relative_idx] - values[selected_idx] = np.random.randint(2, max_iter + 1, len(selected_idx)) - elif strategy == "mismatch": - mismatch_np = np.array(mismatch_mask) - if np.any(mismatch_np > 1): - values = mismatch_np + 1 - elif cross_entropy is not None: - # If given top_ce_ratio, select the top ce tokens from mismatch tokens - ce_np = np.array(cross_entropy) - # Only consider valid tokens (prompt_mask==1) - mismatch_and_valid = mismatch_np & prompt_mask_np - mismatch_idx = np.nonzero(mismatch_and_valid)[0] - if len(mismatch_idx) > 0: - ce_on_mismatch = ce_np[mismatch_idx] - order = np.argsort(-ce_on_mismatch) - sorted_idx = mismatch_idx[order] - # If top_ce_ratio is specified, only keep top ratio of tokens - if 'top_ce_ratio' in strategy_kwargs: - top_ce_ratio = strategy_kwargs['top_ce_ratio'] - num_top_ce = max(1, int(len(sorted_idx) * top_ce_ratio)) - sorted_idx = sorted_idx[:num_top_ce] - - # Use the extracted function to assign binned iter values - _assign_binned_iter_values(sorted_idx, values, max_iter) - else: - # Only consider valid tokens (prompt_mask==1) - mismatch_and_valid = mismatch_np & prompt_mask_np - idx = np.nonzero(mismatch_and_valid)[0] - if len(idx) > 0: - values[idx] = np.random.randint(2, max_iter + 1, len(idx)) - elif strategy == "top_ce": - # Select top ce_ratio tokens with highest cross entropy as hard tokens - if cross_entropy is not None and 'top_ce_ratio' in strategy_kwargs: - ce_np = np.array(cross_entropy) - top_ce_ratio = strategy_kwargs['top_ce_ratio'] - - # Get all valid token indices (excluding prompt tokens) - valid_idx = np.nonzero(prompt_mask_np)[0] - if len(valid_idx) > 0: - # Get cross entropy values for valid tokens - ce_on_valid = ce_np[valid_idx] - # Sort by cross entropy (highest first) - order = np.argsort(-ce_on_valid) - sorted_idx = valid_idx[order] - - # Select top ratio tokens - num_top_ce = max(1, int(len(sorted_idx) * top_ce_ratio)) - sorted_idx = sorted_idx[:num_top_ce] - - # Use the extracted function to assign binned iter values - _assign_binned_iter_values(sorted_idx, values, max_iter) - elif strategy == "ds_divergence": - # Select top ds_ratio tokens with highest ds_divergence as hard tokens - if ds_divergence is not None and 'top_ds_ratio' in strategy_kwargs: - ds_np = np.array(ds_divergence) - top_ds_ratio = strategy_kwargs['top_ds_ratio'] - # Get all valid token indices (excluding prompt tokens) - valid_idx = np.nonzero(prompt_mask_np)[0] - if len(valid_idx) > 0: - # Get ds_divergence values for valid tokens - ds_on_valid = ds_np[valid_idx] - # Sort by ds_divergence (highest first) - order = np.argsort(-ds_on_valid) - sorted_idx = valid_idx[order] - # Select top ratio tokens - num_top_ds = max(1, int(len(sorted_idx) * top_ds_ratio)) - sorted_idx = sorted_idx[:num_top_ds] - # Use the extracted function to assign binned iter values - _assign_binned_iter_values(sorted_idx, values, max_iter) - elif strategy == "nonmismatch": - nonmismatch_np = np.array(mismatch_mask) == 0 - if np.any(nonmismatch_np > 1): - values = nonmismatch_np + 1 - else: - # Only consider valid tokens (prompt_mask==1) - nonmismatch_and_valid = nonmismatch_np & prompt_mask_np - idx = np.nonzero(nonmismatch_and_valid)[0] - if len(idx) > 0: - values[idx] = np.random.randint(2, max_iter + 1, len(idx)) - elif strategy == "divergent": - # Only consider valid tokens (prompt_mask==1) - divergent_and_valid = np.array(divergent_mask) & prompt_mask_np - idx = np.nonzero(divergent_and_valid)[0] - if len(idx) > 0: - values[idx] = np.random.randint(2, max_iter + 1, len(idx)) - elif strategy == "all": - # Only assign to valid tokens (prompt_mask==1) - valid_idx = np.nonzero(prompt_mask_np)[0] - if len(valid_idx) > 0: - values[valid_idx] = np.random.randint(2, max_iter + 1, len(valid_idx)) - elif strategy == "maxiter": - # Only assign to valid tokens (prompt_mask==1) - valid_idx = np.nonzero(prompt_mask_np)[0] - if len(valid_idx) > 0: - values[valid_idx] = max_iter - return values + if not any(prompt_mask): + return None - # Compute arrays (iter_count is local helper, not returned) - iter_count = compute_iter_values(iter_count_strategy, iter_count_strategy_kwargs, prompt_mask_np=prompt_mask_np) - iter_count_labels = compute_iter_values(iter_count_label_strategy, base_counts=iter_count, prompt_mask_np=prompt_mask_np) - - # Ensure prompt tokens use configured iteration for both arrays - iter_count[prompt_mask_np == 0] = query_iter_count - iter_count_labels[prompt_mask_np == 0] = query_iter_count - - batch_input_ids.append(input_ids) - batch_iter_count_labels.append(iter_count_labels.tolist()) - batch_attention_mask.append(attention_mask) - batch_labels.append(labels) - + if max_length is not None and len(input_ids) > max_length: + if not truncate: + return None + input_ids = input_ids[:max_length] + prompt_mask = prompt_mask[:max_length] + mismatch_mask = mismatch_mask[:max_length] if mismatch_mask is not None else None + + labels = np.array(input_ids, dtype=np.int64) + prompt_np = np.array(prompt_mask) + labels[prompt_np == 0] = -100 # ignore prompt tokens for next-token loss + + iter_labels = np.ones(len(input_ids), dtype=np.int32) + if max_iter > 1 and mismatch_mask is not None: + mismatch_np = np.array(mismatch_mask) + if np.any(mismatch_np > 1): + # The labeller already emitted oracle iter counts directly. + iter_labels = mismatch_np + 1 + else: + # Otherwise, mismatch is a 0/1 mask of "hard" positions; assign a + # uniform random extra iter count to each, restricted to valid + # (non-prompt) positions. + mismatch_and_valid = mismatch_np & prompt_np + idx = np.nonzero(mismatch_and_valid)[0] + if len(idx) > 0: + iter_labels[idx] = np.random.randint(2, max_iter + 1, len(idx)) + iter_labels[prompt_np == 0] = query_iter_count + + return { + "input_ids": input_ids, + "attention_mask": [1] * len(input_ids), + "labels": labels.tolist(), + "iter_count_labels": iter_labels.tolist(), + } + + +def _preprocess_for_sft_batch(examples: Dict, *, max_iter: int, max_length: Optional[int], truncate: bool, query_iter_count: int) -> Dict: + """Vectorise :func:`_build_sft_example` over a HF Datasets batch.""" + n = len(examples["real_token"]) + out = {"input_ids": [], "attention_mask": [], "labels": []} if max_iter > 1: - return { - 'input_ids': batch_input_ids, - 'attention_mask': batch_attention_mask, - 'labels': batch_labels, - 'iter_count_labels': batch_iter_count_labels, - } - else: - return { - 'input_ids': batch_input_ids, - 'attention_mask': batch_attention_mask, - 'labels': batch_labels, - } + out["iter_count_labels"] = [] + has_mismatch = "mismatch" in examples + for i in range(n): + result = _build_sft_example( + input_ids=examples["real_token"][i], + prompt_mask=examples["mask"][i], + mismatch_mask=examples["mismatch"][i] if has_mismatch else None, + max_iter=max_iter, + max_length=max_length, + truncate=truncate, + query_iter_count=query_iter_count, + ) + if result is None: + continue + for k in out: + out[k].append(result[k]) + return out + + +def _filter_keep(examples: Dict, *, max_length: Optional[int], truncate: bool) -> list[bool]: + """Drop rows with empty supervision or (when not truncating) over-length.""" + masks, tokens = examples["mask"], examples["real_token"] + mismatches = examples.get("mismatch") + enforce_len = bool(max_length) and not truncate + apply_window = bool(max_length) and truncate + keep = [] + for i in range(len(tokens)): + end = max_length if apply_window else len(tokens[i]) + mask_i = masks[i][:end] + ok = any(mask_i) + if ok and mismatches is not None: + mismatch_i = mismatches[i][:end] + ok = ok and any(mask_i[j] == 1 and mismatch_i[j] == 1 for j in range(len(mask_i))) + if enforce_len: + ok = ok and (len(tokens[i]) <= max_length) + keep.append(ok) + return keep def preprocess_dataset(training_config: Dict, data_config: Dict, model_config: Dict, accelerator: Accelerator): - """Load and preprocess the dataset with optimizations for large datasets.""" - accelerator.print("Loading dataset...") - # get data config - train_data_path = data_config['train_data_path'] - eval_data_path = data_config.get('eval_data_path', None) - train_data_ratio = data_config.get('train_data_ratio', 1.0) - eval_data_ratio = data_config.get('eval_data_ratio', 0.05) - - # get hard token strategy and relative weight - # Prefer strategy inferred from iter_decider config; fallback to training config - hard_token_strategy = _infer_strategy_from_iter_decider(model_config, is_eval=False) - eval_hard_token_strategy = _infer_strategy_from_iter_decider(model_config, is_eval=True) - hard_relative_weight = training_config.get('hard_token_relative_weight', 1.0) + """Load + preprocess train (and optionally eval) datasets for SFT.""" + del training_config # unused since the hard_token_relative_weight path was removed + accelerator.print("Loading dataset…") + train_ds_path = data_config["train_data_path"] + eval_ds_path = data_config.get("eval_data_path") + train_ratio = data_config.get("train_data_ratio", 1.0) + eval_ratio = data_config.get("eval_data_ratio", 0.05) + + max_iter = model_config.get("max_iter", 1) + max_length = data_config.get("max_length") + truncate = (data_config.get("max_length_action", "cutoff") or "cutoff").lower() != "filter" + query_iter_count = int(data_config.get("query_iter_count", 1)) - # get model config - max_iter = model_config.get('max_iter', 1) - - # Load train dataset with accelerator.main_process_first(): - raw_train_dataset = load_from_disk(train_data_path) - - if train_data_ratio != 1.0: - accelerator.print(f"Using {train_data_ratio} of train dataset") - raw_train_dataset = raw_train_dataset.select(range(int(len(raw_train_dataset) * train_data_ratio))) - - # Check if eval_data_path is provided and exists - raw_eval_dataset = None - use_separate_eval = False - if eval_data_path and eval_data_path.strip(): + train_ds = load_from_disk(train_ds_path) + if train_ratio != 1.0: + accelerator.print(f"Subsampling train dataset to {train_ratio:.2%}") + train_ds = train_ds.select(range(int(len(train_ds) * train_ratio))) + + eval_ds = None + if eval_ds_path: try: with accelerator.main_process_first(): - raw_eval_dataset = load_from_disk(eval_data_path) - use_separate_eval = True - accelerator.print(f"Using separate evaluation dataset from: {eval_data_path}") - accelerator.print(f"Training with full train dataset") + eval_ds = load_from_disk(eval_ds_path) + accelerator.print(f"Loaded separate eval dataset from {eval_ds_path}") except Exception as e: - accelerator.print(f"Warning: Could not load eval dataset from {eval_data_path}: {e}") - accelerator.print("Will split train dataset instead") - - if not use_separate_eval: - accelerator.print("No separate eval dataset provided, will split train dataset using ratio") - - # Get max_length and action from config if available - max_length = data_config.get('max_length', None) - max_length_action = (data_config.get('max_length_action', 'cutoff') or 'cutoff').lower() - if max_length_action not in {"cutoff", "filter"}: - max_length_action = "cutoff" - - # Optimized sequence length analysis using numpy for train dataset - accelerator.print("Analyzing sequence lengths...") - # Use map with batched=True for faster processing - def get_lengths_batch(examples): - return {"lengths": [len(tokens) for tokens in examples['real_token']]} - - with accelerator.main_process_first(): - lengths_dataset = raw_train_dataset.map( - get_lengths_batch, - batched=True, - batch_size=1000, - num_proc=16, # Use multiple processes - remove_columns=raw_train_dataset.column_names - ) - - raw_lengths = lengths_dataset['lengths'] - + accelerator.print(f"Warning: could not load eval dataset {eval_ds_path}: {e}") + if max_length: - accelerator.print(f"Using max_length: {max_length}") - long_sequences = sum(1 for length in raw_lengths if length > max_length) - if long_sequences > 0: - if max_length_action == "filter": - accelerator.print(f"Warning: {long_sequences} sequences will be filtered out") - else: - accelerator.print(f"Warning: {long_sequences} sequences will be truncated") - - # Prefilter samples in a single pass to avoid batched map row count mismatch - def _prefilter_batch(examples): - masks = examples['mask'] - tokens = examples['real_token'] - mismatches = examples['mismatch'] if 'mismatch' in examples else None - enforce_len = bool(max_length) and (max_length_action == "filter") - apply_window = bool(max_length) and (max_length_action == "cutoff") - keep = [] - for i in range(len(tokens)): - end = max_length if apply_window else len(tokens[i]) - mask_i = masks[i][:end] - mismatch_i = mismatches[i][:end] if mismatches is not None else None - ok = any(mask_i) - if ok and mismatch_i is not None: - has_mismatch_on_mask = any((mask_i[j] == 1) and (mismatch_i[j] == 1) for j in range(len(mask_i))) - ok = ok and has_mismatch_on_mask - if enforce_len: - ok = ok and (len(tokens[i]) <= max_length) - keep.append(ok) - return keep + accelerator.print(f"max_length={max_length} ({'truncate' if truncate else 'filter'})") + # Prefilter: drop rows with no supervision or (when filter mode) over-length. + filt = partial(_filter_keep, max_length=max_length, truncate=truncate) with accelerator.main_process_first(): - raw_train_dataset = raw_train_dataset.filter( - _prefilter_batch, - batched=True, - batch_size=1000, - num_proc=16, - desc="Prefiltering train dataset" - ) - - if use_separate_eval and raw_eval_dataset is not None: - with accelerator.main_process_first(): - raw_eval_dataset = raw_eval_dataset.filter( - _prefilter_batch, - batched=True, - batch_size=1000, - num_proc=16, - desc="Prefiltering eval dataset" - ) - - accelerator.print("Preprocessing datasets...") - num_proc = 16 # use 16 processes to parallel process - batch_size = 2000 # batch size - - # New: separate strategies for iter_count and iter_count_labels - # Backward-compat: fall back to model-inferred hard_token_strategy for iter_count; labels default to copy - iter_count_strategy = data_config.get('iter_count_strategy', hard_token_strategy or 'auto') - iter_count_strategy_kwargs = data_config.get('iter_count_strategy_kwargs', {}) - iter_count_label_strategy = data_config.get('iter_count_label_strategy', 'copy') - query_iter_count = int(data_config.get('query_iter_count', 1)) - preprocess_fn = partial( - preprocess_for_sft_batch, - max_iter=max_iter, - max_length=max_length, - iter_count_strategy=iter_count_strategy, - iter_count_strategy_kwargs=iter_count_strategy_kwargs, - iter_count_label_strategy=iter_count_label_strategy, - query_iter_count=query_iter_count, - max_length_action=max_length_action, + train_ds = train_ds.filter(filt, batched=True, batch_size=1000, num_proc=16, desc="Prefiltering train") + if eval_ds is not None: + eval_ds = eval_ds.filter(filt, batched=True, batch_size=1000, num_proc=16, desc="Prefiltering eval") + + # Map to SFT examples. + accelerator.print("Preprocessing datasets…") + preprocess = partial( + _preprocess_for_sft_batch, max_iter=max_iter, max_length=max_length, + truncate=truncate, query_iter_count=query_iter_count, ) - - # Determine columns to remove - remove_cols = ['data_id', 'real_text', 'real_token', 'mask'] - if 'divergent' in raw_train_dataset.column_names: - remove_cols.append('divergent') - # Remove entropy to avoid tokenizer.pad trying to collate it - if 'entropy' in raw_train_dataset.column_names: - remove_cols.append('entropy') - # Remove ds_divergence for the same reason as entropy - if 'ds_divergence' in raw_train_dataset.column_names: - remove_cols.append('ds_divergence') - if 'mismatch' in raw_train_dataset.column_names: - remove_cols.append('mismatch') - - if use_separate_eval: - # Process train and eval datasets separately - accelerator.print("Processing train dataset...") - with accelerator.main_process_first(): - processed_train_dataset = raw_train_dataset.map( - preprocess_fn, - batched=True, - batch_size=batch_size, - num_proc=num_proc, - remove_columns=remove_cols, - desc="Processing train dataset" - ) - - accelerator.print("Processing eval dataset...") - # Determine eval columns to remove - eval_remove_cols = ['data_id', 'real_text', 'real_token', 'mask'] - if 'divergent' in raw_eval_dataset.column_names: - eval_remove_cols.append('divergent') - if 'mismatch' in raw_eval_dataset.column_names: - eval_remove_cols.append('mismatch') - if 'problem_idx' in raw_eval_dataset.column_names: - eval_remove_cols.append('problem_idx') - if 'entropy' in raw_eval_dataset.column_names: - eval_remove_cols.append('entropy') - if 'ds_divergence' in raw_eval_dataset.column_names: - eval_remove_cols.append('ds_divergence') - eval_preprocess_fn = partial( - preprocess_for_sft_batch, - max_iter=max_iter, - max_length=max_length, - iter_count_strategy=iter_count_strategy, - iter_count_strategy_kwargs=iter_count_strategy_kwargs, - iter_count_label_strategy=iter_count_label_strategy, - query_iter_count=query_iter_count, - max_length_action=max_length_action, + def _drop_input_cols(ds): + return [c for c in _LABELLER_INPUT_COLUMNS if c in ds.column_names] + + with accelerator.main_process_first(): + train_ds = train_ds.map( + preprocess, batched=True, batch_size=2000, num_proc=16, + remove_columns=_drop_input_cols(train_ds), desc="SFT train preprocessing", ) - - with accelerator.main_process_first(): - processed_eval_dataset = raw_eval_dataset.map( - eval_preprocess_fn, - batched=True, - batch_size=batch_size, - num_proc=num_proc, - remove_columns=eval_remove_cols, - desc="Processing eval dataset" + if eval_ds is not None: + eval_ds = eval_ds.map( + preprocess, batched=True, batch_size=2000, num_proc=16, + remove_columns=_drop_input_cols(eval_ds), desc="SFT eval preprocessing", ) - else: - # Process the combined dataset first, then split - accelerator.print("Processing combined dataset...") - with accelerator.main_process_first(): - processed_dataset = raw_train_dataset.map( - preprocess_fn, - batched=True, - batch_size=batch_size, - num_proc=num_proc, - remove_columns=remove_cols, - desc="Processing dataset" - ) - - # Split dataset into train and eval after preprocessing - accelerator.print("Splitting dataset into train and eval...") - if eval_data_ratio > 0: - split_dataset = processed_dataset.train_test_split(test_size=eval_data_ratio, seed=42) - processed_train_dataset = split_dataset['train'] - processed_eval_dataset = split_dataset['test'] - else: - processed_train_dataset = processed_dataset - processed_eval_dataset = None - - if hard_relative_weight != 1.0: - # Calculate hard token ratio using processed train dataset - avg_hard_ratio = calculate_hard_token_ratio( - processed_train_dataset, hard_token_strategy, accelerator - ) - else: - avg_hard_ratio = None - - accelerator.print(f"Train dataset size: {len(processed_train_dataset)}") - if processed_eval_dataset is not None: - accelerator.print(f"Eval dataset size: {len(processed_eval_dataset)}") - - return processed_train_dataset, processed_eval_dataset, avg_hard_ratio + # If no separate eval set was provided, split a fraction from train. + if eval_ds is None and eval_ratio > 0: + accelerator.print(f"Splitting {eval_ratio:.2%} of train as eval") + split = train_ds.train_test_split(test_size=eval_ratio, seed=42) + train_ds = split["train"] + eval_ds = split["test"] + + accelerator.print(f"Train dataset size: {len(train_ds)}") + if eval_ds is not None: + accelerator.print(f"Eval dataset size: {len(eval_ds)}") + return train_ds, eval_ds diff --git a/tah/utils/sampling.py b/tah/utils/sampling.py deleted file mode 100644 index 268096f..0000000 --- a/tah/utils/sampling.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -from typing import Union - -def sample_token(logits: torch.Tensor, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1) -> Union[int, torch.Tensor]: - """Sample a token from logits using temperature, top-p, and top-k sampling. - Args: - logits: Token logits of shape [vocab_size] or [batch_size, vocab_size] - temperature: Temperature for sampling (>0). Higher values produce more random samples. - top_p: Top-p probability threshold for nucleus sampling (0 < top_p ≤ 1) - top_k: Top-k threshold for sampling (if -1, no top-k filtering is applied) - Returns: - Sampled token ID (int for single sample, tensor for batch) - """ - if not isinstance(logits, torch.Tensor): - raise TypeError("logits must be a torch.Tensor") - - if logits.dim() not in [1, 2]: - raise ValueError("logits must have shape [vocab_size] or [batch_size, vocab_size]") - - # Handle single dimension input - is_single_input = logits.dim() == 1 - if is_single_input: - logits = logits.unsqueeze(0) - - batch_size = logits.shape[0] - - # For greedy sampling (temperature=0), just return argmax - if temperature == 0 or temperature <= 1e-5: - tokens = torch.argmax(logits, dim=-1) - return tokens.item() if is_single_input else tokens - - # Convert to probabilities - probs = torch.nn.functional.softmax(logits / temperature, dim=-1) - - # Apply top-k filtering first (if specified) - if top_k != -1: - # Get top-k values and indices - top_k_values, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]), dim=-1) - - # Create a mask to zero out non-top-k probabilities - mask = torch.zeros_like(probs, dtype=torch.bool) - mask.scatter_(-1, top_k_indices, True) - - # Zero out non-top-k probabilities - probs = probs * mask.float() - - # Renormalize probabilities - probs = probs / probs.sum(dim=-1, keepdim=True) - - # Apply top-p (nucleus) sampling - if top_p < 1.0: - # Sort probabilities in descending order - sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) - - # Calculate cumulative probabilities - cumulative_probs = torch.cumsum(sorted_probs, dim=-1) - - # Create a mask for probabilities to keep - # Values above top_p threshold are masked out - mask = cumulative_probs <= top_p - - # Always keep at least one token - mask[:, 0] = True - - # Zero out masked positions to exclude them from sampling - sorted_probs = sorted_probs * mask.float() - - # Renormalize probabilities - sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) - - # Sample from the filtered distribution - sampled_indices = torch.multinomial(sorted_probs, num_samples=1) - - # Map back to original vocabulary indices - tokens = torch.gather(sorted_indices, dim=-1, index=sampled_indices) - tokens = tokens.squeeze(-1) # Remove sample dimension - else: - # Direct sampling if no top-p filtering - tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) - - return tokens.item() if is_single_input else tokens diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/_harness.py b/tests/_harness.py new file mode 100644 index 0000000..d6a7d25 --- /dev/null +++ b/tests/_harness.py @@ -0,0 +1,211 @@ +"""Test harness shared across component tests. + +Each component test follows the same pattern: + + 1. ``capture(name, fn, *args)`` runs ``fn(*args)`` against the public TaH at + ``$TAH_PUBLIC_ROOT`` (default ``/tmp/TaH-pub``) inside a subprocess and + pickles the inputs+outputs to ``tests/baselines/.pt``. + 2. ``compare(name, fn, *args)`` runs ``fn(*args)`` against the cleaned + ``tah-release`` in this process and asserts every output tensor matches + the recorded baseline within ``ACC_TOL``. + 3. ``bench(label, fn, *args, ref_fn)`` measures wall-clock for ``fn`` vs + ``ref_fn`` over ``WARMUP+ITERS`` calls and prints a one-line speedup. + +Snapshots are PyTorch ``.pt`` files containing a dict ``{"args": ..., "out": +...}``. Tensors are saved on CPU; we move them to ``DEVICE`` at compare time. + +The split exists because public TaH and the cleaned tah-release share the +package name ``tah``; both cannot live on ``sys.path`` at once. +""" +from __future__ import annotations + +import os +import pickle +import subprocess +import sys +import time +from pathlib import Path +from typing import Any, Callable + +import torch + +REPO = Path(__file__).resolve().parents[1] +BASELINE_DIR = REPO / "tests" / "baselines" +BASELINE_DIR.mkdir(parents=True, exist_ok=True) + +PUBLIC_ROOT = Path(os.environ.get("TAH_PUBLIC_ROOT", "/tmp/TaH-pub")) +DEVICE = os.environ.get("TAH_TEST_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") +ACC_TOL = float(os.environ.get("TAH_ACC_TOL", "1e-5")) +WARMUP = int(os.environ.get("TAH_BENCH_WARMUP", "5")) +ITERS = int(os.environ.get("TAH_BENCH_ITERS", "30")) + + +def _to_cpu(x: Any) -> Any: + if isinstance(x, torch.Tensor): + return x.detach().cpu() + if isinstance(x, dict): + return {k: _to_cpu(v) for k, v in x.items()} + if isinstance(x, (list, tuple)): + t = type(x) + return t(_to_cpu(v) for v in x) + return x + + +def _to_device(x: Any, device: str) -> Any: + if isinstance(x, torch.Tensor): + return x.to(device) + if isinstance(x, dict): + return {k: _to_device(v, device) for k, v in x.items()} + if isinstance(x, (list, tuple)): + t = type(x) + return t(_to_device(v, device) for v in x) + return x + + +def baseline_path(name: str) -> Path: + return BASELINE_DIR / f"{name}.pt" + + +def have_baseline(name: str) -> bool: + return baseline_path(name).exists() + + +def capture(name: str, code: str, payload: dict | None = None) -> dict: + """Run ``code`` (a Python source string) inside a subprocess scoped to + ``PUBLIC_ROOT``, capture its returned dict, save as a baseline, and return. + + ``code`` MUST define ``def run(payload):`` returning a JSON/pickle-able + dict. ``payload`` is forwarded as the single argument. + """ + payload = payload or {} + payload_pkl = pickle.dumps(payload) + out_path = baseline_path(name) + runner = ( + "import os, sys, pickle\n" + f"sys.path.insert(0, {str(PUBLIC_ROOT)!r})\n" + "import torch\n" + f"{code}\n" + "payload = pickle.loads(sys.stdin.buffer.read())\n" + "out = run(payload)\n" + "from tests._harness import _to_cpu # noqa: E402 -- imported lazily\n" + ) + # We can't import tests._harness inside the subprocess because the public + # TaH path is first on sys.path. Inline a CPU mover instead. + runner = ( + f"import os, sys, pickle\n" + f"sys.path.insert(0, {str(PUBLIC_ROOT)!r})\n" + f"import torch\n" + f"def _to_cpu(x):\n" + f" if isinstance(x, torch.Tensor): return x.detach().cpu()\n" + f" if isinstance(x, dict): return {{k: _to_cpu(v) for k, v in x.items()}}\n" + f" if isinstance(x, (list, tuple)): return type(x)(_to_cpu(v) for v in x)\n" + f" return x\n" + f"{code}\n" + f"payload = pickle.loads(sys.stdin.buffer.read())\n" + f"out = run(payload)\n" + f"sys.stdout.buffer.write(pickle.dumps({{'args': payload, 'out': _to_cpu(out)}}))\n" + ) + proc = subprocess.run( + [sys.executable, "-c", runner], + input=payload_pkl, + capture_output=True, + env={**os.environ, "PYTHONPATH": str(PUBLIC_ROOT)}, + check=False, + ) + if proc.returncode != 0: + raise RuntimeError( + f"baseline capture for {name!r} failed:\n" + f"--- stderr ---\n{proc.stderr.decode(errors='replace')}\n" + f"--- stdout ---\n{proc.stdout.decode(errors='replace')}" + ) + snap = pickle.loads(proc.stdout) + torch.save(snap, out_path) + return snap + + +def load_baseline(name: str) -> dict: + if not have_baseline(name): + raise FileNotFoundError( + f"no baseline at {baseline_path(name)} — run capture(...) first" + ) + return torch.load(baseline_path(name), weights_only=False) + + +def assert_close(name: str, actual: Any, expected: Any, atol: float = ACC_TOL, rtol: float = 1e-4): + """Recursively compare actual vs expected; raise if any leaf disagrees.""" + if isinstance(expected, torch.Tensor): + if not isinstance(actual, torch.Tensor): + raise AssertionError(f"{name}: expected Tensor, got {type(actual).__name__}") + a = actual.detach().cpu() + e = expected.detach().cpu() + if a.shape != e.shape: + raise AssertionError(f"{name}: shape mismatch {tuple(a.shape)} vs {tuple(e.shape)}") + if a.dtype != e.dtype: + # Best-effort dtype unify before compare; some bool/long mismatches are OK + try: + a = a.to(e.dtype) + except Exception: + raise AssertionError(f"{name}: dtype mismatch {a.dtype} vs {e.dtype}") from None + if e.is_floating_point(): + diff = (a.float() - e.float()).abs() + max_abs = float(diff.max().item()) if diff.numel() else 0.0 + tol = atol + rtol * float(e.float().abs().max().item() if e.numel() else 0.0) + if max_abs > tol: + raise AssertionError( + f"{name}: max abs diff {max_abs:.3e} > tol {tol:.3e}" + ) + else: + if not torch.equal(a, e): + n_diff = int((a != e).sum().item()) + raise AssertionError(f"{name}: {n_diff} integer/bool elements differ") + return + if isinstance(expected, dict): + if not isinstance(actual, dict): + raise AssertionError(f"{name}: expected dict, got {type(actual).__name__}") + for k in expected: + if k not in actual: + raise AssertionError(f"{name}: missing key {k!r}") + assert_close(f"{name}.{k}", actual[k], expected[k], atol, rtol) + for k in actual: + if k not in expected: + raise AssertionError(f"{name}: unexpected key {k!r}") + return + if isinstance(expected, (list, tuple)): + if not isinstance(actual, type(expected)): + raise AssertionError(f"{name}: container mismatch {type(actual).__name__} vs {type(expected).__name__}") + if len(expected) != len(actual): + raise AssertionError(f"{name}: length {len(actual)} vs {len(expected)}") + for i, (a, e) in enumerate(zip(actual, expected)): + assert_close(f"{name}[{i}]", a, e, atol, rtol) + return + if expected != actual: + raise AssertionError(f"{name}: {actual!r} vs {expected!r}") + + +def bench(label: str, fn: Callable, ref_fn: Callable | None = None, *, warmup: int = WARMUP, iters: int = ITERS) -> dict: + """Time ``fn`` (and optionally ``ref_fn``) and print a one-line summary. + + Returns a dict with ``ms`` (cleaned) and ``ref_ms`` (or None) so callers can + assert non-regression. Caller is responsible for any required CUDA syncing + inside ``fn`` / ``ref_fn``. + """ + def _time(f): + for _ in range(warmup): + f() + if torch.cuda.is_available(): + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + f() + if torch.cuda.is_available(): + torch.cuda.synchronize() + return (time.perf_counter() - t0) / iters * 1e3 + + ms = _time(fn) + if ref_fn is not None: + ref_ms = _time(ref_fn) + speedup = ref_ms / ms if ms > 0 else float("inf") + print(f" bench[{label}]: clean={ms:.3f}ms ref={ref_ms:.3f}ms speedup={speedup:.2f}x") + return {"ms": ms, "ref_ms": ref_ms, "speedup": speedup} + print(f" bench[{label}]: clean={ms:.3f}ms") + return {"ms": ms, "ref_ms": None, "speedup": None} diff --git a/tests/baselines/.gitignore b/tests/baselines/.gitignore new file mode 100644 index 0000000..2a4d57a --- /dev/null +++ b/tests/baselines/.gitignore @@ -0,0 +1,4 @@ +# Component baseline snapshots (regenerable from public TaH). +# Keep the directory in git but never commit individual .pt files. +*.pt +!.gitignore diff --git a/tests/bench.py b/tests/bench.py new file mode 100644 index 0000000..b62cc56 --- /dev/null +++ b/tests/bench.py @@ -0,0 +1,300 @@ +"""Microbenchmarks for tah/model + tah/evaluate hot paths. + +Run:: + + python tests/bench.py components # iterated helpers, ~1s each + python tests/bench.py e2e # real TaH-plus-1.7B forward (~5s) + python tests/bench.py all # both, then print a summary + python tests/bench.py all --json out.json # also write a json report + +Each component is timed as ``(warmup × WARMUP) → measured × ITERS``, with a +``torch.cuda.synchronize`` before the timer start and after the last run, so +the reported numbers exclude lazy-init / first-call compile costs. + +The benchmark is the verification surface for the speed-impacting refactors +landing alongside this file: capture a baseline before each change, run the +same command after, look at the printed delta. Component shapes are chosen +to roughly mirror what the wrapper sees during a prefill on TaH-plus-1.7B +(``B=2, T=64, V=151936, H=2048, L=28``). +""" +from __future__ import annotations + +import argparse +import json +import os +import time +from dataclasses import asdict, dataclass, field +from typing import Callable, Dict, List, Optional + +import torch + + +# ──────────────────────────────────────────────────────────────────────────── +# Bench infra +# ──────────────────────────────────────────────────────────────────────────── + + +@dataclass +class BenchResult: + name: str + ms: float + iters: int + warmup: int + extra: Dict[str, float] = field(default_factory=dict) + + +def time_fn(fn: Callable[[], None], *, warmup: int = 5, iters: int = 30) -> float: + """Median-of-iters wall-clock per call in ms; CUDA-synced if available.""" + for _ in range(warmup): + fn() + if torch.cuda.is_available(): + torch.cuda.synchronize() + samples: List[float] = [] + for _ in range(iters): + if torch.cuda.is_available(): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + if torch.cuda.is_available(): + torch.cuda.synchronize() + samples.append((time.perf_counter() - t0) * 1e3) + samples.sort() + return samples[len(samples) // 2] # median + + +def _print_table(results: List[BenchResult]) -> None: + if not results: + return + name_w = max(len(r.name) for r in results) + 2 + print(f"{'name':<{name_w}}{'ms':>10}{'iters':>10}{'extra':>20}") + print("-" * (name_w + 40)) + for r in results: + extra = " ".join(f"{k}={v:.3g}" for k, v in r.extra.items()) if r.extra else "" + print(f"{r.name:<{name_w}}{r.ms:>10.3f}{r.iters:>10}{extra:>20}") + + +# ──────────────────────────────────────────────────────────────────────────── +# Component benchmarks +# ──────────────────────────────────────────────────────────────────────────── + + +def bench_components(device: str) -> List[BenchResult]: + from tah.model.iter_decider import MLPIterDecider + from tah.model.loss import IterDeciderLoss, NextTokenPredLoss + from tah.model.tah_model import ( + additive_logits_update, + gather_active, + scatter_back, + topk_softmax_input_update, + ) + + # Shapes chosen to mirror a realistic TaH-plus-1.7B prefill. + B, T, V, H, L = 2, 64, 151936, 2048, 28 + K = 100 # input_updater topk + + results: List[BenchResult] = [] + g = torch.Generator(device=device).manual_seed(0) + dtype = torch.bfloat16 if device == "cuda" else torch.float32 + + # ── input_updater ─────────────────────────────────────────────────── + logits = torch.randn(B, T, V, generator=g, device=device, dtype=dtype) + embed_w = torch.randn(V, H, generator=g, device=device, dtype=dtype) + results.append(BenchResult( + name="topk_softmax_input_update", + ms=time_fn(lambda: topk_softmax_input_update(logits, embed_w, K)), + iters=30, warmup=5, + extra={"shape": float(B * T * V), "topk": float(K)}, + )) + + # ── output_updater (additive) ─────────────────────────────────────── + a = torch.randn(B, T, V, generator=g, device=device, dtype=dtype) + b = torch.randn(B, T, V, generator=g, device=device, dtype=dtype) + results.append(BenchResult( + name="additive_logits_update", + ms=time_fn(lambda: additive_logits_update(a, b)), + iters=30, warmup=5, + )) + + # ── gather_active / scatter_back ──────────────────────────────────── + # Mask: half the positions active (typical mid-iteration state). + mask = torch.zeros(B, T, dtype=torch.bool, device=device) + mask[:, :T // 2] = True + mask = mask[:, torch.randperm(T)] + embeds = torch.randn(B, T, H, generator=g, device=device, dtype=dtype) + pos_ids = torch.arange(T, device=device).expand(B, T).clone() + valid = torch.ones(B, T, dtype=torch.long, device=device) + results.append(BenchResult( + name="gather_active", + ms=time_fn(lambda: gather_active(mask, embeds, pos_ids, valid)), + iters=30, warmup=5, + extra={"active_frac": float(mask.float().mean().item())}, + )) + + # scatter_back source comes from gather_active, so warm both up together. + pad_mask, *gathered = gather_active(mask, embeds) + src = gathered[0] + dest = torch.zeros_like(embeds) + results.append(BenchResult( + name="scatter_back", + ms=time_fn(lambda: scatter_back(mask, src=src, dest=dest)), + iters=30, warmup=5, + )) + assignment = torch.zeros(B, T, dtype=torch.bool, device=device) + assignment[:, T // 4: 3 * T // 4] = True + results.append(BenchResult( + name="scatter_back+mask", + ms=time_fn(lambda: scatter_back(mask, src=src, dest=dest, assignment_mask=assignment)), + iters=30, warmup=5, + )) + + # ── MLPIterDecider.forward ────────────────────────────────────────── + torch.manual_seed(7) + decider = MLPIterDecider( + topk=K, hidden_states_size=H, + hidden_states_layer_nums=[2, 10, 18, 26], + hidden_dims=[512, 512, 512, 512, 512, 512], + expansion_factor=4, dropout_rate=0.1, normalize_input=False, + threshold=0.9, max_iter=2, dtype=dtype, + ).to(device) + decider_logits = logits[mask] # (n_active, V) + decider_hidden = torch.randn(decider_logits.shape[0], L, H, generator=g, device=device, dtype=dtype) + results.append(BenchResult( + name="MLPIterDecider.forward", + ms=time_fn(lambda: decider(logits=decider_logits, iter_depth=1, all_hidden_states=decider_hidden)), + iters=30, warmup=5, + extra={"n_active": float(decider_logits.shape[0])}, + )) + + # ── NextTokenPredLoss.final_loss_func ─────────────────────────────── + labels = torch.randint(0, V, (B, T), generator=g, device=device, dtype=torch.long) + iter_count = torch.randint(0, 3, (B, T), generator=g, device=device, dtype=torch.long) + nl = NextTokenPredLoss() + nl.prepare_loss(B, T, device, torch.float32) + results.append(BenchResult( + name="NextTokenPredLoss.final", + ms=time_fn(lambda: nl.final_loss_func( + logits=logits, labels_shifted=labels, iter_count=iter_count, + training=False, num_items_in_batch=int((labels != -100).sum().item()), + )), + iters=30, warmup=5, + )) + + # ── IterDeciderLoss.intra_iter_loss_func ──────────────────────────── + # The wrapper hands the loss a flattened tensor of `valid` continue logits; + # mirror that here so the bench matches the production call shape. + n_valid = int(mask.sum().item()) + continue_logits = torch.randn(n_valid, generator=g, device=device, dtype=torch.float32) + iter_count_labels = torch.randint(0, 3, (B, T), generator=g, device=device, dtype=torch.long) + valid_mask = mask.long() + bce = IterDeciderLoss(pos_weight=2.0, skip_last_iter=True, max_iter=2) + bce.prepare_loss(B, T, device, torch.float32) + + def _intra(): + bce.iter_decider_loss_per_token = torch.zeros(B, T, device=device, dtype=torch.float32) + bce.intra_iter_loss_func( + active_logits=None, current_iter_mask=mask, active_labels_shifted=None, + active_valid_continue_logits=continue_logits, active_valid_mask=valid_mask, + iter_depth=1, active_iter_count_labels=iter_count_labels, + iter_decider_threshold=0.5, + ) + results.append(BenchResult( + name="IterDeciderLoss.intra", + ms=time_fn(_intra), iters=30, warmup=5, + extra={"n_valid": float(n_valid)}, + )) + + return results + + +# ──────────────────────────────────────────────────────────────────────────── +# End-to-end TaH-plus-1.7B forward +# ──────────────────────────────────────────────────────────────────────────── + + +def bench_e2e(device: str) -> List[BenchResult]: + if device != "cuda": + print("[skip] e2e bench requires CUDA") + return [] + from transformers import AutoTokenizer + from tah.model.tah_model import TaHForCausalLM + from tah.model.utils import TaHForCasualLM_generate + + ckpt = os.environ.get("TAH_CHECKPOINT", "nics-efc/TaH-plus-1.7B") + tok = AutoTokenizer.from_pretrained(ckpt) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + model = TaHForCausalLM.from_pretrained( + ckpt, torch_dtype=torch.bfloat16, device_map=device, attn_implementation="sdpa", + ) + + # Forward latency on a small prompt (representative of one decode step + # except for KV cache, which we skip here for a clean per-call number). + text = "Compute 17 + 25. Reply with a single integer." + inp = tok(text, return_tensors="pt").to(model.device) + + def _forward(): + with torch.no_grad(): + model(**inp, use_cache=False) + + fwd = BenchResult( + name="TaHForCausalLM.forward", + ms=time_fn(_forward, warmup=3, iters=10), + iters=10, warmup=3, + extra={"input_len": float(inp["input_ids"].shape[1])}, + ) + + # Generation: 32 tokens greedy, full TaHForCasualLM_generate path. + def _generate(): + with torch.no_grad(): + TaHForCasualLM_generate( + tah_model=model, tokenizer=tok, model_inputs=dict(inp), + max_new_tokens=32, do_sample=False, verbose=False, + ) + + gen = BenchResult( + name="TaHForCasualLM_generate(32)", + ms=time_fn(_generate, warmup=2, iters=5), + iters=5, warmup=2, + extra={"new_tokens": 32.0}, + ) + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return [fwd, gen] + + +# ──────────────────────────────────────────────────────────────────────────── +# CLI +# ──────────────────────────────────────────────────────────────────────────── + + +def main() -> None: + parser = argparse.ArgumentParser(description="TaH benchmarks") + parser.add_argument("scope", choices=("components", "e2e", "all")) + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--json", default=None, help="Write all results to a json file") + args = parser.parse_args() + + print(f"Device: {args.device}\n") + all_results: List[BenchResult] = [] + if args.scope in ("components", "all"): + print("=== components ===") + cr = bench_components(args.device) + _print_table(cr) + print() + all_results.extend(cr) + if args.scope in ("e2e", "all"): + print("=== end-to-end ===") + er = bench_e2e(args.device) + _print_table(er) + print() + all_results.extend(er) + + if args.json: + with open(args.json, "w") as f: + json.dump({"device": args.device, "results": [asdict(r) for r in all_results]}, f, indent=2) + print(f"\nWrote {args.json}") + + +if __name__ == "__main__": + main() diff --git a/tests/bench_compile.py b/tests/bench_compile.py new file mode 100644 index 0000000..2be6f2c --- /dev/null +++ b/tests/bench_compile.py @@ -0,0 +1,129 @@ +"""Compile vs eager bench for the wrapper's hot helpers. + +For each candidate, time eager → compile-warm → compile-steady. ``compile-warm`` +is the *first* call (full guard build), ``compile-steady`` is the median of +30 subsequent calls. The deltas tell us whether the helper is worth wrapping +in ``torch.compile`` for inference (steady-state speedup must justify +warmup cost; for serving ~1000s of forwards, even a small steady speedup +amortises). + +Run:: + + python tests/bench_compile.py +""" +from __future__ import annotations + +import time + +import torch + + +def _time(fn, *, iters: int = 30, warmup: int = 5) -> float: + for _ in range(warmup): + fn() + if torch.cuda.is_available(): + torch.cuda.synchronize() + samples = [] + for _ in range(iters): + if torch.cuda.is_available(): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + if torch.cuda.is_available(): + torch.cuda.synchronize() + samples.append((time.perf_counter() - t0) * 1e3) + samples.sort() + return samples[len(samples) // 2] + + +def _first_call_ms(fn) -> float: + if torch.cuda.is_available(): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + if torch.cuda.is_available(): + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1e3 + + +def main() -> None: + if not torch.cuda.is_available(): + print("compile bench requires CUDA") + return + + from tah.model.iter_decider import MLPIterDecider + from tah.model.tah_model import ( + additive_logits_update, + gather_active, + scatter_back, + topk_softmax_input_update, + ) + + device = "cuda" + dtype = torch.bfloat16 + B, T, V, H, L = 2, 64, 151936, 2048, 28 + K = 100 + + g = torch.Generator(device=device).manual_seed(0) + logits = torch.randn(B, T, V, generator=g, device=device, dtype=dtype) + embed_w = torch.randn(V, H, generator=g, device=device, dtype=dtype) + a = torch.randn(B, T, V, generator=g, device=device, dtype=dtype) + b = torch.randn(B, T, V, generator=g, device=device, dtype=dtype) + + print(f"{'name':<35}{'eager':>10}{'compile1':>10}{'compileN':>10}{'speedup':>10}") + print("-" * 75) + + def _row(name, eager_fn, compile_fn): + eager_ms = _time(eager_fn) + compile1_ms = _first_call_ms(compile_fn) + compileN_ms = _time(compile_fn) + speedup = eager_ms / compileN_ms if compileN_ms > 0 else float("inf") + print(f"{name:<35}{eager_ms:>10.3f}{compile1_ms:>10.3f}{compileN_ms:>10.3f}{speedup:>10.2f}x") + + # topk_softmax_input_update + f = lambda: topk_softmax_input_update(logits, embed_w, K) + f_c = torch.compile(topk_softmax_input_update) + _row("topk_softmax_input_update", f, lambda: f_c(logits, embed_w, K)) + + # additive_logits_update + f = lambda: additive_logits_update(a, b) + f_c = torch.compile(additive_logits_update) + _row("additive_logits_update", f, lambda: f_c(a, b)) + + # gather_active (dynamic max_active, may recompile) + mask = torch.zeros(B, T, dtype=torch.bool, device=device) + mask[:, :T // 2] = True + embeds = torch.randn(B, T, H, generator=g, device=device, dtype=dtype) + f = lambda: gather_active(mask, embeds) + f_c = torch.compile(gather_active, dynamic=True) + _row("gather_active(dynamic)", f, lambda: f_c(mask, embeds)) + + # scatter_back (also dynamic max_active) + pad_mask, gathered = gather_active(mask, embeds)[0], gather_active(mask, embeds)[1] + src = gathered + dest = torch.zeros_like(embeds) + f = lambda: scatter_back(mask, src=src, dest=dest) + f_c = torch.compile(scatter_back, dynamic=True) + _row("scatter_back(dynamic)", f, lambda: f_c(mask, src=src, dest=dest)) + + # MLPIterDecider.forward + torch.manual_seed(7) + decider = MLPIterDecider( + topk=K, hidden_states_size=H, + hidden_states_layer_nums=[2, 10, 18, 26], + hidden_dims=[512, 512, 512, 512, 512, 512], + expansion_factor=4, dropout_rate=0.1, normalize_input=False, + threshold=0.9, max_iter=2, dtype=dtype, + ).to(device).eval() + decider_logits = logits[mask] + decider_hidden = torch.randn(decider_logits.shape[0], L, H, generator=g, device=device, dtype=dtype) + decider_compiled = torch.compile(decider, dynamic=True) + _row( + "MLPIterDecider.forward", + lambda: decider(logits=decider_logits, iter_depth=1, all_hidden_states=decider_hidden), + lambda: decider_compiled(logits=decider_logits, iter_depth=1, all_hidden_states=decider_hidden), + ) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..97dde6e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,44 @@ +"""Shared fixtures for tah-release component tests. + +We intentionally use *tiny* synthetic shapes so each component test runs in +under a second on CPU, while still exercising the >1 batch / >1 layer / >1 +iteration code paths. +""" +from __future__ import annotations + +import random + +import pytest +import torch + + +SEED = 4242 + + +@pytest.fixture(autouse=True) +def _deterministic(): + random.seed(SEED) + torch.manual_seed(SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(SEED) + yield + + +@pytest.fixture +def device() -> str: + import os + return os.environ.get("TAH_TEST_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.fixture +def shapes() -> dict: + """Tiny but non-trivial shapes for component tests.""" + return { + "B": 2, + "T": 5, + "V": 64, + "H": 32, + "L": 4, # number of transformer layers exposed for hidden states + "TOPK": 8, + "MAX_ITER": 2, + } diff --git a/tests/test_causal_cache.py b/tests/test_causal_cache.py new file mode 100644 index 0000000..7b50f5f --- /dev/null +++ b/tests/test_causal_cache.py @@ -0,0 +1,150 @@ +"""Acc test for ``TaHCache``. + +The cache holds per-(layer, iteration) KV plus position/valid metadata, and +exposes "view up to iter N" reads that the wrapper uses to build attention +masks. Cache shape and contents are part of the checkpoint contract with +minisgl-tah, so this test pins the *behaviour*, not just the API. +""" +from __future__ import annotations + +import pytest +import torch + +from tests._harness import ( + ACC_TOL, + assert_close, + capture, + have_baseline, + load_baseline, +) + + +def _cache_inputs(shapes, device): + B, T, H, L = shapes["B"], shapes["T"], 16, 2 # heads=2, head_dim=H/heads + n_heads = 2 + head_dim = 8 + g = torch.Generator(device=device).manual_seed(130) + + # Two iterations × two layers, with iter-1 attending to a strict subset. + def kv(seq_len): + k = torch.randn(B, n_heads, seq_len, head_dim, generator=g, device=device, dtype=torch.float32) + v = torch.randn(B, n_heads, seq_len, head_dim, generator=g, device=device, dtype=torch.float32) + return k, v + + iter0_layer = [kv(T) for _ in range(L)] + iter1_layer = [kv(T - 2) for _ in range(L)] + pos_iter0 = torch.arange(T, dtype=torch.long, device=device).unsqueeze(0).expand(B, T) + pos_iter1 = torch.arange(T - 2, dtype=torch.long, device=device).unsqueeze(0).expand(B, T - 2) + valid_iter0 = torch.ones(B, T, dtype=torch.long, device=device) + valid_iter1 = torch.ones(B, T - 2, dtype=torch.long, device=device) + + return { + "iter0_layer": iter0_layer, + "iter1_layer": iter1_layer, + "pos_iter0": pos_iter0, + "pos_iter1": pos_iter1, + "valid_iter0": valid_iter0, + "valid_iter1": valid_iter1, + "L": L, + "B": B, + } + + +CACHE_RUNNER = """ +def run(payload): + from tah.model.causal_cache import TaHCache + cache = TaHCache().to(device='cpu', dtype=torch.float32) + L = payload['L'] + # Iter 0 + cache.current_iter_depth = 0 + cache.position_ids_to_cache = payload['pos_iter0'] + cache.valid_mask_to_cache = payload['valid_iter0'] + for layer_idx in range(L): + k, v = payload['iter0_layer'][layer_idx] + cache.update(k, v, layer_idx) + # Iter 1 — narrower + cache.current_iter_depth = 1 + cache.position_ids_to_cache = payload['pos_iter1'] + cache.valid_mask_to_cache = payload['valid_iter1'] + for layer_idx in range(L): + k, v = payload['iter1_layer'][layer_idx] + cache.update(k, v, layer_idx) + + out = {} + for layer_idx in range(L): + kk0, vv0 = cache.get_cache_upto_iter(layer_idx, 0) + kk1, vv1 = cache.get_cache_upto_iter(layer_idx, 1) + out[f'L{layer_idx}_K_upto0'] = kk0 + out[f'L{layer_idx}_V_upto0'] = vv0 + out[f'L{layer_idx}_K_upto1'] = kk1 + out[f'L{layer_idx}_V_upto1'] = vv1 + out[f'L{layer_idx}_pos_upto1'] = cache.get_position_id_upto_iter(layer_idx, 1, init_batch_size=payload['B']) + out[f'L{layer_idx}_valid_upto1'] = cache.get_valid_mask_upto_iter(layer_idx, 1, init_batch_size=payload['B']) + out[f'L{layer_idx}_iteridx_upto1'] = cache.get_cache_iter_index_upto_iter(layer_idx, 1) + return out +""" + + +@pytest.fixture +def baseline(shapes): + args = _cache_inputs(shapes, "cpu") + name = "causal_cache" + if not have_baseline(name): + capture(name, CACHE_RUNNER, payload=args) + return load_baseline(name) + + +def test_cache_view_acc(baseline, device): + from tah.model.causal_cache import TaHCache + + def _to_dev(x): + if torch.is_tensor(x): + return x.to(device) + if isinstance(x, list): + return [_to_dev(v) for v in x] + if isinstance(x, tuple): + return tuple(_to_dev(v) for v in x) + return x + + args = {k: _to_dev(v) for k, v in baseline["args"].items()} + + cache = TaHCache().to(device=device, dtype=torch.float32) + L = args["L"] + cache.current_iter_depth = 0 + cache.position_ids_to_cache = args["pos_iter0"] + cache.valid_mask_to_cache = args["valid_iter0"] + for layer_idx in range(L): + k, v = args["iter0_layer"][layer_idx] + cache.update(k, v, layer_idx) + cache.current_iter_depth = 1 + cache.position_ids_to_cache = args["pos_iter1"] + cache.valid_mask_to_cache = args["valid_iter1"] + for layer_idx in range(L): + k, v = args["iter1_layer"][layer_idx] + cache.update(k, v, layer_idx) + + for layer_idx in range(L): + kk0, vv0 = cache.get_cache_upto_iter(layer_idx, 0) + kk1, vv1 = cache.get_cache_upto_iter(layer_idx, 1) + assert_close(f"L{layer_idx}_K_upto0", kk0, baseline["out"][f"L{layer_idx}_K_upto0"], atol=ACC_TOL) + assert_close(f"L{layer_idx}_V_upto0", vv0, baseline["out"][f"L{layer_idx}_V_upto0"], atol=ACC_TOL) + assert_close(f"L{layer_idx}_K_upto1", kk1, baseline["out"][f"L{layer_idx}_K_upto1"], atol=ACC_TOL) + assert_close(f"L{layer_idx}_V_upto1", vv1, baseline["out"][f"L{layer_idx}_V_upto1"], atol=ACC_TOL) + assert_close( + f"L{layer_idx}_pos_upto1", + cache.get_position_id_upto_iter(layer_idx, 1, init_batch_size=args["B"]), + baseline["out"][f"L{layer_idx}_pos_upto1"], + atol=ACC_TOL, + ) + assert_close( + f"L{layer_idx}_valid_upto1", + cache.get_valid_mask_upto_iter(layer_idx, 1, init_batch_size=args["B"]), + baseline["out"][f"L{layer_idx}_valid_upto1"], + atol=ACC_TOL, + ) + assert_close( + f"L{layer_idx}_iteridx_upto1", + cache.get_cache_iter_index_upto_iter(layer_idx, 1), + baseline["out"][f"L{layer_idx}_iteridx_upto1"], + atol=ACC_TOL, + ) diff --git a/tests/test_input_updater.py b/tests/test_input_updater.py new file mode 100644 index 0000000..5e6da2e --- /dev/null +++ b/tests/test_input_updater.py @@ -0,0 +1,92 @@ +"""Acc + speed test for the input updater. + +Public TaH ships ``TrivialUpdater`` (one class, three modes selected by ctor +args). The canonical recipe uses ``topk=100, use_hidden_states=False`` (i.e. a +top-k softmax over logits, weighted-sum into embedding rows). The cleaned +version exposes a single module function with that single behaviour, so the +test pins exactly the math we care to preserve. +""" +from __future__ import annotations + +import pytest +import torch + +from tests._harness import ( + ACC_TOL, + DEVICE, + assert_close, + bench, + capture, + have_baseline, + load_baseline, +) + + +def _make_inputs(shapes, device): + B, T, V, H, TOPK = shapes["B"], shapes["T"], shapes["V"], shapes["H"], shapes["TOPK"] + g = torch.Generator(device=device).manual_seed(123) + logits = torch.randn(B, T, V, generator=g, device=device, dtype=torch.float32) + prev_inputs = torch.randn(B, T, H, generator=g, device=device, dtype=torch.float32) + embedding_weight = torch.randn(V, H, generator=g, device=device, dtype=torch.float32) + hidden_states = torch.randn(B, T, 4, H, generator=g, device=device, dtype=torch.float32) + return { + "logits": logits, + "prev_inputs": prev_inputs, + "embedding_weight": embedding_weight, + "hidden_states": hidden_states, + "topk": TOPK, + } + + +PUBLIC_RUNNER = """ +def run(payload): + from tah.model.input_updater import TrivialUpdater + upd = TrivialUpdater(use_hidden_states=False, topk=payload['topk']) + out = upd( + logits=payload['logits'], + prev_inputs=payload['prev_inputs'], + embedding_weight=payload['embedding_weight'], + hidden_states=payload['hidden_states'], + ) + return {'updated': out} +""" + + +@pytest.fixture +def baseline(shapes): + inputs = _make_inputs(shapes, "cpu") # baseline saved on CPU + name = "input_updater_topk" + if not have_baseline(name): + capture(name, PUBLIC_RUNNER, payload=inputs) + return load_baseline(name) + + +def test_input_updater_acc(baseline, shapes, device): + from tah.model.tah_model import topk_softmax_input_update + + inputs = {k: (v.to(device) if torch.is_tensor(v) else v) + for k, v in baseline["args"].items()} + out = topk_softmax_input_update( + logits=inputs["logits"], + embedding_weight=inputs["embedding_weight"], + topk=inputs["topk"], + ) + assert_close("input_updater.updated", out, baseline["out"]["updated"], atol=ACC_TOL) + + +def test_input_updater_speed(baseline, shapes, device): + from tah.model.tah_model import topk_softmax_input_update + + inputs = {k: (v.to(device) if torch.is_tensor(v) else v) + for k, v in baseline["args"].items()} + + def cleaned(): + return topk_softmax_input_update( + logits=inputs["logits"], + embedding_weight=inputs["embedding_weight"], + topk=inputs["topk"], + ) + + res = bench("input_updater_topk", cleaned) + # Sanity floor: inlined helper should always be <1ms on tiny shapes. + assert res["ms"] < 1.0 diff --git a/tests/test_iter_decider.py b/tests/test_iter_decider.py new file mode 100644 index 0000000..12c7581 --- /dev/null +++ b/tests/test_iter_decider.py @@ -0,0 +1,145 @@ +"""Acc + speed tests for the iter deciders. + +Two deciders are kept in the cleaned package — both are used in the canonical +training/eval recipes: + +* ``IterLabelDecider`` — step-1 SFT: continue iff oracle iter_count_labels say so. +* ``MLPIterDecider`` — step-2 SFT + eval: learned classifier over hidden + top-k logits. + +Public TaH also shipped ``TrivialIterDecider``, ``AlwaysWrapperIterDecider``, +and ``OracleDynamicIterDecider`` — none are used by the released checkpoint and +are removed in tah-release. +""" +from __future__ import annotations + +import pytest +import torch + +from tests._harness import ( + ACC_TOL, + assert_close, + bench, + capture, + have_baseline, + load_baseline, +) + + +# ---------------------------------------------------------------- IterLabelDecider + +def _label_inputs(shapes, device): + B, T, V = shapes["B"], shapes["T"], shapes["V"] + g = torch.Generator(device=device).manual_seed(126) + logits = torch.randn(B, T, V, generator=g, device=device, dtype=torch.float32) + labels = torch.randint(0, 3, (B, T), generator=g, device=device, dtype=torch.long) + labels[0, 0] = -100 # padding-like + return {"logits": logits, "labels": labels, "max_iter": shapes["MAX_ITER"]} + + +LABEL_RUNNER = """ +def run(payload): + from tah.model.iter_decider import IterLabelDecider + dec = IterLabelDecider(max_iter=payload['max_iter']) + out0 = dec(logits=payload['logits'], iter_depth=0, iter_count_labels=payload['labels']) + out1 = dec(logits=payload['logits'], iter_depth=1, iter_count_labels=payload['labels']) + return { + 'd0': out0[0], 'l0': out0[1], + 'd1': out1[0], 'l1': out1[1], + } +""" + + +@pytest.fixture +def label_baseline(shapes): + args = _label_inputs(shapes, "cpu") + name = "iter_decider_label" + if not have_baseline(name): + capture(name, LABEL_RUNNER, payload=args) + return load_baseline(name) + + +def test_iter_label_decider_acc(label_baseline, device): + from tah.model.iter_decider import IterLabelDecider + + args = {k: (v.to(device) if torch.is_tensor(v) else v) + for k, v in label_baseline["args"].items()} + dec = IterLabelDecider(max_iter=args["max_iter"]).to(device) + out0 = dec(logits=args["logits"], iter_depth=0, iter_count_labels=args["labels"]) + out1 = dec(logits=args["logits"], iter_depth=1, iter_count_labels=args["labels"]) + assert_close("d0", out0[0], label_baseline["out"]["d0"]) + assert_close("l0", out0[1], label_baseline["out"]["l0"]) + assert_close("d1", out1[0], label_baseline["out"]["d1"]) + assert_close("l1", out1[1], label_baseline["out"]["l1"]) + + +# ---------------------------------------------------------------- MLPIterDecider + +MLP_KWARGS = dict( + topk=8, + hidden_states_size=32, + hidden_states_layer_nums=[0, 1, 2, 3], + hidden_dims=[16, 16, 16], + expansion_factor=2, + dropout_rate=0.0, + normalize_input=False, + threshold=0.5, + max_iter=2, + dtype=torch.float32, +) + + +def _mlp_inputs(shapes, device): + B, T, V, H, L = shapes["B"], shapes["T"], shapes["V"], shapes["H"], shapes["L"] + assert V >= MLP_KWARGS["topk"], "vocab too small for topk" + g = torch.Generator(device=device).manual_seed(127) + logits = torch.randn(B, T, V, generator=g, device=device, dtype=torch.float32) + hidden = torch.randn(B, T, L, MLP_KWARGS["hidden_states_size"], generator=g, device=device, dtype=torch.float32) + return {"logits": logits, "hidden": hidden} + + +MLP_RUNNER = ( + "def run(payload):\n" + " import torch\n" + " from tah.model.iter_decider import MLPIterDecider\n" + " torch.manual_seed(7)\n" + f" dec = MLPIterDecider(**{MLP_KWARGS!r})\n" + " state = {k: v.detach().clone() for k, v in dec.state_dict().items()}\n" + " out = dec(logits=payload['logits'], iter_depth=0, all_hidden_states=payload['hidden'])\n" + " return {'state': state, 'decision': out[0], 'logits': out[1]}\n" +) + + +@pytest.fixture +def mlp_baseline(shapes): + args = _mlp_inputs(shapes, "cpu") + name = "iter_decider_mlp" + if not have_baseline(name): + capture(name, MLP_RUNNER, payload=args) + return load_baseline(name) + + +def test_mlp_iter_decider_acc(mlp_baseline, device): + from tah.model.iter_decider import MLPIterDecider + + torch.manual_seed(7) + dec = MLPIterDecider(**MLP_KWARGS).to(device) + # Sync weights with the baseline init exactly (avoids any seed/init drift). + dec.load_state_dict({k: v.to(device) for k, v in mlp_baseline["out"]["state"].items()}) + + args = {k: v.to(device) for k, v in mlp_baseline["args"].items()} + out = dec(logits=args["logits"], iter_depth=0, all_hidden_states=args["hidden"]) + assert_close("decision", out[0], mlp_baseline["out"]["decision"]) + assert_close("logits", out[1], mlp_baseline["out"]["logits"], atol=1e-5) + + +def test_mlp_iter_decider_speed(mlp_baseline, device): + from tah.model.iter_decider import MLPIterDecider as Cleaned + + torch.manual_seed(7) + dec = Cleaned(**MLP_KWARGS).to(device) + args = {k: v.to(device) for k, v in mlp_baseline["args"].items()} + + def run(): + return dec(logits=args["logits"], iter_depth=0, all_hidden_states=args["hidden"]) + + bench("mlp_iter_decider", run) diff --git a/tests/test_iter_label.py b/tests/test_iter_label.py new file mode 100644 index 0000000..617231f --- /dev/null +++ b/tests/test_iter_label.py @@ -0,0 +1,114 @@ +"""Acc test for the iter-label generator. + +Public TaH had three generators (``Fixed``, ``DynamicMismatch``, ``Max``); only +``FixedIterLabelGenerator`` is used in the canonical recipes. Its only state +is a dense ``(B, T)`` accumulator updated via per-iteration max-merge of the +active slice. + +The cleaned wrapper inlines this directly: a single ``full_labels`` tensor in +the forward, two lines to merge the active proposal. The test pins the merge +math against the public class. +""" +from __future__ import annotations + +import pytest +import torch + +from tests._harness import ( + ACC_TOL, + assert_close, + capture, + have_baseline, + load_baseline, +) + + +def _make_inputs(shapes, device): + B, T = shapes["B"], shapes["T"] + g = torch.Generator(device=device).manual_seed(125) + iter1_labels = torch.randint(0, 3, (B, T), generator=g, device=device, dtype=torch.long) + iter2_labels = torch.randint(0, 3, (B, T), generator=g, device=device, dtype=torch.long) + # Mask: drop a couple of positions at iter2 (shrink active set) + mask = torch.ones(B, T, dtype=torch.bool, device=device) + mask[0, -1] = False + mask[1, 0] = False + return {"iter1_labels": iter1_labels, "iter2_labels": iter2_labels, "mask": mask} + + +PUBLIC_RUNNER = """ +def run(payload): + from tah.model.iter_label import FixedIterLabelGenerator + gen = FixedIterLabelGenerator() + B, T = payload['iter1_labels'].shape + gen.prepare(B, T, payload['iter1_labels'].device, torch.float32) + + # iter 1 — all positions active + active_full = torch.ones(B, T, dtype=torch.bool, device=payload['iter1_labels'].device) + gen.intra_iter_labels( + active_iter_count_labels=payload['iter1_labels'], + current_iter_mask=active_full, + ) + # iter 2 — gather active slice using the user's mask + mask = payload['mask'] + active_per_seq = mask.sum(1) + max_len = int(active_per_seq.max()) + SENTINEL = T + base_idx = torch.arange(T, device=mask.device).expand(B, T).masked_fill(~mask, SENTINEL) + sorted_idx, _ = torch.sort(base_idx, dim=1, stable=True) + gather_idx = sorted_idx[:, :max_len].clamp(max=T - 1) + pad_mask = sorted_idx[:, :max_len].eq(SENTINEL) + active_slice = torch.gather(payload['iter2_labels'], 1, gather_idx).masked_fill(pad_mask, -100) + gen.intra_iter_labels(active_iter_count_labels=active_slice, current_iter_mask=mask) + full = gen.finalize() + return {'full_labels': full, 'active_slice_iter2': active_slice} +""" + + +@pytest.fixture +def baseline(shapes): + inputs = _make_inputs(shapes, "cpu") + name = "iter_label_fixed" + if not have_baseline(name): + capture(name, PUBLIC_RUNNER, payload=inputs) + return load_baseline(name) + + +def test_iter_label_inline(baseline, device): + """The cleaned wrapper inlines: max-merge active labels into the dense view.""" + args = {k: v.to(device) for k, v in baseline["args"].items()} + iter1_labels = args["iter1_labels"] + iter2_labels = args["iter2_labels"] + mask = args["mask"] + B, T = iter1_labels.shape + + # The inlined behaviour the wrapper will use: + full_labels = torch.zeros(B, T, dtype=torch.long, device=device) + + # iter 1 — fully active: merge by max + proposal = iter1_labels.clone() + proposal[iter1_labels == -100] = 0 + full_labels = torch.maximum(full_labels, proposal) + + # iter 2 — partial active: scatter active proposal into dense, then max-merge + active_per_seq = mask.sum(1) + max_len = int(active_per_seq.max()) + SENTINEL = T + base_idx = torch.arange(T, device=device).expand(B, T).masked_fill(~mask, SENTINEL) + sorted_idx, _ = torch.sort(base_idx, dim=1, stable=True) + gather_idx = sorted_idx[:, :max_len].clamp(max=T - 1) + pad_mask = sorted_idx[:, :max_len].eq(SENTINEL) + active_slice = torch.gather(iter2_labels, 1, gather_idx).masked_fill(pad_mask, -100) + + # Scatter back: build a dense tmp from active slice + mask + tmp = torch.zeros_like(full_labels) + valid = (active_slice != -100) + proposal_dense = torch.zeros_like(active_slice) + proposal_dense[valid] = active_slice[valid] + for b in range(B): + n = int(active_per_seq[b].item()) + if n: + tmp[b, mask[b]] = proposal_dense[b, :n] + full_labels = torch.maximum(full_labels, tmp) + + assert_close("full_labels", full_labels, baseline["out"]["full_labels"], atol=ACC_TOL) + assert_close("active_slice_iter2", active_slice, baseline["out"]["active_slice_iter2"], atol=ACC_TOL) diff --git a/tests/test_jobs_runner.py b/tests/test_jobs_runner.py new file mode 100644 index 0000000..0672ce6 --- /dev/null +++ b/tests/test_jobs_runner.py @@ -0,0 +1,143 @@ +"""Unit test for ``tah.evaluate.jobs.run_single_job`` with a stubbed backend. + +Validates the per-job orchestration (output dir layout, prompt assembly, +score-and-save loop, stats write) without loading a real model. Real +end-to-end coverage of the per-backend setup lives in the live +``test_released_checkpoint`` smoke; this test exercises the surrounding +glue, which is the part most commonly broken by surgical refactors. +""" +from __future__ import annotations + +import csv +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def tokenizer_stub(): + """Bare-minimum tokenizer interface that ``run_single_job`` touches. + + * ``apply_chat_template`` — returns the message content (so ``_make_prompt`` + gets a deterministic string). + * ``encode`` — counts characters; good enough for the + input/output token-count fields in the per-row CSV. + * ``pad_token`` — set, so the constructor's ``if pad_token is None`` + branch isn't hit. + """ + + class T: + pad_token = "" + eos_token = "" + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False, **kwargs): + return f"PROMPT::{messages[0]['content']}" + + def encode(self, text, **kwargs): + return list(text) # 1 token per character + + return T() + + +def _fake_backend_factory(echo_prefix: str = "OUT::"): + """Returns a setup_backend stand-in that produces a tokenizer-free model + + an inference fn whose outputs are a deterministic function of the input.""" + + def _setup(backend, config, model_path, tokenizer, tp_size=1): + del backend, config, model_path, tokenizer, tp_size # unused + + def infer(prompts): + return [(f"{echo_prefix}{p} \\boxed{{42}}", 0.001) for p in prompts] + + return object(), infer + + return _setup + + +def _fake_score_one(output_text: str, correct: str, dataset_name: str, *, is_code: bool): + """Mark "boxed{N}" as correct iff N == correct, else wrong.""" + del dataset_name, is_code + if "\\boxed{" in output_text: + ans = output_text.split("\\boxed{")[1].split("}")[0] + return ans, True, ans == correct + return "", False, False + + +def test_run_single_job_writes_expected_files(tmp_path, tokenizer_stub): + """End-to-end: 2 problems × 2 samples = 4 inferences with the fake backend. + Verify CSV row count, per-sample json files, stats summary.""" + from tah.evaluate import jobs + + problems_data = [ + {"id": "math_1", "question": "What's 41+1?", "answer": "42", "_original_id": "1"}, + {"id": "math_2", "question": "What's 100-58?", "answer": "42", "_original_id": "2"}, + ] + field_mapping = { + "id_field": "id", "question_field": "question", "answer_field": "answer", + "answer_type": "boxed", "prompt_template": "{question}", + } + config = { + "repeat_size": 2, "batch_size": 2, "temperature": 0.0, + "max_new_tokens": 16, "top_p": 1.0, + } + + with patch.object(jobs, "setup_backend", _fake_backend_factory()), \ + patch.object(jobs, "_score_one", _fake_score_one), \ + patch.object(jobs, "AutoTokenizer", create=True) if False else patch( + "transformers.AutoTokenizer.from_pretrained", return_value=tokenizer_stub, + ), \ + patch.object(jobs, "cleanup", lambda *a, **kw: None): + jobs.run_single_job( + config=config, combined_dataset_name="math", + output_dir=str(tmp_path), timestamp="20260101", + model_path="dummy", job_id=0, job_nums=1, start_idx=0, end_idx=2, + tp_size=1, backend="hf", data_range=None, + problems_data=problems_data, field_mapping=field_mapping, + unified_code_solutions_file=None, + ) + + job_dir = tmp_path / "math_hf" / "20260101" / "job_0" + assert job_dir.exists(), f"job dir not created: {sorted(tmp_path.rglob('*'))}" + + # detailed_results.csv: header + 4 rows (2 problems × repeat 2). + rows = list(csv.DictReader(open(job_dir / "detailed_results.csv"))) + assert len(rows) == 4, rows + assert all(r["is_correct"] == "True" for r in rows), rows + assert sorted(set(r["problem_id"] for r in rows)) == ["math_1", "math_2"] + assert sorted(set(r["sample_idx"] for r in rows)) == ["0", "1"] + + # Per-sample JSON files written under details//sample_.json. + for pid in ("math_1", "math_2"): + d = job_dir / "details" / pid + files = sorted(d.glob("sample_*.json")) + assert len(files) == 2, files + sample = json.load(open(files[0])) + assert sample["correct_answer"] == "42" + assert sample["predicted_answer"] == "42" + assert sample["is_correct"] + + # evaluation_stats.csv: per-problem rows + a "Total Accuracy" row. + stat_rows = list(csv.reader(open(job_dir / "evaluation_stats.csv"))) + header = stat_rows[0] + body = [r for r in stat_rows[1:] if r and r[0] not in ("", "Total Accuracy")] + total = next(r for r in stat_rows if r and r[0] == "Total Accuracy") + assert "accuracy" in header + assert len(body) == 2 # one row per problem + assert total[1] == "1.000" # everyone correct + assert int(total[3]) == 4 # total samples + + +def test_build_prompts_repeat_and_mapping(tokenizer_stub): + from tah.evaluate.jobs import _build_prompts + + problems = [ + {"problem_text": "Q1", "problem_id": "p1"}, + {"problem_text": "Q2", "problem_id": "p2"}, + ] + prompts, mapping = _build_prompts(problems, tokenizer_stub, repeat_size=3, is_code=False) + assert len(prompts) == 6 + assert prompts[0] == prompts[1] == prompts[2] == "PROMPT::Q1" + assert prompts[3] == prompts[4] == prompts[5] == "PROMPT::Q2" + assert mapping == [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)] diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100644 index 0000000..f7ba848 --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,168 @@ +"""Acc tests for the kept loss functions. + +* ``NextTokenPredLoss`` — step-1 SFT objective; standard cross-entropy with an + optional hard/easy reweight knob. +* ``IterDeciderLoss`` — step-2 SFT objective; per-iteration BCE on the iter + decider's continue logits, supervised by ``iter_count_labels``. + +Public TaH also shipped ``ConsistencyLoss`` (unused) which is removed. +""" +from __future__ import annotations + +import pytest +import torch + +from tests._harness import ( + ACC_TOL, + assert_close, + capture, + have_baseline, + load_baseline, +) + + +# ---------------------------------------------------------------- NextTokenPredLoss + +def _ce_inputs(shapes, device): + B, T, V = shapes["B"], shapes["T"], shapes["V"] + g = torch.Generator(device=device).manual_seed(128) + logits = torch.randn(B, T, V, generator=g, device=device, dtype=torch.float32) + labels = torch.randint(0, V, (B, T), generator=g, device=device, dtype=torch.long) + labels[0, 0] = -100 + iter_count = torch.randint(0, 3, (B, T), generator=g, device=device, dtype=torch.long) + return {"logits": logits, "labels": labels, "iter_count": iter_count} + + +CE_RUNNER = """ +def run(payload): + from tah.model.loss import NextTokenPredLoss + loss = NextTokenPredLoss() + loss.prepare_loss(payload['logits'].shape[0], payload['logits'].shape[1], + payload['logits'].device, torch.float32) + val = loss.final_loss_func( + logits=payload['logits'], + labels_shifted=payload['labels'], + iter_count=payload['iter_count'], + training=False, + num_items_in_batch=int((payload['labels'] != -100).sum().item()), + ) + return {'loss': val.detach()} +""" + + +@pytest.fixture +def ce_baseline(shapes): + args = _ce_inputs(shapes, "cpu") + name = "loss_next_token" + if not have_baseline(name): + capture(name, CE_RUNNER, payload=args) + return load_baseline(name) + + +def test_next_token_pred_loss_acc(ce_baseline, device): + from tah.model.loss import NextTokenPredLoss + + args = {k: v.to(device) for k, v in ce_baseline["args"].items()} + loss = NextTokenPredLoss() + loss.prepare_loss(args["logits"].shape[0], args["logits"].shape[1], + args["logits"].device, torch.float32) + val = loss.final_loss_func( + logits=args["logits"], + labels_shifted=args["labels"], + iter_count=args["iter_count"], + training=False, + num_items_in_batch=int((args["labels"] != -100).sum().item()), + ) + assert_close("loss", val, ce_baseline["out"]["loss"], atol=ACC_TOL) + + +# ---------------------------------------------------------------- IterDeciderLoss + +def _bce_inputs(shapes, device): + B, T = shapes["B"], shapes["T"] + g = torch.Generator(device=device).manual_seed(129) + iter_count_labels = torch.randint(0, 3, (B, T), generator=g, device=device, dtype=torch.long) + iter_count_labels[0, 0] = -100 + valid_mask = torch.ones(B, T, dtype=torch.long, device=device) + valid_mask[1, -1] = 0 # one padded position — the wrapper masks this out before + # calling the iter_decider, so continue_logits has shape + # (sum(valid_mask),), not (B*T,). + n_valid = int(valid_mask.sum().item()) + continue_logits = torch.randn(n_valid, generator=g, device=device, dtype=torch.float32) + return { + "continue_logits": continue_logits, + "iter_count_labels": iter_count_labels, + "valid_mask": valid_mask, + "max_iter": shapes["MAX_ITER"], + } + + +BCE_RUNNER = """ +def run(payload): + from tah.model.loss import IterDeciderLoss + B, T = payload['iter_count_labels'].shape + loss = IterDeciderLoss(pos_weight=2.0, skip_last_iter=True, max_iter=payload['max_iter']) + loss.prepare_loss(B, T, payload['iter_count_labels'].device, torch.float32) + + # current_iter_mask: all positions active + current = torch.ones(B, T, dtype=torch.bool, device=payload['iter_count_labels'].device) + val = loss.intra_iter_loss_func( + active_logits=None, + current_iter_mask=current, + active_labels_shifted=None, + active_valid_continue_logits=payload['continue_logits'], + active_valid_mask=payload['valid_mask'], + iter_depth=1, + active_iter_count_labels=payload['iter_count_labels'], + iter_decider_threshold=0.5, + ) + final = loss.final_loss_func( + logits=payload['continue_logits'].view(B, T)[:, :1] if False else payload['continue_logits'].new_zeros(B, T, 1), + labels_shifted=payload['iter_count_labels'], + iter_count=torch.ones_like(payload['iter_count_labels']), + iter_count_labels=payload['iter_count_labels'], + training=True, + num_items_in_batch=int((payload['iter_count_labels'] != -100).sum().item()), + ) + return {'final': final.detach()} +""" + + +@pytest.fixture +def bce_baseline(shapes): + args = _bce_inputs(shapes, "cpu") + name = "loss_iter_decider" + if not have_baseline(name): + capture(name, BCE_RUNNER, payload=args) + return load_baseline(name) + + +def test_iter_decider_loss_acc(bce_baseline, device): + from tah.model.loss import IterDeciderLoss + + args = {k: (v.to(device) if torch.is_tensor(v) else v) + for k, v in bce_baseline["args"].items()} + B, T = args["iter_count_labels"].shape + loss = IterDeciderLoss(pos_weight=2.0, skip_last_iter=True, max_iter=args["max_iter"]) + loss.prepare_loss(B, T, args["iter_count_labels"].device, torch.float32) + + current = torch.ones(B, T, dtype=torch.bool, device=device) + loss.intra_iter_loss_func( + active_logits=None, + current_iter_mask=current, + active_labels_shifted=None, + active_valid_continue_logits=args["continue_logits"], + active_valid_mask=args["valid_mask"], + iter_depth=1, + active_iter_count_labels=args["iter_count_labels"], + iter_decider_threshold=0.5, + ) + final = loss.final_loss_func( + logits=args["continue_logits"].new_zeros(B, T, 1), + labels_shifted=args["iter_count_labels"], + iter_count=torch.ones_like(args["iter_count_labels"]), + iter_count_labels=args["iter_count_labels"], + training=True, + num_items_in_batch=int((args["iter_count_labels"] != -100).sum().item()), + ) + assert_close("final", final, bce_baseline["out"]["final"], atol=ACC_TOL) diff --git a/tests/test_output_updater.py b/tests/test_output_updater.py new file mode 100644 index 0000000..2410e2f --- /dev/null +++ b/tests/test_output_updater.py @@ -0,0 +1,70 @@ +"""Acc + speed test for the output updater. + +Public TaH ships ``AdditiveLogitsUpdater`` (residual accumulation across +iterations) and ``NoneUpdater`` (pass-through). Only the additive form is used +in the canonical recipes, so the cleaned package collapses both to a single +``additive_logits_update`` function. +""" +from __future__ import annotations + +import pytest +import torch + +from tests._harness import ( + ACC_TOL, + assert_close, + bench, + capture, + have_baseline, + load_baseline, +) + + +def _make_inputs(shapes, device): + B, T, V = shapes["B"], shapes["T"], shapes["V"] + g = torch.Generator(device=device).manual_seed(124) + iter0 = torch.randn(B, T, V, generator=g, device=device, dtype=torch.float32) + iter1 = torch.randn(B, T, V, generator=g, device=device, dtype=torch.float32) + return {"iter0": iter0, "iter1": iter1} + + +PUBLIC_RUNNER = """ +def run(payload): + from tah.model.output_updater import AdditiveLogitsUpdater + upd = AdditiveLogitsUpdater() + out0 = upd(logits=payload['iter0'], prev_logits=None, iter_depth=0) + out1 = upd(logits=payload['iter1'], prev_logits=out0, iter_depth=1) + return {'out0': out0, 'out1': out1} +""" + + +@pytest.fixture +def baseline(shapes): + inputs = _make_inputs(shapes, "cpu") + name = "output_updater_additive" + if not have_baseline(name): + capture(name, PUBLIC_RUNNER, payload=inputs) + return load_baseline(name) + + +def test_output_updater_acc(baseline, device): + from tah.model.tah_model import additive_logits_update + + args = {k: v.to(device) for k, v in baseline["args"].items()} + out0 = additive_logits_update(args["iter0"], prev_logits=None) + out1 = additive_logits_update(args["iter1"], prev_logits=out0) + assert_close("out0", out0, baseline["out"]["out0"], atol=ACC_TOL) + assert_close("out1", out1, baseline["out"]["out1"], atol=ACC_TOL) + + +def test_output_updater_speed(baseline, device): + from tah.model.tah_model import additive_logits_update + + args = {k: v.to(device) for k, v in baseline["args"].items()} + + def cleaned(): + a = additive_logits_update(args["iter0"], None) + return additive_logits_update(args["iter1"], a) + + res = bench("output_updater_additive", cleaned) + assert res["ms"] < 1.0 diff --git a/tests/test_released_checkpoint.py b/tests/test_released_checkpoint.py new file mode 100644 index 0000000..be34074 --- /dev/null +++ b/tests/test_released_checkpoint.py @@ -0,0 +1,121 @@ +"""End-to-end smoke test against the real released TaH-plus-1.7B checkpoint. + +Marked ``slow``: requires ~3 GB checkpoint cached + a CUDA device with ~5 GB +free. Skipped when CUDA is unavailable. Validates the load → forward → +generate → save → reload chain that downstream consumers depend on. +""" +from __future__ import annotations + +import json +import os +import tempfile + +import pytest +import torch + + +CHECKPOINT = os.environ.get("TAH_CHECKPOINT", "nics-efc/TaH-plus-1.7B") + + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="released checkpoint smoke test needs CUDA" +) + + +@pytest.fixture(scope="module") +def model_and_tok(): + from transformers import AutoTokenizer + from tah.model.tah_model import TaHForCausalLM + + tok = AutoTokenizer.from_pretrained(CHECKPOINT) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + model = TaHForCausalLM.from_pretrained( + CHECKPOINT, + torch_dtype=torch.bfloat16, + device_map="cuda:0", + attn_implementation="sdpa", + ) + return model, tok + + +def test_forward_basic(model_and_tok): + """A single forward should populate logits + iter_count without raising.""" + from tah.model.causal_cache import TaHCache + + model, tok = model_and_tok + text = "What is 2 + 2? Answer with just the number." + inputs = tok(text, return_tensors="pt").to(model.device) + with torch.no_grad(): + out = model(**inputs, use_cache=False) + assert out.logits.shape[0] == 1 + assert out.logits.shape[1] == inputs["input_ids"].shape[1] + assert out.iter_count.shape == inputs["input_ids"].shape + # iter_count is in [1, max_iter]; max_iter is 2 for this checkpoint. + assert int(out.iter_count.min()) >= 1 + assert int(out.iter_count.max()) <= 2 + + +def test_generate_short(model_and_tok): + """Greedy 16-token generation must produce non-empty text without OOMing.""" + from tah.model.utils import TaHForCasualLM_generate + + model, tok = model_and_tok + msg = [{"role": "user", "content": "Compute 17 + 25. Reply with a single integer."}] + text = tok.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, enable_thinking=False) + inputs = tok(text, return_tensors="pt", padding=True, padding_side="left").to(model.device) + + output_tokens, output_texts = TaHForCasualLM_generate( + tah_model=model, + tokenizer=tok, + model_inputs=dict(inputs), + max_new_tokens=16, + do_sample=False, + verbose=False, + ) + assert len(output_texts) == 1 + assert len(output_tokens[0]) > 0 + assert isinstance(output_texts[0], str) + + +def test_save_load_roundtrip_real(model_and_tok): + """Save the loaded released model to a temp dir, reload it, forward must match.""" + from tah.model.tah_model import TaHForCausalLM + + model, tok = model_and_tok + inputs = tok("The quick brown fox", return_tensors="pt").to(model.device) + with torch.no_grad(): + ref_out = model(**inputs, use_cache=False) + + with tempfile.TemporaryDirectory(prefix="tah_release_smoke_") as tmp: + model.save_pretrained(tmp) + + # Layout sanity: tah_config + iter_decider + lora + base must be on disk. + assert os.path.isfile(os.path.join(tmp, "tah_config.json")) + assert os.path.isfile(os.path.join(tmp, "iter_decider.bin")) + assert os.path.isdir(os.path.join(tmp, "lora")) + assert os.path.isfile(os.path.join(tmp, "lora", "adapter_config.json")) + # Base safetensors (sharded or single-file). + st_files = [f for f in os.listdir(tmp) if f.startswith("model") and f.endswith((".safetensors", ".json"))] + assert any("safetensors" in f for f in st_files), f"no base safetensors in {tmp}" + + # Read back tah_config.json and verify shape. + with open(os.path.join(tmp, "tah_config.json")) as f: + cfg_dict = json.load(f) + assert cfg_dict["iter_decider"] == "MLPIterDecider" + assert cfg_dict["adapter"] == "lora" + assert cfg_dict["max_iter"] == 2 + + reloaded = TaHForCausalLM.from_pretrained( + tmp, torch_dtype=torch.bfloat16, device_map="cuda:0", attn_implementation="sdpa", + ) + with torch.no_grad(): + new_out = reloaded(**inputs, use_cache=False) + + # bf16 + scattered ops give ~1e-2 drift; iter_count must match exactly. + max_logit_diff = (new_out.logits.float() - ref_out.logits.float()).abs().max().item() + assert max_logit_diff < 5e-2, f"reload drift {max_logit_diff:.4e}" + assert torch.equal(new_out.iter_count, ref_out.iter_count), "iter_count drift after reload" + + del reloaded + torch.cuda.empty_cache() diff --git a/tests/test_save_load.py b/tests/test_save_load.py new file mode 100644 index 0000000..1d18466 --- /dev/null +++ b/tests/test_save_load.py @@ -0,0 +1,125 @@ +"""Save → load roundtrip for ``TaHForCausalLM``. + +Pins the on-disk layout that downstream consumers depend on: + +* ``tah_config.json`` — config (component names + kwargs). +* ``lora/`` — PEFT adapter directory (``adapter_model.safetensors``, + ``adapter_config.json``). +* ``iter_decider.bin`` — pickled state dict + class name + init args. +* ``model.safetensors`` — base-model weights with cleaned state-dict keys + (no ``.base_layer`` PEFT prefix; no ``lora_*`` weights). + +A reload of the saved directory must produce a forward pass that matches the +original within fp tolerance. +""" +from __future__ import annotations + +import os +import json +import shutil +import tempfile + +import pytest +import torch + +from tests._harness import assert_close + + +BASE_MODEL = os.environ.get("TAH_TEST_BASE_MODEL", "Qwen/Qwen3-0.6B") + + +def _build_inputs(device): + g = torch.Generator(device=device).manual_seed(151) + B, T = 1, 8 + input_ids = torch.randint(10, 1000, (B, T), generator=g, device=device, dtype=torch.long) + attention_mask = torch.ones(B, T, dtype=torch.long, device=device) + return input_ids, attention_mask + + +@pytest.fixture(scope="module") +def device(): + return "cuda" if torch.cuda.is_available() else "cpu" + + +@pytest.fixture(scope="module") +def initial_model(device): + from transformers import AutoModelForCausalLM + from tah.model.tah_config import TaHConfig + from tah.model.tah_model import TaHForCausalLM + + cfg = TaHConfig( + embedding_key="model.embed_tokens", + max_iter=2, + input_updater_kwargs={"topk": 8}, + iter_decider="MLPIterDecider", + iter_decider_kwargs={ + "topk": 8, + "hidden_states_size": 1024, # matches Qwen3-0.6B hidden_size + "hidden_states_layer_nums": [0, 4, 8, 12], + "hidden_dims": [16, 16, 16, 16, 16, 16], + "expansion_factor": 2, + "dropout_rate": 0.0, + "normalize_input": False, + "threshold": 0.5, + "max_iter": 2, + "dtype": torch.float32, + }, + eval_iter_decider=None, + adapter="lora", + adapter_kwargs={ + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + "target_modules": "all-linear", + "bias": "none", + }, + train_loss="NextTokenPredLoss", + train_loss_kwargs={}, + eval_loss="NextTokenPredLoss", + eval_loss_kwargs={}, + ) + torch.manual_seed(13) + base = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, torch_dtype=torch.float32, attn_implementation="sdpa" + ).to(device).eval() + torch.manual_seed(13) + model = TaHForCausalLM(base_model=base, config=cfg).to(device).eval() + return model, cfg + + +def test_save_load_roundtrip(initial_model, device): + from tah.model.tah_model import TaHForCausalLM + + model, _cfg = initial_model + input_ids, attn = _build_inputs(device) + with torch.no_grad(): + ref_out = model(input_ids=input_ids, attention_mask=attn, use_cache=False) + + with tempfile.TemporaryDirectory(prefix="tah_release_save_") as tmp: + model.save_pretrained(tmp) + # On-disk layout assertions + assert os.path.isfile(os.path.join(tmp, "tah_config.json")), "tah_config.json missing" + assert os.path.isfile(os.path.join(tmp, "iter_decider.bin")), "iter_decider.bin missing" + assert os.path.isdir(os.path.join(tmp, "lora")), "lora/ missing" + assert os.path.isfile(os.path.join(tmp, "lora", "adapter_config.json")), "adapter_config.json missing" + # Base model files + st = [f for f in os.listdir(tmp) if f.startswith("model") and (f.endswith(".safetensors") or f == "model.safetensors.index.json")] + assert st, f"no base-model safetensors in {tmp}: {os.listdir(tmp)}" + + # tah_config.json shape + with open(os.path.join(tmp, "tah_config.json")) as f: + cfg_json = json.load(f) + assert cfg_json["iter_decider"] == "MLPIterDecider" + assert cfg_json["adapter"] == "lora" + + # Reload + torch.manual_seed(0) + reloaded = TaHForCausalLM.from_pretrained( + tmp, torch_dtype=torch.float32, attn_implementation="sdpa" + ).to(device).eval() + + with torch.no_grad(): + new_out = reloaded(input_ids=input_ids, attention_mask=attn, use_cache=False) + + assert_close("logits", new_out.logits, ref_out.logits, atol=1e-4, rtol=1e-3) + assert_close("iter_count", new_out.iter_count, ref_out.iter_count, atol=0) diff --git a/tests/test_sft_smoke.py b/tests/test_sft_smoke.py new file mode 100644 index 0000000..8b4de96 --- /dev/null +++ b/tests/test_sft_smoke.py @@ -0,0 +1,175 @@ +"""Smoke tests for the SFT pipeline. + +Exercises ``CustomTaHDataCollator`` on a synthetic batch and runs 2 training +steps of ``CustomTaHTrainer`` against a 4-example synthetic dataset. Doesn't +load a real labelled dataset (avoids the multi-GB download); every other +production code path through ``tah/train`` and the wrapper's training-mode +forward IS exercised. +""" +from __future__ import annotations + +import os +import tempfile + +import pytest +import torch + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="SFT smoke needs CUDA") +BASE_MODEL = os.environ.get("TAH_TEST_BASE_MODEL", "Qwen/Qwen3-0.6B") + + +def test_data_collator_pads_iter_count_labels(): + """Pad iter_count_labels to the same length as input_ids, on the same side.""" + from transformers import AutoTokenizer + from tah.train import CustomTaHDataCollator + + tok = AutoTokenizer.from_pretrained(BASE_MODEL) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + tok.padding_side = "right" + + features = [ + {"input_ids": [1, 2, 3], "labels": [1, 2, 3], "iter_count_labels": [1, 2, 1]}, + {"input_ids": [1, 2, 3, 4, 5], "labels": [1, 2, 3, 4, 5], "iter_count_labels": [1, 1, 2, 2, 1]}, + ] + coll = CustomTaHDataCollator(tokenizer=tok, padding=True) + batch = coll(features) + assert batch["input_ids"].shape == batch["iter_count_labels"].shape, "iter_count_labels not aligned to input_ids" + # First row should be padded with -100 in iter_count_labels + assert int(batch["iter_count_labels"][0, -1].item()) == -100, "first row didn't pad with ignore index" + # Second row should be unchanged + assert int(batch["iter_count_labels"][1, 0].item()) == 1 + + +def test_trainer_runs_two_steps_on_synthetic_dataset(tmp_path): + """Build a 4-example synthetic dataset, run 2 SFT steps, verify a loss is logged.""" + from datasets import Dataset + from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments + from tah.model.tah_config import TaHConfig + from tah.model.tah_model import TaHForCausalLM + from tah.train import CustomTaHDataCollator, CustomTaHTrainer, LoggerCallback + + tok = AutoTokenizer.from_pretrained(BASE_MODEL) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + tok.padding_side = "right" + + rng = torch.Generator().manual_seed(311) + examples = [] + for _ in range(4): + T = int(torch.randint(8, 24, (1,), generator=rng).item()) + ids = torch.randint(10, 1000, (T,), generator=rng).tolist() + labels = list(ids) + labels[0] = -100 # mimic prompt-mask + iter_labels = torch.randint(1, 3, (T,), generator=rng).tolist() + examples.append({"input_ids": ids, "labels": labels, "iter_count_labels": iter_labels}) + ds = Dataset.from_list(examples) + + cfg = TaHConfig( + embedding_key="model.embed_tokens", + max_iter=2, + input_updater_kwargs={"topk": 8}, + iter_decider="IterLabelDecider", + iter_decider_kwargs={"max_iter": 2}, + eval_iter_decider=None, + adapter="lora", + adapter_kwargs={"r": 8, "lora_alpha": 16, "lora_dropout": 0.0, + "target_modules": "all-linear", "bias": "none"}, + train_loss="NextTokenPredLoss", + eval_loss="NextTokenPredLoss", + ) + base = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, torch_dtype=torch.float32, attn_implementation="sdpa", + ).to("cuda:0") + model = TaHForCausalLM(base_model=base, config=cfg).to("cuda:0") + + args = TrainingArguments( + output_dir=str(tmp_path), + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + max_steps=2, + learning_rate=1e-5, + logging_steps=1, + save_strategy="no", + report_to="none", + remove_unused_columns=False, + bf16=False, + ) + trainer = CustomTaHTrainer( + model=model, args=args, train_dataset=ds, processing_class=tok, + ) + trainer.data_collator = CustomTaHDataCollator(tokenizer=tok, padding=True) + callback = LoggerCallback() + model.logger_callback = callback + trainer.callback_handler.callbacks.insert(0, callback) + + trainer.train() + + # Verify a loss appeared and is finite. + losses = [e.get("loss") for e in trainer.state.log_history if "loss" in e] + assert losses, f"no train loss in log_history: {trainer.state.log_history}" + assert all(loss is not None and not (loss != loss) and loss < float("inf") for loss in losses), losses + + +def test_trainer_save_load_checkpoint(tmp_path): + """After a step, _save must write a TaH-layout checkpoint we can reload.""" + from datasets import Dataset + from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments + from tah.model.tah_config import TaHConfig + from tah.model.tah_model import TaHForCausalLM + from tah.train import CustomTaHDataCollator, CustomTaHTrainer + + tok = AutoTokenizer.from_pretrained(BASE_MODEL) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + tok.padding_side = "right" + + cfg = TaHConfig( + embedding_key="model.embed_tokens", + max_iter=2, + input_updater_kwargs={"topk": 8}, + iter_decider="IterLabelDecider", + iter_decider_kwargs={"max_iter": 2}, + adapter="lora", + adapter_kwargs={"r": 4, "lora_alpha": 8, "lora_dropout": 0.0, + "target_modules": "all-linear", "bias": "none"}, + train_loss="NextTokenPredLoss", + eval_loss="NextTokenPredLoss", + ) + base = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, torch_dtype=torch.float32, attn_implementation="sdpa" + ).to("cuda:0") + model = TaHForCausalLM(base_model=base, config=cfg).to("cuda:0") + + ds = Dataset.from_list([ + {"input_ids": [10, 20, 30, 40], "labels": [-100, 20, 30, 40], "iter_count_labels": [1, 2, 1, 1]}, + ]) + args = TrainingArguments( + output_dir=str(tmp_path / "run"), + per_device_train_batch_size=1, max_steps=1, learning_rate=1e-5, + save_strategy="steps", save_steps=1, save_total_limit=1, + logging_steps=1, report_to="none", remove_unused_columns=False, bf16=False, + ) + trainer = CustomTaHTrainer( + model=model, args=args, train_dataset=ds, processing_class=tok, + ) + trainer.data_collator = CustomTaHDataCollator(tokenizer=tok, padding=True) + trainer.train() + + # find the checkpoint dir + ckpts = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-")] + assert ckpts, f"no checkpoint dir under {args.output_dir}: {os.listdir(args.output_dir)}" + ckpt = os.path.join(args.output_dir, ckpts[0]) + assert os.path.isfile(os.path.join(ckpt, "tah_config.json")), "tah_config.json missing" + assert os.path.isfile(os.path.join(ckpt, "iter_decider.bin")), "iter_decider.bin missing" + assert os.path.isdir(os.path.join(ckpt, "lora")), "lora/ missing" + + # Reload and forward — proves the saved layout is valid. + reloaded = TaHForCausalLM.from_pretrained( + ckpt, torch_dtype=torch.float32, attn_implementation="sdpa" + ).to("cuda:0") + inputs = tok("hello world", return_tensors="pt").to("cuda:0") + with torch.no_grad(): + out = reloaded(**inputs, use_cache=False) + assert out.logits.shape[0] == 1 diff --git a/tests/test_wrapper_forward.py b/tests/test_wrapper_forward.py new file mode 100644 index 0000000..69f2fe2 --- /dev/null +++ b/tests/test_wrapper_forward.py @@ -0,0 +1,157 @@ +"""End-to-end acc test for ``TaHForCausalLM.forward``. + +Drops a small real Qwen3-0.6B-Base into the wrapper, runs one forward in +training mode (labels + iter_count_labels supplied), and asserts the cleaned +package produces the same loss / logits / iter_count tensors as public TaH. + +This test is the single most important regression gate — every refactor +inside the wrapper has to leave this green. + +Cost: ~3-5s per run on a B200 once Qwen3-0.6B is in HF cache. +""" +from __future__ import annotations + +import os + +import pytest +import torch + +from tests._harness import ( + assert_close, + capture, + have_baseline, + load_baseline, +) + + +BASE_MODEL = os.environ.get("TAH_TEST_BASE_MODEL", "Qwen/Qwen3-0.6B") +DTYPE_STR = "float32" # CPU-friendly + numerically tight + + +def _wrapper_inputs(device): + g = torch.Generator(device=device).manual_seed(131) + # Small batch to keep the test fast. + B, T = 2, 16 + input_ids = torch.randint(10, 1000, (B, T), generator=g, device=device, dtype=torch.long) + attention_mask = torch.ones(B, T, dtype=torch.long, device=device) + attention_mask[1, :2] = 0 # left pad on row 1 + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + iter_count_labels = torch.randint(0, 3, (B, T), generator=g, device=device, dtype=torch.long) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "iter_count_labels": iter_count_labels, + } + + +CONFIG_DICT = { + "embedding_key": "model.embed_tokens", + "max_iter": 2, + "input_updater_kwargs": {"topk": 8}, + "iter_decider": "IterLabelDecider", + "iter_decider_kwargs": {"max_iter": 2}, + "eval_iter_decider": None, + "eval_iter_decider_kwargs": {}, + "adapter": "lora", + "adapter_kwargs": { + "r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + "target_modules": "all-linear", + "bias": "none", + }, + "train_loss": "NextTokenPredLoss", + "train_loss_kwargs": {}, + "eval_loss": "NextTokenPredLoss", + "eval_loss_kwargs": {}, +} + + +# Public-TaH config: superset of the cleaned config dict above. Includes the +# inert fields (output_updater, iter_label_generator, iter_attention_mode, +# input_updater) that the public dataclass requires. Behavior must match the +# cleaned package since the dropped slots have only one used implementation. +_PUBLIC_CONFIG_DICT = { + **CONFIG_DICT, + "input_updater": "TrivialUpdater", + "output_updater": "AdditiveLogitsUpdater", + "output_updater_kwargs": {}, + "iter_label_generator": "FixedIterLabelGenerator", + "iter_label_generator_kwargs": {}, + "iter_attention_mode": "duo", +} + +WRAPPER_RUNNER = ( + "def run(payload):\n" + " import torch\n" + " from transformers import AutoModelForCausalLM\n" + " from tah.model.tah_config import TaHConfig\n" + " from tah.model.recurrent_transformer import TaHForCausalLM\n" + " device = payload['device']\n" + " torch.manual_seed(11)\n" + f" base = AutoModelForCausalLM.from_pretrained({BASE_MODEL!r}, torch_dtype=torch.{DTYPE_STR}, attn_implementation='sdpa')\n" + " base = base.to(device).eval()\n" + f" cfg = TaHConfig(**{_PUBLIC_CONFIG_DICT!r})\n" + " torch.manual_seed(11)\n" + " model = TaHForCausalLM(base_model=base, config=cfg).to(device).eval()\n" + " inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in payload.items()}\n" + " out = model(\n" + " input_ids=inputs['input_ids'],\n" + " attention_mask=inputs['attention_mask'],\n" + " labels=inputs['labels'],\n" + " iter_count_labels=inputs['iter_count_labels'],\n" + " use_cache=False,\n" + " )\n" + " return {\n" + " 'loss': out.loss.detach() if out.loss is not None else None,\n" + " 'logits': out.logits.detach(),\n" + " 'iter_count': out.iter_count,\n" + " }\n" +) + + +@pytest.fixture(scope="module") +def baseline(): + device = "cuda" if torch.cuda.is_available() else "cpu" + inputs = _wrapper_inputs(device) + inputs = {k: (v.cpu() if torch.is_tensor(v) else v) for k, v in inputs.items()} + inputs["device"] = device + name = "wrapper_forward_qwen3_0.6b" + if not have_baseline(name): + capture(name, WRAPPER_RUNNER, payload=inputs) + return load_baseline(name) + + +def test_wrapper_forward_acc(baseline, device): + from transformers import AutoModelForCausalLM + from tah.model.tah_config import TaHConfig + from tah.model.tah_model import TaHForCausalLM + + args = {k: (v.to(device) if torch.is_tensor(v) else v) + for k, v in baseline["args"].items()} + + torch.manual_seed(11) + base = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, torch_dtype=getattr(torch, DTYPE_STR), attn_implementation="sdpa" + ) + base = base.to(device).eval() + cfg = TaHConfig(**CONFIG_DICT) + torch.manual_seed(11) + model = TaHForCausalLM(base_model=base, config=cfg).to(device).eval() + out = model( + input_ids=args["input_ids"], + attention_mask=args["attention_mask"], + labels=args["labels"], + iter_count_labels=args["iter_count_labels"], + use_cache=False, + ) + + # Tolerances: cleaned version may reorder mathematically equivalent ops + # (e.g., gather/scatter) which can introduce ~1e-4 drift in fp32. Tighten + # later if specific paths warrant it. + if baseline["out"]["loss"] is not None: + assert_close("loss", out.loss, baseline["out"]["loss"], atol=1e-4, rtol=1e-3) + assert_close("logits", out.logits, baseline["out"]["logits"], atol=2e-3, rtol=1e-3) + assert_close("iter_count", out.iter_count, baseline["out"]["iter_count"], atol=0) From 88badeb0a41a27b6cb5fc751beb7c82017a6d958 Mon Sep 17 00:00:00 2001 From: Tianyu Fu Date: Sat, 25 Apr 2026 23:15:34 -0400 Subject: [PATCH 2/5] [cleanup] simplify repo and README MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - README: trim Future Work, fix bibtex arxiv ID, dedupe step headers, compress training note, replace bench table with pointer to tests/README - CLAUDE.md: drop bash/ and eval_unified rows from layout tree, point to tests/README for snapshot harness details - bash/: delete redundant shell wrappers (commands already in README, and hard-coded HF mirror was region-specific) - tah/evaluate/eval_unified.py: delete backwards-compat shim (no callers; same names already exported through tah/evaluate/__init__.py) - pyproject.toml: drop unused torchvision dep - .gitignore: add .pytest_cache/, __pycache__/, .coverage - tests/README.md: new — run instructions, _harness.py baseline-snapshot system, benchmark scripts, component-baseline timings (moved from project README) - script/recipes/README.md: new — directory layout, step1/step2/eval semantics, supported model sizes (Qwen3-0.6B / 1.7B) Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 3 ++ CLAUDE.md | 9 ++---- README.md | 34 +++------------------ bash/eval_tah.sh | 12 -------- bash/pre_data.sh | 15 ---------- bash/sft_tah.sh | 19 ------------ pyproject.toml | 1 - script/recipes/README.md | 48 ++++++++++++++++++++++++++++++ tah/evaluate/__init__.py | 1 - tah/evaluate/eval_unified.py | 12 -------- tests/README.md | 57 ++++++++++++++++++++++++++++++++++++ 11 files changed, 114 insertions(+), 97 deletions(-) delete mode 100644 bash/eval_tah.sh delete mode 100644 bash/pre_data.sh delete mode 100644 bash/sft_tah.sh create mode 100644 script/recipes/README.md delete mode 100644 tah/evaluate/eval_unified.py create mode 100644 tests/README.md diff --git a/.gitignore b/.gitignore index 094b9c3..1d6aa86 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.pyc *.png *.log +.coverage # Folders reference/ @@ -12,6 +13,8 @@ local/ output/ wandb/ build/ +.pytest_cache/ +__pycache__/ test/ data/ diff --git a/CLAUDE.md b/CLAUDE.md index b74d913..3a5a690 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -67,10 +67,7 @@ TAH_TEST_DEVICE=cpu pytest tests/ # run on CPU (skip needs no flag) python tests/bench.py components # per-helper microbench (B200 baseline in README) python tests/bench.py e2e # forward + generate on TaH-plus-1.7B ``` -Snapshot baselines are captured by spawning a subprocess scoped to -`/tmp/TaH-pub` (public TaH); cleaned outputs are diffed against the -recorded snapshots in `tests/baselines/`. Snapshots are gitignored — -they regenerate on first run. +See `tests/README.md` for the snapshot/baseline harness details. ### Training (3-stage) ```bash @@ -118,8 +115,7 @@ tah/ │ ├── backends.py # sglang / hf / tah model + inference fn │ ├── jobs.py # job-sharded runner + result aggregation │ ├── matheval.py # math benchmark graders (math_verify) -│ ├── codeeval.py # humaneval / mbpp via evalplus -│ └── eval_unified.py # backwards-compat shim re-exporting the above +│ └── codeeval.py # humaneval / mbpp via evalplus └── utils/data_prepare.py # SFT preprocessing script/ ├── preparation/ # download.py, label.py, prune.py, filter_split.py @@ -128,7 +124,6 @@ script/ ├── playground/inference_example.py └── recipes/ # qwen3_{0.6,1.7}/sft_tah_step{1,2}.yaml + eval_tah.yaml tests/ # _harness.py + per-component test_*.py + baselines/ (gitignored) -bash/ # sft_tah.sh, eval_tah.sh, pre_data.sh ``` ## Conventions diff --git a/README.md b/README.md index ac6d3cf..3f0867b 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Feel free to star the repo or cite the paper if you find it interesting. @article{fu2025tah, title={Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models}, author={Tianyu Fu and Yichen You and Zekai Chen and Guohao Dai and Huazhong Yang and Yu Wang}, - journal={arXiv preprint arXiv:2510.08577}, + journal={arXiv preprint arXiv:2511.08577}, year={2025}, } ``` @@ -104,7 +104,6 @@ Training a TaH model consists of three stages: Use a reference model to generate hard token labels for the training and validation data: ```bash -### step 0 # download the default subset of OpenR1-Math-220k python script/preparation/download.py # filter and split @@ -129,7 +128,6 @@ python script/preparation/label.py \ For the TaH version, prune one layer from the base model to match the parameter count of the standard baseline (skip this step for TaH+ version): ```bash -### step 0 python script/preparation/prune.py \ --model Qwen/Qwen3-1.7B-Base \ --dataset ./data/processed_data/openr1-math/1_7/eval \ @@ -142,7 +140,6 @@ python script/preparation/prune.py \ The first stage uses fixed iteration labels for training: ```bash -### step 1 python -m accelerate.commands.launch \ --config_file ./script/recipes/accelerate_configs/zero2.yaml \ --num_processes 8 \ @@ -158,10 +155,7 @@ Key configurations in Step1 (`sft_tah_step1.yaml`): - `adapter: "lora"` — only LoRA is supported in tah-release. - `train_loss: "NextTokenPredLoss"` — standard causal-LM cross-entropy. -Note: the input updater (top-k softmax over logits → embedding mix), the -output updater (residual additive accumulation), the iter-label generator -(dense max-merge of dataset labels), and the adapter setup are all inlined -into the wrapper, so there's no separate config field for them anymore. +Single-implementation hooks (input/output updaters, iter labels, adapter) are inlined into the wrapper — only `iter_decider` and `train_loss` are config-selectable. ### Step2: Train Iteration Decider @@ -169,7 +163,6 @@ The second stage trains the iteration decider: ```bash -### step 2 python -m accelerate.commands.launch \ --config_file ./script/recipes/accelerate_configs/zero2.yaml \ --num_processes 8 \ @@ -205,8 +198,7 @@ TaH/ │ │ ├── backends.py # sglang / hf / tah model + inference fn │ │ ├── jobs.py # job-sharded runner + result aggregation │ │ ├── matheval.py # math benchmark graders (math_verify) -│ │ ├── codeeval.py # humaneval / mbpp via evalplus -│ │ └── eval_unified.py # backwards-compat shim +│ │ └── codeeval.py # humaneval / mbpp via evalplus │ └── utils/ # SFT preprocessing ├── script/ │ ├── preparation/ # download.py, label.py, prune.py, filter_split.py @@ -227,25 +219,7 @@ python tests/bench.py e2e # forward + 32-token generate on TaH-plus-1.7B python tests/bench_compile.py # one-off torch.compile vs eager experiment ``` -Component baselines (single B200, torch 2.11+cu128, bf16): - -| helper | ms | -|---|---| -| topk_softmax_input_update | 0.48 | -| additive_logits_update | 0.03 | -| gather_active | 0.19 | -| scatter_back | 0.12 | -| MLPIterDecider.forward | 0.86 | -| NextTokenPredLoss.final | 0.23 | -| IterDeciderLoss.intra | 0.55 | -| **TaHForCausalLM.forward** (TaH-plus-1.7B, T=15) | **18.0** | -| **TaHForCasualLM_generate(32)** | **691** (~21.6 ms / token) | - -## Future Work - -- [ ] Optimize iteration decision strategies -- [ ] Integrate TaH with online distillation or RL -- [ ] Support training for larger models +See [`tests/README.md`](tests/README.md) for component-level baselines and CPU-mode (`TAH_TEST_DEVICE=cpu`) instructions. ## Related Projects diff --git a/bash/eval_tah.sh b/bash/eval_tah.sh deleted file mode 100644 index 7f68cf9..0000000 --- a/bash/eval_tah.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -export HF_ENDPOINT="https://hf-mirror.com" -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - -python script/evaluation/eval.py \ - --eval_config ./script/recipes/qwen3_1.7/eval_tah.yaml \ - --model_path nics-efc/TaH-plus-1.7B \ - --dataset_name gsm8k \ - --backend tah \ - --job_nums 8 \ - --tp_size_per_job 1 \ No newline at end of file diff --git a/bash/pre_data.sh b/bash/pre_data.sh deleted file mode 100644 index 04765a3..0000000 --- a/bash/pre_data.sh +++ /dev/null @@ -1,15 +0,0 @@ -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - -python script/preparation/label.py \ - --num_gpu 8 \ - --dataset_path ./data/initial_data/openr1-math/train.jsonl \ - --test_model_list Qwen/Qwen3-1.7B \ - --output_path ./data/processed_data/openr1-math/1_7/train \ - --max_input_length 10000 - -python script/preparation/label.py \ - --num_gpu 8 \ - --dataset_path ./data/initial_data/openr1-math/eval.jsonl \ - --test_model_list Qwen/Qwen3-1.7B \ - --output_path ./data/processed_data/openr1-math/1_7/eval \ - --max_input_length 10000 \ \ No newline at end of file diff --git a/bash/sft_tah.sh b/bash/sft_tah.sh deleted file mode 100644 index 9a274f8..0000000 --- a/bash/sft_tah.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export TORCH_NCCL_AVOID_RECORD_STREAMS=1 -export HF_ENDPOINT="https://hf-mirror.com" -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - -### step 1 -python -m accelerate.commands.launch \ - --config_file ./script/recipes/accelerate_configs/zero2.yaml \ - --num_processes 8 \ - ./script/train/SFT_TaH.py \ - --config ./script/recipes/qwen3_1.7/sft_tah_step1.yaml - -### step 2 -# python -m accelerate.commands.launch \ -# --config_file ./script/recipes/accelerate_configs/zero2.yaml \ -# --num_processes 8 \ -# ./script/train/SFT_TaH.py \ -# --config ./script/recipes/qwen3_1.7/sft_tah_step2.yaml \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2e04599..1b3698b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ classifiers = [ dependencies = [ "transformers==4.52.4", "torch==2.6.0", - "torchvision==0.21.0", "accelerate", "datasets", "peft==0.15.2", diff --git a/script/recipes/README.md b/script/recipes/README.md new file mode 100644 index 0000000..671f4eb --- /dev/null +++ b/script/recipes/README.md @@ -0,0 +1,48 @@ +# Recipes + +YAML configs for training and evaluating TaH models. Pick the directory that +matches your base-model size, then point the relevant entrypoint at the YAML. + +## Layout + +``` +script/recipes/ +├── accelerate_configs/ # multi-GPU launch configs for `accelerate launch` +│ ├── zero2.yaml # DeepSpeed ZeRO stage 2 +│ └── zero3.yaml # DeepSpeed ZeRO stage 3 +├── qwen3_0.6/ # recipes for Qwen3-0.6B-based TaH (hidden_size=1024) +│ ├── sft_tah_step1.yaml +│ ├── sft_tah_step2.yaml +│ └── eval_tah.yaml +└── qwen3_1.7/ # recipes for Qwen3-1.7B-based TaH (hidden_size=2048) + ├── sft_tah_step1.yaml + ├── sft_tah_step2.yaml + └── eval_tah.yaml +``` + +## What each file does + +| File | Purpose | +|---|---| +| `accelerate_configs/zero2.yaml` | DeepSpeed ZeRO-2 launch config; lower memory savings, less comm overhead. | +| `accelerate_configs/zero3.yaml` | DeepSpeed ZeRO-3 launch config; max memory savings, more comm. Pick whichever fits your GPU/memory budget. | +| `qwen3_*/sft_tah_step1.yaml` | Step 1 SFT — `iter_decider: IterLabelDecider` (oracle hard-token labels) + `train_loss: NextTokenPredLoss`. Teaches the LoRA adapter on tokens marked "hard" by the labeller. | +| `qwen3_*/sft_tah_step2.yaml` | Step 2 SFT — loads the Step 1 checkpoint, switches to `iter_decider: MLPIterDecider` + `train_loss: IterDeciderLoss`. Trains the iter-decider so the model predicts its own hard tokens at inference. | +| `qwen3_*/eval_tah.yaml` | Eval config consumed by `script/evaluation/eval.py`; controls dataset list, generation params, max-new-tokens. | + +## Model sizes supported + +This release ships recipes for two base-model sizes — Qwen3-0.6B and +Qwen3-1.7B. The hidden size differs (1024 for 0.6B, 2048 for 1.7B), and the +`iter_decider_kwargs.hidden_states_size` field in each YAML is set +accordingly. If you adapt these recipes to a new base model, match its +hidden size. + +## How they're used + +- `script/train/SFT_TaH.py --config ` consumes the Step 1 / Step 2 YAMLs. +- `script/evaluation/eval.py --eval_config ` consumes the eval YAML. +- Accelerate configs are passed via `accelerate launch --config_file `. + +See the project README's *Train your own TaH model* and *Run evaluation* +sections for full command examples. diff --git a/tah/evaluate/__init__.py b/tah/evaluate/__init__.py index 6c156e0..b306928 100644 --- a/tah/evaluate/__init__.py +++ b/tah/evaluate/__init__.py @@ -8,7 +8,6 @@ * ``jobs`` — per-job runner, process orchestration, result aggregation. * ``matheval`` — math benchmark graders (rule-based via ``math_verify``). * ``codeeval`` — humaneval/mbpp grading via ``evalplus``. -* ``eval_unified`` — backwards-compat shim re-exporting the above entry points. """ from tah.evaluate.datasets import load_combined_dataset from tah.evaluate.jobs import ( diff --git a/tah/evaluate/eval_unified.py b/tah/evaluate/eval_unified.py deleted file mode 100644 index b82870c..0000000 --- a/tah/evaluate/eval_unified.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Backwards-compat shim: ``tah.evaluate.eval_unified`` was the original -single-file driver. The driver has been split into ``datasets.py``, -``backends.py``, and ``jobs.py``; this module re-exports the public entry -points so existing callers keep working. -""" -from tah.evaluate.datasets import load_combined_dataset as load_datasets_with_config # noqa: F401 -from tah.evaluate.jobs import ( # noqa: F401 - allocate_gpus_and_run_jobs, - combine_job_results, - parse_data_range, - run_single_job, -) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..e04524e --- /dev/null +++ b/tests/README.md @@ -0,0 +1,57 @@ +## What's here + +This directory holds the TaH test suite: 11 pytest files (`test_*.py`) covering individual +components (iter label, loss, input/output updaters, iter decider, causal cache), +wrapper-level forward/save-load/sft smoke checks, the released-checkpoint sanity test, +and the eval jobs runner; a shared `_harness.py` that captures and diffs per-component +baselines against upstream public TaH; a `conftest.py` with device + seed fixtures; and +two standalone benchmark scripts (`bench.py`, `bench_compile.py`) sitting alongside the +tests rather than under pytest. + +## Run the tests + +```bash +pytest tests/ -q # all tests +pytest tests/test_.py -v # single file +TAH_TEST_DEVICE=cpu pytest tests/ # force CPU (auto-detects CUDA otherwise) +``` + +`TAH_TEST_DEVICE` is read by `conftest.py` and `_harness.py` to pick the device for +both fixtures and baseline subprocesses. The simple component tests +(`test_iter_label.py`, `test_loss.py`, `test_input_updater.py`, +`test_output_updater.py`) run cleanly on CPU; tests that touch the released checkpoint +or full wrapper (e.g. `test_released_checkpoint.py`) require a GPU and will download +`nics-efc/TaH-plus-1.7B` from Hugging Face on first run. + +## Baseline-snapshot harness + +On first run, `_harness.py` spawns a subprocess scoped to `/tmp/TaH-pub` (a checkout +of the public upstream TaH at [thu-nics/TaH](https://github.com/thu-nics/TaH)) to +capture per-component `.pt` snapshots into `tests/baselines/`. Subsequent runs in +this repo diff cleaned tah-release outputs against the recorded snapshots, giving the +suite drift detection between this cleaned fork and upstream. The `tests/baselines/` +directory is gitignored — snapshots regenerate on first run. + +## Benchmarks + +```bash +python tests/bench.py components # microbenchmarks for the wrapper's hot helpers +python tests/bench.py e2e # forward + 32-token generate on TaH-plus-1.7B +python tests/bench_compile.py # one-off torch.compile vs eager experiment +``` + +## Component baselines + +Single B200, torch 2.11+cu128, bf16: + +| helper | ms | +|---|---| +| topk_softmax_input_update | 0.48 | +| additive_logits_update | 0.03 | +| gather_active | 0.19 | +| scatter_back | 0.12 | +| MLPIterDecider.forward | 0.86 | +| NextTokenPredLoss.final | 0.23 | +| IterDeciderLoss.intra | 0.55 | +| **TaHForCausalLM.forward** (TaH-plus-1.7B, T=15) | **18.0** | +| **TaHForCasualLM_generate(32)** | **691** (~21.6 ms / token) | From 1d8611deabf589be4d71783acadc676486c3fcee Mon Sep 17 00:00:00 2001 From: Tianyu Fu Date: Fri, 1 May 2026 10:05:00 -0400 Subject: [PATCH 3/5] [fix] tah_model: 3 correctness bugs surfaced by end-to-end SFT MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (1) topk_softmax_input_update autocast dtype mismatch Trainer's eval-on-start runs the wrapper inside accelerate's autocast. softmax/sum can promote bf16 inputs to fp32, then the index-put scatter ``active_input_embeds[active_next] = …`` errors with "got BFloat16 for the destination and Float for the source". Cast the helper's return to ``embedding_weight.dtype``. (2) _set_lora_grad_flags early-return left base model frozen PEFT's ``get_peft_model`` unconditionally freezes every non-lora param. The prior early-return at ``base_grad=True, adapter_grad=True`` (step-1 SFT default) was based on a wrong assumption that HF Trainer would re-enable them. Result: step-1 silently trained only the LoRA adapter (5M of 601M params on Qwen3-0.6B; 0.13 GB instead of the intended 3.4 GB on Qwen3-1.7B). Always reapply the flags so step-1 trains base + LoRA and step-2's ``base_grad=False, adapter_grad=False`` still freezes them. (3) in-place index_put broke autograd once base embeds carried grad Phase-8 wrote next-iter embeddings via ``active_input_embeds[active_next] = topk_softmax_input_update(…)``. The same tensor was already saved by autograd through the ``simple_base_model(inputs_embeds=active_input_embeds)`` call earlier in the iteration; mutating it in place tripped ``one of the variables needed for gradient computation has been modified by an inplace operation``. Latent until fix (2) above made base embeds trainable. Clone before the index_put. 13/13 acc tests still pass on CPU; smoke train (5 steps, Qwen3-1.7B-Base) shows trainable=3.335 GB, grad_norm=7.13, loss converging. --- tah/model/tah_model.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tah/model/tah_model.py b/tah/model/tah_model.py index d4a87cd..c19f92a 100644 --- a/tah/model/tah_model.py +++ b/tah/model/tah_model.py @@ -69,7 +69,10 @@ def topk_softmax_input_update( topk_values, topk_indices = torch.topk(logits, k=k, dim=-1) topk_probs = torch.softmax(topk_values, dim=-1) topk_embeds = embedding_weight[topk_indices] # (..., k, H) - return torch.sum(topk_probs.unsqueeze(-1) * topk_embeds, dim=-2) + out = torch.sum(topk_probs.unsqueeze(-1) * topk_embeds, dim=-2) + # Trainer eval runs under autocast which can promote softmax/sum to fp32; + # the caller scatters the result into a bf16 buffer so cast back. + return out.to(embedding_weight.dtype) def additive_logits_update( @@ -377,13 +380,9 @@ def _setup_lora(self, config: TaHConfig) -> None: self._set_lora_grad_flags(base_grad, adapter_grad) def _set_lora_grad_flags(self, base_grad: bool, adapter_grad: bool) -> None: - """Enable/disable gradients on lora-* params vs everything else. - - No-op when both flags default to True (the common case at training - time, where HF Trainer manages requires_grad per parameter group). - """ - if base_grad is True and adapter_grad is True: - return + """Always reapplied: PEFT freezes all non-LoRA params in + ``get_peft_model``, so even ``(True, True)`` needs us to re-enable the + base. An earlier no-op early-return silently broke step-1 SFT.""" for name, p in self.simple_base_model.base_model.named_parameters(): p.requires_grad = adapter_grad if "lora" in name.lower() else base_grad @@ -673,6 +672,11 @@ def forward( next_iter_mask = (~finished_mask) & current_iter_mask & (valid_mask == 1) if next_iter_mask.any(): active_next = (~active_finished_mask) & valid_active + # Clone before the in-place index_put: when base embeddings are + # trainable (step-1 default), autograd has saved active_input_embeds + # for backward through the simple_base_model call above; mutating + # it in place would trip the saved-tensor version check. + active_input_embeds = active_input_embeds.clone() active_input_embeds[active_next] = topk_softmax_input_update( logits=updated_active_logits[active_next], embedding_weight=self.embed_tokens.weight, From 8ddfa75ff369a94adf68bc6a186932e70e79d86a Mon Sep 17 00:00:00 2001 From: Tianyu Fu Date: Fri, 1 May 2026 10:05:20 -0400 Subject: [PATCH 4/5] [ergonomics] newcomer-friction fixes from a 1-GPU end-to-end run MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hit while reproducing the README's three-stage pipeline (data prep → step1 SFT → step2 SFT → eval) on a single B200. - script/playground/inference_example.py: add --max-new-tokens flag and drop the default from 16384 to 512. The hardcoded 16384 made the "demo" take ~15-20 min; a newcomer would think it was hung. Pass --max-new-tokens 16384 to see a full reasoning chain. - script/preparation/download.py: gate ``HF_ENDPOINT='https://hf-mirror.com'`` behind ``HF_MIRROR=1``. The hardcoded mirror silently failed (or was very slow) for users outside China. - tah/evaluate/jobs.py: install ``prctl(PR_SET_PDEATHSIG, SIGTERM)`` in each multiprocessing worker. Killing the eval driver previously left orphan workers reparented to init still pinning the GPU; saw this concretely during a smoke-test cancel. - README.md: install-section note about ``__editable___tah_*_finder.py`` going stale when ``__init__.py`` is added/removed (re-run ``pip install -e .``); --max-new-tokens example for the playground; a Single-GPU smoke subsection under Run evaluation showing how to slice the dataset + shrink ``max_new_tokens`` for a few-minute sanity check. - script/recipes/qwen3_1.7_1gpu/: new 1-GPU SFT recipes — grad-accum 4 (vs 16), max_length 4096, gradient_checkpointing false, report_to none, outputs under /tmp/tah_run/. Step-2 starts from nics-efc/TaH-plus-1.7B by default (override with your local step-1 final_model). Use with plain ``python script/train/SFT_TaH.py`` (no accelerate launch). - script/recipes/README.md: list the new qwen3_1.7_1gpu/ dir + describe the scaling deltas vs the 8-GPU originals. --- README.md | 32 ++++++- script/playground/inference_example.py | 22 +++-- script/preparation/download.py | 7 +- script/recipes/README.md | 16 +++- .../recipes/qwen3_1.7_1gpu/sft_tah_step1.yaml | 51 ++++++++++++ .../recipes/qwen3_1.7_1gpu/sft_tah_step2.yaml | 83 +++++++++++++++++++ tah/evaluate/jobs.py | 15 ++++ 7 files changed, 213 insertions(+), 13 deletions(-) create mode 100644 script/recipes/qwen3_1.7_1gpu/sft_tah_step1.yaml create mode 100644 script/recipes/qwen3_1.7_1gpu/sft_tah_step2.yaml diff --git a/README.md b/README.md index 3f0867b..44f05b0 100644 --- a/README.md +++ b/README.md @@ -55,10 +55,17 @@ pip install -e ".[training,evaluation]" For code generation evaluation, install [evalplus](https://github.com/evalplus/evalplus) +> **Note** if you ``git pull`` and the top-level package layout changes +> (e.g. ``__init__.py`` is added or removed), re-run ``pip install -e .`` +> — the editable install caches the layout in +> ``site-packages/__editable___tah_*_finder.py`` and stale state will +> silently drop ``tah/__init__.py``'s re-exports. + ## Run an example for TaH ```bash -python script/playground/inference_example.py +python script/playground/inference_example.py # quick demo (~1 min) +python script/playground/inference_example.py --max-new-tokens 16384 # full reasoning chain ``` This script demonstrates TaH's selective latent iteration mechanism, with color-coded output showing the iteration count for each token. @@ -82,8 +89,29 @@ Key parameters: - `--model_path`: Path to the model - `--dataset_name`: Dataset name (supports gsm8k, math500, aime24, etc. Detailed configs can be found in `tah/evaluate/eval_configs/dataset_configs.json`) - `--backend`: Inference backend (`tah` for TaH) -- `--job_nums`: Number of parallel jobs +- `--job_nums`: Number of parallel jobs (one job pins `tp_size_per_job` GPUs) - `--tp_size_per_job`: Tensor parallel size per job +- `--data_range N` / `--data_range start end`: subset slice — handy for smoke tests +- `--data_ids gsm8k_0,gsm8k_5`: run only specific problem ids + +#### Single-GPU smoke +The default recipe targets 8 GPUs (`--job_nums 8`). To sanity-check the pipeline on +one GPU in a couple of minutes, slice the dataset and shrink `max_new_tokens`: +```bash +# clone the recipe and shrink generation length +sed 's/max_new_tokens: 4096/max_new_tokens: 512/' \ + script/recipes/qwen3_1.7/eval_tah.yaml > /tmp/eval_tah_smoke.yaml + +CUDA_VISIBLE_DEVICES=0 python script/evaluation/eval.py \ + --eval_config /tmp/eval_tah_smoke.yaml \ + --model_path nics-efc/TaH-plus-1.7B \ + --dataset_name gsm8k --backend tah \ + --job_nums 1 --tp_size_per_job 1 \ + --data_range 5 \ + --output_dir /tmp/tah_eval_smoke +``` +The TaH backend is a token-by-token Python loop intended for research; for serving +throughput, use `--backend sglang` or the dedicated `minisgl-tah` server. ### Evaluate with a different backend diff --git a/script/playground/inference_example.py b/script/playground/inference_example.py index 3d2bbc0..a02af23 100644 --- a/script/playground/inference_example.py +++ b/script/playground/inference_example.py @@ -2,8 +2,11 @@ with per-token iter-count colouring. Run: - python script/playground/inference_example.py + python script/playground/inference_example.py # quick demo (512 tokens) + python script/playground/inference_example.py --max-new-tokens 16384 # full reasoning chain """ +import argparse + import torch from transformers import AutoTokenizer @@ -12,17 +15,22 @@ def main(): - model_name = "nics-efc/TaH-plus-1.7B" - device_map = "cuda:0" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="nics-efc/TaH-plus-1.7B", help="HF id or local path") + parser.add_argument("--device", default="cuda:0") + parser.add_argument("--max-new-tokens", type=int, default=512, + help="Cap on generated tokens (default 512 ≈ 1 min on a B200; " + "raise to 16384+ to see a full reasoning chain).") + args = parser.parse_args() - tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tah_model = TaHForCausalLM.from_pretrained( - model_name, + args.model, torch_dtype=torch.bfloat16, - device_map=device_map, + device_map=args.device, attn_implementation="sdpa", ) print(f"Device: {tah_model.device}, Dtype: {tah_model.dtype}") @@ -54,7 +62,7 @@ def main(): tah_model=tah_model, tokenizer=tokenizer, model_inputs=dict(model_inputs), - max_new_tokens=16384, + max_new_tokens=args.max_new_tokens, do_sample=True, temperature=0.6, top_p=0.95, diff --git a/script/preparation/download.py b/script/preparation/download.py index 840da1e..1e7fb45 100644 --- a/script/preparation/download.py +++ b/script/preparation/download.py @@ -1,8 +1,13 @@ import os -os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' import time from huggingface_hub import snapshot_download +# Opt-in mirror for users in regions where huggingface.co is unreachable. +# Set ``HF_MIRROR=1`` to route downloads through hf-mirror.com; otherwise the +# default huggingface.co endpoint is used. +if os.environ.get("HF_MIRROR") == "1": + os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com") + def download_with_retry(repo_id, repo_type, local_dir, max_retries=5, split=None, allow_patterns=None): """Download with retry mechanism""" for attempt in range(max_retries): diff --git a/script/recipes/README.md b/script/recipes/README.md index 671f4eb..24e7c3a 100644 --- a/script/recipes/README.md +++ b/script/recipes/README.md @@ -14,12 +14,22 @@ script/recipes/ │ ├── sft_tah_step1.yaml │ ├── sft_tah_step2.yaml │ └── eval_tah.yaml -└── qwen3_1.7/ # recipes for Qwen3-1.7B-based TaH (hidden_size=2048) +├── qwen3_1.7/ # recipes for Qwen3-1.7B-based TaH (hidden_size=2048) +│ ├── sft_tah_step1.yaml +│ ├── sft_tah_step2.yaml +│ └── eval_tah.yaml +└── qwen3_1.7_1gpu/ # 1-GPU variants of the qwen3_1.7 SFT recipes ├── sft_tah_step1.yaml - ├── sft_tah_step2.yaml - └── eval_tah.yaml + └── sft_tah_step2.yaml ``` +The `qwen3_1.7_1gpu/` recipes shrink `gradient_accumulation_steps` to 4 (vs +16 in the 8-GPU originals), drop `max_length` to 4096, set `report_to: none`, +and write checkpoints under `/tmp/tah_run/` so a 3-stage reproduction fits +on a single B200. Use them with plain `python script/train/SFT_TaH.py` (no +`accelerate launch` needed); the step-2 recipe expects you to fill in +`tah_model_path` with the step-1 `final_model` path before launching. + ## What each file does | File | Purpose | diff --git a/script/recipes/qwen3_1.7_1gpu/sft_tah_step1.yaml b/script/recipes/qwen3_1.7_1gpu/sft_tah_step1.yaml new file mode 100644 index 0000000..d427a32 --- /dev/null +++ b/script/recipes/qwen3_1.7_1gpu/sft_tah_step1.yaml @@ -0,0 +1,51 @@ +# Inline parameter docs live in ../qwen3_1.7/sft_tah_step1.yaml — this file +# diverges only in the training-budget knobs (output_dir, max_length, +# num_train_epochs, gradient_accumulation_steps, save_total_limit, report_to). +model: + name: Qwen/Qwen3-1.7B-Base + torch_dtype: bfloat16 + device_map: auto + trust_remote_code: true + attn_implementation: sdpa + embedding_key: model.embed_tokens + max_iter: 2 + iter_decider: IterLabelDecider + iter_decider_kwargs: + max_iter: 2 + input_updater_kwargs: + topk: 100 + adapter: lora + adapter_kwargs: + r: 32 + lora_alpha: 64 + lora_dropout: 0.1 + target_modules: all-linear + bias: none + train_loss: NextTokenPredLoss + eval_loss: NextTokenPredLoss +data: + train_data_path: data/processed_data/openr1_math/1_7/train + eval_data_path: data/processed_data/openr1_math/1_7/eval + output_dir: /tmp/tah_run/step1/ + max_length: 4096 +training: + num_train_epochs: 2 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 4 + gradient_checkpointing: false + learning_rate: 4.0e-05 + weight_decay: 0.01 + warmup_ratio: 0.03 + max_grad_norm: 0.2 + lr_scheduler_type: cosine_with_min_lr + lr_scheduler_kwargs: + min_lr_rate: 0.1 + logging_steps: 5 + save_strategy: epoch + save_only_model: true + save_total_limit: 2 + bf16: true + eval_strategy: epoch + eval_on_start: true + per_device_eval_batch_size: 1 + report_to: none diff --git a/script/recipes/qwen3_1.7_1gpu/sft_tah_step2.yaml b/script/recipes/qwen3_1.7_1gpu/sft_tah_step2.yaml new file mode 100644 index 0000000..56f7721 --- /dev/null +++ b/script/recipes/qwen3_1.7_1gpu/sft_tah_step2.yaml @@ -0,0 +1,83 @@ +# Inline parameter docs live in ../qwen3_1.7/sft_tah_step2.yaml — this file +# diverges only in the training-budget knobs (output_dir, max_length, +# num_train_epochs, gradient_accumulation_steps, save_total_limit, report_to). +model: + name: Qwen/Qwen3-1.7B-Base + # HF hub id of the step-1 starting point. Use the released TaH-plus-1.7B + # checkpoint by default; substitute /tmp/tah_run/step1/.../final_model if + # you ran step 1 yourself. + tah_model_path: nics-efc/TaH-plus-1.7B + torch_dtype: bfloat16 + device_map: auto + trust_remote_code: true + attn_implementation: sdpa + embedding_key: model.embed_tokens + max_iter: 2 + iter_decider: MLPIterDecider + iter_decider_kwargs: + topk: 100 + hidden_states_size: 2048 + hidden_states_layer_nums: + - 2 + - 10 + - 18 + - 26 + hidden_dims: + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + expansion_factor: 4 + dropout_rate: 0.1 + normalize_input: false + threshold: 0.8 + eval_iter_decider: iter_decider + eval_iter_decider_kwargs: {} + input_updater_kwargs: + topk: 100 + adapter: lora + adapter_kwargs: + r: 32 + lora_alpha: 64 + lora_dropout: 0.1 + target_modules: all-linear + base_grad: false + adapter_grad: false + bias: none + train_loss: IterDeciderLoss + train_loss_kwargs: + pos_weight: 5.4 + skip_last_iter: true + max_iter: 2 + eval_loss: NextTokenPredLoss + eval_loss_kwargs: {} +data: + train_data_path: data/processed_data/openr1_math/1_7/train + eval_data_path: data/processed_data/openr1_math/1_7/eval + output_dir: /tmp/tah_run/step2/ + max_length: 4096 +training: + freeze_component: + - model.simple_base_model + num_train_epochs: 2 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 4 + gradient_checkpointing: false + learning_rate: 3.0e-05 + weight_decay: 0.01 + warmup_ratio: 0.03 + max_grad_norm: 0.2 + lr_scheduler_type: cosine_with_min_lr + lr_scheduler_kwargs: + min_lr_rate: 0.1 + logging_steps: 5 + save_strategy: epoch + save_only_model: true + save_total_limit: 2 + bf16: true + eval_strategy: epoch + eval_on_start: true + per_device_eval_batch_size: 1 + report_to: none diff --git a/tah/evaluate/jobs.py b/tah/evaluate/jobs.py index 4ef0724..71ffd4c 100644 --- a/tah/evaluate/jobs.py +++ b/tah/evaluate/jobs.py @@ -18,12 +18,14 @@ from __future__ import annotations import csv +import ctypes import fcntl import json import logging as pylog import math import os import shutil +import signal import socket import time import traceback @@ -308,8 +310,21 @@ def _setup_logging(level_name: str) -> None: ) +_PR_SET_PDEATHSIG = 1 # no Python constant; + + +def _install_pdeathsig() -> None: + """SIGTERM the worker if its parent dies (Linux). Without this, killing the + eval driver leaves orphan workers reparented to init still pinning the GPU.""" + try: + ctypes.CDLL("libc.so.6", use_errno=True).prctl(_PR_SET_PDEATHSIG, signal.SIGTERM, 0, 0, 0) + except Exception: + pass + + def _run_job_process(job_args: Tuple, result_queue: Queue) -> None: """One job per Process; pins CUDA_VISIBLE_DEVICES + NCCL port for the worker.""" + _install_pdeathsig() (job_id, config, combined_dataset_name, output_dir, timestamp, model_path, job_nums, start_idx, end_idx, tp_size, backend, data_range, gpu_devices, problems_data, field_mapping, unified_code_solutions_file) = job_args From cf04d4254f594994f16d0adfa38594037acbe459 Mon Sep 17 00:00:00 2001 From: Tianyu Fu Date: Fri, 29 May 2026 21:18:05 +0800 Subject: [PATCH 5/5] [update] minor modifications and updates --- .gitignore | 2 ++ CLAUDE.md | 4 ++-- script/preparation/download.py | 2 +- script/train/SFT_TaH.py | 3 +-- tah/evaluate/jobs.py | 3 --- tah/evaluate/matheval.py | 12 +++++++++--- tah/model/causal_cache.py | 12 ++++++++++-- tah/model/tah_model.py | 10 +++++++++- tests/_harness.py | 9 --------- tests/test_jobs_runner.py | 2 +- tests/test_loss.py | 2 +- 11 files changed, 36 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 1d6aa86..12d5065 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ *.png *.log .coverage +*.prof +memory_snapshot.html # Folders reference/ diff --git a/CLAUDE.md b/CLAUDE.md index 3a5a690..0b9d3ff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -55,13 +55,13 @@ see prior ones without disturbing iter-0. ### Install ```bash -conda activate release +conda activate tah uv pip install -e ".[dev,training,evaluation]" ``` ### Tests + benchmarks ```bash -pytest tests/ # 21 component + wrapper + roundtrip tests +pytest tests/ # 11 component + wrapper + roundtrip tests pytest tests/test_.py -v # one file TAH_TEST_DEVICE=cpu pytest tests/ # run on CPU (skip needs no flag) python tests/bench.py components # per-helper microbench (B200 baseline in README) diff --git a/script/preparation/download.py b/script/preparation/download.py index 1e7fb45..617260e 100644 --- a/script/preparation/download.py +++ b/script/preparation/download.py @@ -17,7 +17,7 @@ def download_with_retry(repo_id, repo_type, local_dir, max_retries=5, split=None repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, - allow_patterns=allow_patterns # 启用文件过滤 + allow_patterns=allow_patterns # restrict to the requested file patterns ) print("Download completed successfully!") return diff --git a/script/train/SFT_TaH.py b/script/train/SFT_TaH.py index 6b5785d..b984a44 100644 --- a/script/train/SFT_TaH.py +++ b/script/train/SFT_TaH.py @@ -65,17 +65,16 @@ def _build_model_and_tokenizer(model_config: Dict, accelerator: Accelerator): ) tokenizer.pad_token = tokenizer.eos_token + valid = {f.name for f in fields(TaHConfig)} if "tah_model_path" in model_config: accelerator.print(f"Resuming from TaH checkpoint: {model_config['tah_model_path']}") # Override only the fields explicitly set in the new YAML. - valid = {f.name for f in fields(TaHConfig)} override = TaHConfig(**{k: v for k, v in model_config.items() if k in valid}) model = TaHForCausalLM.from_pretrained( model_config["tah_model_path"], tah_config=override, ).to(dtype=torch_dtype) else: # Construct fresh from a base model + recipe-specified components. - valid = {f.name for f in fields(TaHConfig)} cfg = TaHConfig(**{k: v for k, v in model_config.items() if k in valid}) base = AutoModelForCausalLM.from_pretrained( model_config["name"], torch_dtype=torch_dtype, device_map=device_map, diff --git a/tah/evaluate/jobs.py b/tah/evaluate/jobs.py index 71ffd4c..2b4d057 100644 --- a/tah/evaluate/jobs.py +++ b/tah/evaluate/jobs.py @@ -62,9 +62,6 @@ def _build_problem(item: dict, idx: int, field_mapping: Dict, detail_dir: Path) """Pull standard fields out of an item; also create the per-problem dir.""" pid = str(item.get(field_mapping["id_field"]) or f"problem_{idx}") text = str(item.get(field_mapping["question_field"], "")).strip() - template = field_mapping.get("prompt_template") - if template and "{question}" in template: - text = template.replace("{question}", text) answer = str(item.get(field_mapping["answer_field"], "")).strip() problem_dir = detail_dir / pid problem_dir.mkdir(parents=True, exist_ok=True) diff --git a/tah/evaluate/matheval.py b/tah/evaluate/matheval.py index f5b6a31..7f26dad 100644 --- a/tah/evaluate/matheval.py +++ b/tah/evaluate/matheval.py @@ -63,8 +63,7 @@ def __init__(self, mode: str): raise ValueError(f"unsupported grader mode {mode!r}") self.mode = mode - def rule_judge(self, solution: str, ground_truth: str, finish_generation: bool = True) -> Tuple[bool, str]: - del finish_generation # accepted for caller-protocol uniformity + def rule_judge(self, solution: str, ground_truth: str) -> Tuple[bool, str]: if self.mode == "expr": gold_cfg = [ExprExtractionConfig()] answer_cfg = _OUTPUT_EXTRACTION @@ -78,7 +77,14 @@ def rule_judge(self, solution: str, ground_truth: str, finish_generation: bool = answer_cfg = [StringExtractionConfig()] gold = parse(ground_truth, extraction_config=gold_cfg) - answer = parse(solution, extraction_config=answer_cfg, extraction_mode="first_match") + # Match main's per-mode behavior: math (expr/latex) takes the first match; + # string/multiple-choice uses math_verify's default extraction_mode + # ("any_match"), exactly as main's GPQAEvaluator did (it passed no + # extraction_mode kwarg). Using first_match for string mode would diverge. + if self.mode == "string": + answer = parse(solution, extraction_config=answer_cfg) + else: + answer = parse(solution, extraction_config=answer_cfg, extraction_mode="first_match") if not answer: return False, "No extracted answer" return bool(verify(gold, answer)), str(answer) diff --git a/tah/model/causal_cache.py b/tah/model/causal_cache.py index 688dbca..3173c8b 100644 --- a/tah/model/causal_cache.py +++ b/tah/model/causal_cache.py @@ -166,10 +166,18 @@ def get_cache_length(self, layer_idx: Optional[int] = 0, iter_idx: Optional[int] return t.shape[-2] if t is not None else 0 return sum(t.shape[-2] for (l, _), t in self._k.items() if l == layer_idx) - def get_cache_length_upto_iter(self, layer_idx: Optional[int] = 0, iter_depth: int = 0) -> int: + def get_kv_length_upto_iter(self, layer_idx: int = 0, upto_iter_idx: int = 0) -> int: + """Total KV sequence length across iterations ``0..upto_iter_idx`` (inclusive). + + This is the key/value length a query at iteration ``upto_iter_idx`` attends + to under the duo-mode mask; it equals ``get_cache_upto_iter(...)[0].shape[-2]`` + without materialising the concatenation. Contrast :meth:`get_cache_length`, + which returns a *single* iteration's length (or *all* iterations when + ``iter_idx is None``). + """ return sum( self._k[(layer_idx, i)].shape[-2] - for i in range(iter_depth + 1) if (layer_idx, i) in self._k + for i in range(upto_iter_idx + 1) if (layer_idx, i) in self._k ) def get_seq_length(self, layer_idx: Optional[int] = 0, iter_idx: Optional[int] = 0) -> int: diff --git a/tah/model/tah_model.py b/tah/model/tah_model.py index c19f92a..ae02cf2 100644 --- a/tah/model/tah_model.py +++ b/tah/model/tah_model.py @@ -319,6 +319,14 @@ class TaHForCausalLM(PreTrainedModel): canonical recipes); other modes from public TaH have been removed. """ + # transformers >=4.57 validates attn support against the *wrapper* class + # (not the inner base model) in ``PreTrainedModel.__init__``. Without this + # class-level flag, ``attn_implementation="sdpa"`` — used by the inference + # demo, the bench harness, and the ``tah`` eval backend — raises + # "TaHForCausalLM does not support ... sdpa". Declaring support here keeps + # those entry points working out of the box. + _supports_sdpa = True + def __init__(self, base_model: PreTrainedModel, config: Optional[TaHConfig] = None): # SDPA is the only attention impl we exercise in the recurrent loop. base_model._supports_sdpa = True @@ -480,7 +488,7 @@ def _max_merge_iter_labels( def _forward_kwargs(kwargs: dict) -> dict: """Filter caller-supplied forward kwargs down to those the loss / callback plumbing actually reads.""" - return {k: v for k, v in kwargs.items() if k in ("global_step", "num_items_in_batch")} + return {k: v for k, v in kwargs.items() if k in ("num_items_in_batch",)} # ── forward ─────────────────────────────────────────────────────────── diff --git a/tests/_harness.py b/tests/_harness.py index d6a7d25..8eec5b7 100644 --- a/tests/_harness.py +++ b/tests/_harness.py @@ -80,15 +80,6 @@ def capture(name: str, code: str, payload: dict | None = None) -> dict: payload = payload or {} payload_pkl = pickle.dumps(payload) out_path = baseline_path(name) - runner = ( - "import os, sys, pickle\n" - f"sys.path.insert(0, {str(PUBLIC_ROOT)!r})\n" - "import torch\n" - f"{code}\n" - "payload = pickle.loads(sys.stdin.buffer.read())\n" - "out = run(payload)\n" - "from tests._harness import _to_cpu # noqa: E402 -- imported lazily\n" - ) # We can't import tests._harness inside the subprocess because the public # TaH path is first on sys.path. Inline a CPU mover instead. runner = ( diff --git a/tests/test_jobs_runner.py b/tests/test_jobs_runner.py index 0672ce6..ae9def8 100644 --- a/tests/test_jobs_runner.py +++ b/tests/test_jobs_runner.py @@ -85,7 +85,7 @@ def test_run_single_job_writes_expected_files(tmp_path, tokenizer_stub): with patch.object(jobs, "setup_backend", _fake_backend_factory()), \ patch.object(jobs, "_score_one", _fake_score_one), \ - patch.object(jobs, "AutoTokenizer", create=True) if False else patch( + patch( "transformers.AutoTokenizer.from_pretrained", return_value=tokenizer_stub, ), \ patch.object(jobs, "cleanup", lambda *a, **kw: None): diff --git a/tests/test_loss.py b/tests/test_loss.py index f7ba848..0d03c97 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -117,7 +117,7 @@ def run(payload): iter_decider_threshold=0.5, ) final = loss.final_loss_func( - logits=payload['continue_logits'].view(B, T)[:, :1] if False else payload['continue_logits'].new_zeros(B, T, 1), + logits=payload['continue_logits'].new_zeros(B, T, 1), labels_shifted=payload['iter_count_labels'], iter_count=torch.ones_like(payload['iter_count_labels']), iter_count_labels=payload['iter_count_labels'],