Skip to content

[major] code simplify#6

Open
fuvty wants to merge 5 commits into
thu-nics:mainfrom
fuvty:simplify
Open

[major] code simplify#6
fuvty wants to merge 5 commits into
thu-nics:mainfrom
fuvty:simplify

Conversation

@fuvty

@fuvty fuvty commented Apr 25, 2026

Copy link
Copy Markdown
Member

No description provided.

fuvty and others added 2 commits April 25, 2026 22:40
… test+bench harness

Squashes 38 iterative commits on the simplify branch into one cohesive cleanup
of public TaH (forked from thu-nics/TaH) targeting nics-efc/TaH-plus-1.7B only.

Model wrapper (tah/model/):
- Drop the by-name registry; inline the single-impl slots (input_updater,
  output_updater, iter_label_generator, adapter) directly into the wrapper.
- Rename recurrent_transformer.py -> tah_model.py; expose iter-loop phases as
  named helpers (Phase 1-8); extract _resolve_tah_config, _build_loss /
  _build_iter_decider, _setup_lora, _set_lora_grad_flags helpers.
- Vectorise scatter_back; flatten TaHCache dict-of-dicts; trim dead helpers
  in utils.py and the playground.
- IterDeciderLoss: extract _record_threshold_accuracy and _scatter_per_token
  for readability without changing semantics.
- MLPIterDecider: extract _select_hidden_features; thread dtype through
  _ClassifierBackbone / _ClassifierBlock.
- Inline TaHConfig (de)serialiser; round-trip torch.dtype via
  type_to_dict_string / dict_string_to_type. from_pretrained now resolves HF
  Hub IDs via snapshot_download and places iter_decider on the model device.

Eval driver (tah/evaluate/):
- Split eval_unified into datasets.py / backends.py / jobs.py; keep
  eval_unified.py as a backwards-compat shim.
- backends.setup_backend collapsed to a table dispatch.
- jobs.run_single_job split into _job_output_dir / _build_prompts /
  _process_batch with a small unit test that stubs the backend (no GPU/model).
- Trim utils.py + matheval.py + codeeval.py; drop tracker integration.

Train + scripts:
- HF Trainer subclass + collator + iter-aware callback kept; trim dead
  branches in tah/train and script/train/SFT_TaH.py.
- script/preparation/label.py trimmed to a Qwen3-only labeller (drop
  MobileLLM/R1 templates and the verbose statistics report).
- script/preparation/prune.py: drop env override + arrow fallback.
- script/evaluation/eval.py: tighten CLI.
- Recipes: drop inert fields, fix qwen3_0.6 step1 decider, drop dead variants.

Tests + bench:
- tests/_harness.py + per-component test_*.py: snapshot-based differential
  testing against /tmp/TaH-pub via PYTHONPATH-isolated subprocess; baselines
  in tests/baselines/ are gitignored and regenerate on first run.
- tests/test_sft_smoke.py: trainer + collator end-to-end on synthetic data.
- tests/bench.py: component + e2e microbenchmarks (B200 baselines documented
  in README); tests/bench_compile.py: torch.compile vs eager experiment.

Docs:
- README + CLAUDE.md refreshed for the cleaned layout, helper reference,
  bench harness and B200 baselines.
- README: trim Future Work, fix bibtex arxiv ID, dedupe step headers,
  compress training note, replace bench table with pointer to tests/README
- CLAUDE.md: drop bash/ and eval_unified rows from layout tree, point to
  tests/README for snapshot harness details
- bash/: delete redundant shell wrappers (commands already in README, and
  hard-coded HF mirror was region-specific)
- tah/evaluate/eval_unified.py: delete backwards-compat shim (no callers;
  same names already exported through tah/evaluate/__init__.py)
- pyproject.toml: drop unused torchvision dep
- .gitignore: add .pytest_cache/, __pycache__/, .coverage
- tests/README.md: new — run instructions, _harness.py baseline-snapshot
  system, benchmark scripts, component-baseline timings (moved from
  project README)
