███╗ ██╗ █████╗ ███╗ ██╗ ██████╗ ██████╗██╗ ██╗ █████╗ ████████╗ ██╗ █████╗ ██╗ ██╗
████╗ ██║██╔══██╗████╗ ██║██╔═══██╗██╔════╝██║ ██║██╔══██╗╚══██╔══╝ ██║██╔══██╗╚██╗██╔╝
██╔██╗ ██║███████║██╔██╗ ██║██║ ██║██║ ███████║███████║ ██║█████╗ ██║███████║ ╚███╔╝
██║╚██╗██║██╔══██║██║╚██╗██║██║ ██║██║ ██╔══██║██╔══██║ ██║╚════╝██ ██║██╔══██║ ██╔██╗
██║ ╚████║██║ ██║██║ ╚████║╚██████╔╝╚██████╗██║ ██║██║ ██║ ██║ ╚█████╔╝██║ ██║██╔╝ ██╗
╚═╝ ╚═══╝╚═╝ ╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚═════╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═╝ ╚════╝ ╚═╝ ╚═╝╚═╝ ╚═╝
JAX / Flax NNX port of Andrej Karpathy's nanochat, optimized for Google Cloud TPU.
nanochat-jax is a TPU-optimized JAX/NNX port of karpathy/nanochat: same d24 baseline model shape, same 2048-token context, same 524K-token batch, same 16,704-step train horizon, and the same CORE eval protocol.
On v6e-8 spot, nanochat-jax reached CORE=0.274 within a $100 TPU budget. The reference path is runs/speedrun.sh: d24 base training followed by CORE eval.
nanochat-jax uses a standard Python venv. For TPU:
python3.11 -m venv ~/venv
source ~/venv/bin/activate
pip install -U pip
pip install -e ".[tpu,dev]" \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.htmlFor Mac CPU or plain CPU/GPU without TPU:
pip install -e ".[dev]"runs/speedrun.sh is the public run script. It uses the d24 TPU train shape used by the current result: seq_len=2048, total batch 524288, Splash Attention, onehot value-embedding gradients, and model-tag checkpoint/eval continuity.
On a v6e-8 TPU host:
bash runs/speedrun.shCORE eval is the default. For a quick smoke run:
CORE_MAX_PER_TASK=50 NUM_ITERATIONS=250 bash runs/speedrun.shThe v6e-8 spot result fits within the $100 TPU budget. Pricing changes; see dev/LEADERBOARD.md for the result row, spot price basis, and hardware comparison.
python -m scripts.chat_cli --model-tag d24_speedrun # interactive
python -m scripts.chat_cli --model-tag d24_speedrun --prompt "Hello" -t 0 # single-prompt, greedyThe speedrun checkpoint is a base model. Instruction following is limited until an SFT checkpoint is trained.
.
|-- LICENSE
|-- README.md
|-- pyproject.toml
|-- nanochat_jax/
| |-- gpt.py # GPT model: XLA + Splash Attention paths
| |-- optim.py # AdamW + Muon
| |-- base_train.py # train_step, schedules, sharded factory
| |-- engine.py # KV cache + sampling
| |-- checkpoint_manager.py # save / load + optimizer state
| |-- common.py # cache dir, distributed info, file lock
| |-- core_eval.py # DCLM CORE tasks
| |-- loss_eval.py # bits-per-byte
| |-- dataloader.py # BOS-aligned best-fit packing
| |-- dataset.py # parquet shard download / iteration
| |-- tokenizer.py # RustBPE + tiktoken + HF wrapper
| |-- layers.py # NanochatLinear
| |-- sharding.py # Mesh + Muon ZeRO-2 spec
| `-- tasks/ # ARC, MMLU, GSM8K, HumanEval, SmolTalk, ...
|-- scripts/
| |-- base_train.py # base training
| |-- base_eval.py # BPB, samples, CORE
| |-- chat_sft.py # SFT
| |-- chat_eval.py # ChatCORE
| |-- chat_cli.py # interactive chat
| |-- tok_train.py # tokenizer train
| `-- tok_eval.py # tokenizer eval
|-- runs/
| |-- speedrun.sh # public d24 reference run
| `-- README.md
|-- tests/
`-- dev/
`-- LEADERBOARD.md
- Use a
v6e-8TPU for the d24 reference run. Capacity and pricing can vary. - Use
--dry-runonscripts/base_train.pyfor cheap preflight checks. - Keep the GCS bucket in the same region as the TPU.
- Copy expensive checkpoints to GCS before deleting the TPU VM.
- Delete TPU VMs and sweep candidate zones after each session.
Done:
- d24 v6e-8 TPU/JAX train result with BPB and CORE artifacts.
- Public speedrun path from base train to BPB/CORE eval.
- Pallas Splash Attention path.
- Muon optimizer support.
Not in the default path yet:
- SFT and chat eval publish-path validation.
- Upstream 8xH100 Time-to-GPT-2 leaderboard claim.
- v5p rows; retired until a separate parity audit.
- RL fine-tuning, FP8 training, and web chat UI.
- Andrej Karpathy / nanochat - the upstream PyTorch implementation.
- Keller Jordan / modded-nanogpt - origin of the Muon optimizer.
google/jaxandgoogle/flax- the framework.AI-Hypercomputer/maxtext- Splash Attention integration patterns.
MIT. See LICENSE.
This is a personal project. The views, code, and opinions expressed here are my own and do not represent those of my current or past employers.