[vla-fine-tuning] perf: ~5× per-step speedup; zero data spillage#705
Open
irradiantlife wants to merge 2 commits into
Open
[vla-fine-tuning] perf: ~5× per-step speedup; zero data spillage#705irradiantlife wants to merge 2 commits into
irradiantlife wants to merge 2 commits into
Conversation
## Summary
ENG-level perf cleanup of the VLA fine-tuning template. No model /
numerical changes -- removes synchronous host stalls in the training
loop and the producer overruns they caused.
Smoke benchmark, Anyscale workspace, 4× L4, MAX_TRAIN_STEPS=100,
freshly-restarted cluster:
| Metric | Before | After | Delta |
|--------------------------------|----------|---------|--------|
| Per-step training body | 3.17 s | 0.60 s | -81 % |
| Dataset producer time | 316.87 s | 60.37 s | -81 % |
| Object-store spillage (peak) | 262 GB | 0 GB | gone |
Per-step is computed as `dataset_exec_time / num_steps` -- under this
template's overlapped producer/consumer pipeline, the dataset producer
runs as long as consumers are pulling, so this captures the steady-state
training-body cost cleanly (excluding cluster setup and checkpoint
upload).
(Total wall-clock improvement is workload-dependent: dominated by the
~5× per-step speedup once setup/teardown amortizes. On the @100-step
smoke this is ~40 % wall-clock; longer runs converge toward the 5×
ratio. Numbers reproduced on a fresh workspace -- back-to-back runs on
the same Ray cluster show contamination as per-node spill files persist
between runs.)
## Why
The template's GPU consumers were never starved, but the consumer-side
plumbing forced repeated host syncs:
1. The collate did synchronous (default-stream) H2D copies, so even with
`non_blocking=True` the H2D serialized against compute on the device.
2. Per-step `.item()` / `.as_py()` calls forced host syncs and Arrow
scalar boxing inside the hot loops.
3. No GPU prefetch -- compute and H2D fought for the default stream,
no overlap.
The 262 GB of object-store spillage in the baseline was a symptom of
the slow consumer giving Ray Data producers a wide window to overrun.
Once the consumer-side stalls are removed, Ray's backpressure system
reaches equilibrium on its own; no producer concurrency cap needed.
## Changes
- **`util.py`**
- `NumpyToTorchCollate`: produce pinned-CPU tensors (no H2D); pair
with new `cuda_prefetcher` for device-level overlap.
- `cuda_prefetcher`: 1-batch GPU prefetch on a dedicated CUDA stream
so batch N+1's H2D copy overlaps with batch N's fwd/bwd.
- `enable_gpu_perf_flags`: TF32 + cudnn.benchmark, called from the
worker.
- `make_trainable_optimizer`: fused-CUDA AdamW + caches trainable
param list so `clip_grad_norm_` doesn't re-walk every step.
- `train_step`: return loss as a 0-d device tensor (was `float()`),
eliminating the per-step host sync.
- **`vla.py`** (mirrored in `README.ipynb`)
- Fuse `transpose_images` (stack + transpose + astype) into a single
pre-allocated float32 buffer.
- Loss accumulator stays on-device as a 0-d tensor; `.item()` only
at log boundary + end-of-epoch.
- **`lerobot_datasource.py`**
- Hoist Arrow `.as_py()` boxing out of the per-row read loop --
convert columns to python lists once per parquet table.
## Test plan
- [x] `pre-commit run --files <changed>` clean
- [x] `python ci/validate_build_yaml.py --no-network` passes
- [x] Anyscale workspace smoke (4× L4) reproduced cleanly.
- [ ] `/test-template vla-fine-tuning` on the PR for the Buildkite
smoke (currently `tests/vla-fine-tuning/tests.sh` is fully
commented out, so this only verifies the workspace+notebook
pipeline path).
To reproduce on a freshly-restarted Anyscale workspace:
```bash
export HF_TOKEN=hf_... # needs "Read access to gated repos"
export MAX_TRAIN_STEPS=100
time uv run papermill README.ipynb /tmp/out.ipynb -k python3 --log-output
```
Look for "Dataset train_<id> execution finished in N.NN seconds" -- divide
by MAX_TRAIN_STEPS for an apples-to-apples per-step number. Object-store
spillage shows up as "Spilled N MiB" log lines (should be absent on the
perf branch).
Contributor
|
/test-template vla-fine-tuning |
Contributor
|
Thank you for the contribution Hi @shorbaji i believe you wrote the template, can you give a review ? |
Contributor
|
/test-template vla-fine-tuning |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
ENG-level perf cleanup of the VLA fine-tuning template. No model / numerical changes -- removes synchronous host stalls in the training loop and the producer overruns they caused. Primary motivation is removing CPU:GPU sync with an async helper. Cost of helper is reserving 1 batch size of GPU RAM.
Smoke benchmark, Anyscale workspace, 4× L4, MAX_TRAIN_STEPS=100, freshly-restarted cluster:
Per-step is computed as
dataset_exec_time / num_steps-- under this template's overlapped producer/consumer pipeline, the dataset producer runs as long as consumers are pulling, so this captures the steady-state training-body cost cleanly (excluding cluster setup and checkpoint upload).(Total wall-clock improvement is workload-dependent: dominated by the ~5× per-step speedup once setup/teardown amortizes. On the @100-step smoke this is ~40 % wall-clock; longer runs converge toward the 5× ratio. Numbers reproduced on a fresh workspace -- back-to-back runs on the same Ray cluster show contamination as per-node spill files persist between runs.)
Why
The template's GPU consumers were never starved, but the consumer-side plumbing forced repeated host syncs:
non_blocking=Truethe H2D serialized against compute on the device..item()/.as_py()calls forced host syncs and Arrow scalar boxing inside the hot loops.The 262 GB of object-store spillage in the baseline was a symptom of the slow consumer giving Ray Data producers a wide window to overrun. Once the consumer-side stalls are removed, Ray's backpressure system reaches equilibrium on its own.
Changes
util.pyNumpyToTorchCollate: produce pinned-CPU tensors (no H2D); pair with newcuda_prefetcherfor device-level overlap.cuda_prefetcher: 1-batch GPU prefetch on a dedicated CUDA stream so batch N+1's H2D copy overlaps with batch N's fwd/bwd.enable_gpu_perf_flags: TF32 + cudnn.benchmark, called from the worker.make_trainable_optimizer: fused-CUDA AdamW + caches trainable param list soclip_grad_norm_doesn't re-walk every step.train_step: return loss as a 0-d device tensor (wasfloat()), eliminating the per-step host sync.vla.py(mirrored inREADME.ipynb)transpose_images(stack + transpose + astype) into a single pre-allocated float32 buffer..item()only at log boundary + end-of-epoch.lerobot_datasource.py.as_py()boxing out of the per-row read loop -- convert columns to python lists once per parquet table.Test plan
pre-commit run --files <changed>cleanpython ci/validate_build_yaml.py --no-networkpasses/test-template vla-fine-tuningon the PR for the Buildkite smoke (currentlytests/vla-fine-tuning/tests.shis fully commented out, so this only verifies the workspace+notebook pipeline path).To reproduce on a freshly-restarted Anyscale workspace:
Look for "Dataset train_ execution finished in N.NN seconds" -- divide by MAX_TRAIN_STEPS for an apples-to-apples per-step number. Object-store spillage shows up as "Spilled N MiB" log lines (should be absent on the perf branch).