Minimal PyTorch prototype of a hierarchical byte-level language model with adaptive border selection, plus a residual byte-level GRU baseline.
- Level 0 models raw UTF-8 bytes and promotes positions whose next-byte prediction is wrong.
- Higher levels model only the promoted border embeddings and promote positions whose next-embedding prediction MSE is above a threshold.
- Each sequence always starts with a border at position
0, so every compressed token owns a span in the level below. - The decoder mirrors the compressor and reconstructs lower-level sequences by broadcast-adding each parent embedding over its child span.
By default, border selection is causal:
- Level 0 accumulates byte-prediction entropy within the current span and promotes a border once the cumulative sum exceeds
--byte-entropy-threshold. - Higher levels accumulate predicted next-embedding uncertainty within the current span and promote a border once the cumulative sum exceeds
--meta-uncertainty-threshold. - The original teacher-forced border rule is still available with
--border-mode teacher_forced, but it leaks target information and should not be used for fair LM evaluation. --thresholdis kept for the legacy teacher-forced meta-border MSE threshold.- A small entropy regularizer can keep the byte model from becoming too overconfident everywhere without directly changing the routing rule.
Use --model-type baseline for a non-hierarchical byte model built from residual 2-layer GRU blocks. It uses only the final next-byte loss, but matches the adaptive model's trainable parameter count exactly.
python -m venv .venv
source .venv/bin/activate
pip install torch datasets
pip install wandbpython -m adaptive_compressor.train \
--model-type adaptive \
--border-mode uncertainty \
--sequence-length 128 \
--batch-size 8 \
--hidden-size 128 \
--num-levels 3 \
--threshold 0.1 \
--byte-entropy-threshold 20.0 \
--meta-uncertainty-threshold 1.0 \
--entropy-floor 0.0 \
--entropy-reg-weight 0.001 \
--eval-every 20 \
--max-steps 200Or use the helper scripts:
CUDA_VISIBLE_DEVICES=0 scripts/run_adaptive.sh --max-steps 200 --batch-size 16
CUDA_VISIBLE_DEVICES=1 scripts/run_baseline.sh --max-steps 200 --batch-size 16The scripts accept common hyperparameters through environment variables such as SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, NUM_LEVELS, and also forward any extra CLI flags you append.
python -m adaptive_compressor.train \
--model-type baseline \
--sequence-length 128 \
--batch-size 8 \
--hidden-size 128 \
--num-levels 3 \
--threshold 0.1 \
--eval-every 20 \
--max-steps 200The script uses Salesforce/wikitext with config wikitext-103-raw-v1 by default, logs to the Weights & Biases project adaptive_compressor, evaluates on the validation split every --eval-every steps using a small subset, and writes a checkpoint to checkpoints/adaptive_compressor.pt.
Pass --disable-wandb to run without online logging.
If --wandb-run-name is not provided, the default run name is ${model_type}_L${sequence_length}_B${batch_tokens // 1000}k.
If you pass --border-mode teacher_forced, training will warn that the adaptive routing uses target-dependent borders and therefore leaks future information.
For comparison to the residual baseline, prefer byte_encoder_bpb rather than byte_decoder_bpb, because the adaptive decoder adds extra depth even when the hierarchy collapses.
The current inference path is intentionally simple and causal: it recomputes the hierarchy from the current prefix at every generation step. This is slower than a cached scheduler, but it is the cleanest way to verify causal behavior.
python -m adaptive_compressor.infer checkpoints/adaptive_compressor.pt \
--prompt "The meaning of compression is" \
--max-new-bytes 128 \
--temperature 1.0 \
--top-k 0To check whether prefix-only logits match full-sequence logits on a prompt prefix:
python -m adaptive_compressor.infer checkpoints/adaptive_compressor.pt \
--prompt "The meaning of compression is" \
--check-causality \
--causality-max-positions 64 \
--max-new-bytes 16adaptive_compressor/models/- split model package with shared modules, adaptive model, and baseline.adaptive_compressor/routing.py- border selection and span routing helpers.adaptive_compressor/data.py- WikiText byte dataset.adaptive_compressor/train.py- small training entry point.
Open docs/model_visualization.html for a colleague-facing explanation of the current architecture, the cumulative border rule, and the exact cached inference loop.