[major] code simplify#6
Open
fuvty wants to merge 5 commits into
Open
Conversation
… test+bench harness Squashes 38 iterative commits on the simplify branch into one cohesive cleanup of public TaH (forked from thu-nics/TaH) targeting nics-efc/TaH-plus-1.7B only. Model wrapper (tah/model/): - Drop the by-name registry; inline the single-impl slots (input_updater, output_updater, iter_label_generator, adapter) directly into the wrapper. - Rename recurrent_transformer.py -> tah_model.py; expose iter-loop phases as named helpers (Phase 1-8); extract _resolve_tah_config, _build_loss / _build_iter_decider, _setup_lora, _set_lora_grad_flags helpers. - Vectorise scatter_back; flatten TaHCache dict-of-dicts; trim dead helpers in utils.py and the playground. - IterDeciderLoss: extract _record_threshold_accuracy and _scatter_per_token for readability without changing semantics. - MLPIterDecider: extract _select_hidden_features; thread dtype through _ClassifierBackbone / _ClassifierBlock. - Inline TaHConfig (de)serialiser; round-trip torch.dtype via type_to_dict_string / dict_string_to_type. from_pretrained now resolves HF Hub IDs via snapshot_download and places iter_decider on the model device. Eval driver (tah/evaluate/): - Split eval_unified into datasets.py / backends.py / jobs.py; keep eval_unified.py as a backwards-compat shim. - backends.setup_backend collapsed to a table dispatch. - jobs.run_single_job split into _job_output_dir / _build_prompts / _process_batch with a small unit test that stubs the backend (no GPU/model). - Trim utils.py + matheval.py + codeeval.py; drop tracker integration. Train + scripts: - HF Trainer subclass + collator + iter-aware callback kept; trim dead branches in tah/train and script/train/SFT_TaH.py. - script/preparation/label.py trimmed to a Qwen3-only labeller (drop MobileLLM/R1 templates and the verbose statistics report). - script/preparation/prune.py: drop env override + arrow fallback. - script/evaluation/eval.py: tighten CLI. - Recipes: drop inert fields, fix qwen3_0.6 step1 decider, drop dead variants. Tests + bench: - tests/_harness.py + per-component test_*.py: snapshot-based differential testing against /tmp/TaH-pub via PYTHONPATH-isolated subprocess; baselines in tests/baselines/ are gitignored and regenerate on first run. - tests/test_sft_smoke.py: trainer + collator end-to-end on synthetic data. - tests/bench.py: component + e2e microbenchmarks (B200 baselines documented in README); tests/bench_compile.py: torch.compile vs eager experiment. Docs: - README + CLAUDE.md refreshed for the cleaned layout, helper reference, bench harness and B200 baselines.
- README: trim Future Work, fix bibtex arxiv ID, dedupe step headers, compress training note, replace bench table with pointer to tests/README - CLAUDE.md: drop bash/ and eval_unified rows from layout tree, point to tests/README for snapshot harness details - bash/: delete redundant shell wrappers (commands already in README, and hard-coded HF mirror was region-specific) - tah/evaluate/eval_unified.py: delete backwards-compat shim (no callers; same names already exported through tah/evaluate/__init__.py) - pyproject.toml: drop unused torchvision dep - .gitignore: add .pytest_cache/, __pycache__/, .coverage - tests/README.md: new — run instructions, _harness.py baseline-snapshot system, benchmark scripts, component-baseline timings (moved from project README) - script/recipes/README.md: new — directory layout, step1/step2/eval semantics, supported model sizes (Qwen3-0.6B / 1.7B) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
(1) topk_softmax_input_update autocast dtype mismatch Trainer's eval-on-start runs the wrapper inside accelerate's autocast. softmax/sum can promote bf16 inputs to fp32, then the index-put scatter ``active_input_embeds[active_next] = …`` errors with "got BFloat16 for the destination and Float for the source". Cast the helper's return to ``embedding_weight.dtype``. (2) _set_lora_grad_flags early-return left base model frozen PEFT's ``get_peft_model`` unconditionally freezes every non-lora param. The prior early-return at ``base_grad=True, adapter_grad=True`` (step-1 SFT default) was based on a wrong assumption that HF Trainer would re-enable them. Result: step-1 silently trained only the LoRA adapter (5M of 601M params on Qwen3-0.6B; 0.13 GB instead of the intended 3.4 GB on Qwen3-1.7B). Always reapply the flags so step-1 trains base + LoRA and step-2's ``base_grad=False, adapter_grad=False`` still freezes them. (3) in-place index_put broke autograd once base embeds carried grad Phase-8 wrote next-iter embeddings via ``active_input_embeds[active_next] = topk_softmax_input_update(…)``. The same tensor was already saved by autograd through the ``simple_base_model(inputs_embeds=active_input_embeds)`` call earlier in the iteration; mutating it in place tripped ``one of the variables needed for gradient computation has been modified by an inplace operation``. Latent until fix (2) above made base embeds trainable. Clone before the index_put. 13/13 acc tests still pass on CPU; smoke train (5 steps, Qwen3-1.7B-Base) shows trainable=3.335 GB, grad_norm=7.13, loss converging.
Hit while reproducing the README's three-stage pipeline (data prep → step1 SFT → step2 SFT → eval) on a single B200. - script/playground/inference_example.py: add --max-new-tokens flag and drop the default from 16384 to 512. The hardcoded 16384 made the "demo" take ~15-20 min; a newcomer would think it was hung. Pass --max-new-tokens 16384 to see a full reasoning chain. - script/preparation/download.py: gate ``HF_ENDPOINT='https://hf-mirror.com'`` behind ``HF_MIRROR=1``. The hardcoded mirror silently failed (or was very slow) for users outside China. - tah/evaluate/jobs.py: install ``prctl(PR_SET_PDEATHSIG, SIGTERM)`` in each multiprocessing worker. Killing the eval driver previously left orphan workers reparented to init still pinning the GPU; saw this concretely during a smoke-test cancel. - README.md: install-section note about ``__editable___tah_*_finder.py`` going stale when ``__init__.py`` is added/removed (re-run ``pip install -e .``); --max-new-tokens example for the playground; a Single-GPU smoke subsection under Run evaluation showing how to slice the dataset + shrink ``max_new_tokens`` for a few-minute sanity check. - script/recipes/qwen3_1.7_1gpu/: new 1-GPU SFT recipes — grad-accum 4 (vs 16), max_length 4096, gradient_checkpointing false, report_to none, outputs under /tmp/tah_run/. Step-2 starts from nics-efc/TaH-plus-1.7B by default (override with your local step-1 final_model). Use with plain ``python script/train/SFT_TaH.py`` (no accelerate launch). - script/recipes/README.md: list the new qwen3_1.7_1gpu/ dir + describe the scaling deltas vs the 8-GPU originals.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.