- script/recipes/README.md: new — directory layout, step1/step2/eval
  semantics, supported model sizes (Qwen3-0.6B / 1.7B)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@fuvty fuvty self-assigned this Apr 27, 2026
@fuvty fuvty added the enhancement New feature or request label Apr 27, 2026
fuvty added 2 commits May 1, 2026 10:05
(1) topk_softmax_input_update autocast dtype mismatch
Trainer's eval-on-start runs the wrapper inside accelerate's autocast.
softmax/sum can promote bf16 inputs to fp32, then the index-put scatter
``active_input_embeds[active_next] = …`` errors with "got BFloat16 for the
destination and Float for the source". Cast the helper's return to
``embedding_weight.dtype``.

(2) _set_lora_grad_flags early-return left base model frozen
PEFT's ``get_peft_model`` unconditionally freezes every non-lora param. The
prior early-return at ``base_grad=True, adapter_grad=True`` (step-1 SFT
default) was based on a wrong assumption that HF Trainer would re-enable
them. Result: step-1 silently trained only the LoRA adapter (5M of 601M
params on Qwen3-0.6B; 0.13 GB instead of the intended 3.4 GB on
Qwen3-1.7B). Always reapply the flags so step-1 trains base + LoRA and
step-2's ``base_grad=False, adapter_grad=False`` still freezes them.

(3) in-place index_put broke autograd once base embeds carried grad
Phase-8 wrote next-iter embeddings via
``active_input_embeds[active_next] = topk_softmax_input_update(…)``. The
same tensor was already saved by autograd through the
``simple_base_model(inputs_embeds=active_input_embeds)`` call earlier in
the iteration; mutating it in place tripped
``one of the variables needed for gradient computation has been modified
by an inplace operation``. Latent until fix (2) above made base embeds
trainable. Clone before the index_put.

13/13 acc tests still pass on CPU; smoke train (5 steps, Qwen3-1.7B-Base)
shows trainable=3.335 GB, grad_norm=7.13, loss converging.
Hit while reproducing the README's three-stage pipeline (data prep → step1
SFT → step2 SFT → eval) on a single B200.

- script/playground/inference_example.py: add --max-new-tokens flag and drop
  the default from 16384 to 512. The hardcoded 16384 made the "demo" take
  ~15-20 min; a newcomer would think it was hung. Pass --max-new-tokens 16384
  to see a full reasoning chain.
- script/preparation/download.py: gate ``HF_ENDPOINT='https://hf-mirror.com'``
  behind ``HF_MIRROR=1``. The hardcoded mirror silently failed (or was very
  slow) for users outside China.
- tah/evaluate/jobs.py: install ``prctl(PR_SET_PDEATHSIG, SIGTERM)`` in each
  multiprocessing worker. Killing the eval driver previously left orphan
  workers reparented to init still pinning the GPU; saw this concretely
  during a smoke-test cancel.
- README.md: install-section note about ``__editable___tah_*_finder.py`` going
  stale when ``__init__.py`` is added/removed (re-run ``pip install -e .``);
  --max-new-tokens example for the playground; a Single-GPU smoke subsection
  under Run evaluation showing how to slice the dataset + shrink
  ``max_new_tokens`` for a few-minute sanity check.
- script/recipes/qwen3_1.7_1gpu/: new 1-GPU SFT recipes — grad-accum 4 (vs
  16), max_length 4096, gradient_checkpointing false, report_to none,
  outputs under /tmp/tah_run/. Step-2 starts from nics-efc/TaH-plus-1.7B by
  default (override with your local step-1 final_model). Use with plain
  ``python script/train/SFT_TaH.py`` (no accelerate launch).
- script/recipes/README.md: list the new qwen3_1.7_1gpu/ dir + describe the
  scaling deltas vs the 8-GPU originals.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant