Skip to content

Cold-start training bug on NVIDIA L4: first training process on fresh instance produces degraded output #872

@harrism

Description

@harrism

Summary

On a freshly launched NVIDIA L4 GPU instance, the very first gsplat training process produces severely degraded output (PSNR ~14, loss ~0.18) compared to all subsequent processes on the same instance (PSNR ~27, loss ~0.04). The bug affects whichever optimizer strategy (DefaultStrategy or MCMCStrategy) runs first -- it is not strategy-specific. A short throwaway training run before real work fully mitigates the issue.

Notably, a pure forward rasterization test (no training loop, no backward pass) does not reproduce the bug on L4, suggesting the issue is in the training/backward/optimization path rather than the forward rasterization kernel.

We have not been able to reproduce this on an NVIDIA RTX 6000 Ada GPU, which is also Ada Lovelace (sm_89). So far the bug has only been observed on AWS g6.xlarge instances (NVIDIA L4).

This may be related to #864 (RTX 4090D rendering quality worse than A800), since both L4 and RTX 4090D are Ada Lovelace architecture (sm_89).

Observed Behavior

Experiment 1: DefaultStrategy first, 5 repetitions on fresh L4

(CI run)

All 10 training runs use simple_trainer.py on the mipnerf360 bonsai scene at data_factor=4, ~2600 steps per run:

Run Strategy PSNR SSIM Final Loss Num Gaussians
1 (cold) Default 14.23 0.554 0.181 197,721
2 Default 27.28 0.894 0.043 239,954
3 Default 27.21 0.894 0.044 239,825
4 Default 27.28 0.894 0.043 239,911
5 Default 27.24 0.894 0.042 239,708
6 MCMC 25.23 0.880 0.057 --
7 MCMC 25.14 0.880 0.056 --
8 MCMC 25.15 0.880 0.054 --
9 MCMC 25.04 0.879 0.054 --
10 MCMC 25.12 0.879 0.055 --

Run 1 (the very first process on the fresh GPU) has PSNR 13 dB lower than runs 2-5 and produces ~42K fewer Gaussians, suggesting densification/refinement is also affected.

Experiment 2: MCMCStrategy first, then DefaultStrategy

(CI run)

To confirm the bug is "first process" not "first strategy," we reversed the order:

Run Strategy Variant PSNR SSIM Final Loss
1 (cold) MCMC warmup 14.25 0.555 0.192
2 Default after_warmup 27.23 0.894 0.042

Whichever strategy runs first gets degraded output. The second process works correctly.

Experiment 3: Synthetic forward rasterization (no training)

(CI run, artifact: gsplat-coldstart-repro)

A deterministic synthetic test (200K Gaussians, seed=42, two separate subprocesses, forward pass only) produces identical output for both processes on the same fresh L4:

GPU: NVIDIA L4 (sm_89)

                Run 1 (cold)    Run 2 (warm)   Match
  -----------------------------------------------------
      mean_r        0.642637        0.642637      OK
      mean_g        0.638566        0.638566      OK
      mean_b        0.649349        0.649349      OK
         std        0.035132        0.035132      OK
    checksum   506082.340411   506082.340411      OK
  alpha_mean        0.999770        0.999770      OK

PASS: Both runs produced identical output.

This confirms the bug is not in the core forward rasterization kernel. It manifests somewhere in the training pipeline (backward pass, gradient computation, optimizer, or densification).

Not reproducible on RTX 6000 Ada

We ran the same simple_trainer.py configuration on an NVIDIA RTX 6000 Ada GPU (also Ada Lovelace, sm_89) and did not observe the cold-start degradation. Both the first and subsequent runs produced consistent, correct results. The bug appears to be specific to the L4, possibly related to its driver configuration or memory subsystem on AWS instances.

Environment

  • GPU: NVIDIA L4 (Ada Lovelace, sm_89, 22 GB VRAM)
  • Instance: AWS g6.xlarge (fresh instance per CI run)
  • gsplat: commit b60e917 (post-v1.5.3, "set default adam to fused", Nov 2025)
  • PyTorch: 2.8.0 with CUDA 12.9
  • Python: 3.12
  • OS: Ubuntu 22.04 (in aswf/ci-openvdb:2024-clang17.2 container)

Steps to Reproduce

  1. Launch a fresh NVIDIA L4 instance (AWS g6.xlarge or equivalent)
  2. Install gsplat and dependencies
  3. Run simple_trainer.py twice in sequence:
# Run 1 (cold): will produce degraded output (~14 PSNR)
python simple_trainer.py default \
  --data_dir /path/to/360_v2/bonsai \
  --data_factor 4 \
  --result_dir /tmp/run1 \
  --max_steps 2000 \
  --disable_viewer

# Run 2 (warm): will produce correct output (~27 PSNR)
python simple_trainer.py default \
  --data_dir /path/to/360_v2/bonsai \
  --data_factor 4 \
  --result_dir /tmp/run2 \
  --max_steps 2000 \
  --disable_viewer

The key requirement is that run 1 must be the very first gsplat training process on the GPU since instance launch. If any prior gsplat training has occurred (even a short 500-step warmup), the bug does not manifest.

Synthetic rasterization test script (does NOT reproduce the bug, but useful for diagnosis)
"""
Minimal forward-only rasterization test. Spawns two subprocesses with
identical deterministic inputs and compares rendered image checksums.
PASSES on L4, confirming the bug is training-path specific.
"""
import json
import subprocess
import sys
import textwrap

SYNTHETIC_WORKER = textwrap.dedent(r'''
import torch
import json

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

device = "cuda"
N = 200000
H, W = 512, 512

means = torch.randn(N, 3, device=device) * 2.0
quats = torch.randn(N, 4, device=device)
quats = quats / quats.norm(dim=-1, keepdim=True)
scales = torch.rand(N, 3, device=device) * 0.05
opacities = torch.sigmoid(torch.randn(N, device=device))
sh0 = torch.rand(N, 1, 3, device=device)

viewmat = torch.eye(4, device=device).unsqueeze(0)
viewmat[0, 2, 3] = 5.0

K = torch.zeros(1, 3, 3, device=device)
K[0, 0, 0] = 500.0
K[0, 1, 1] = 500.0
K[0, 0, 2] = W / 2
K[0, 1, 2] = H / 2
K[0, 2, 2] = 1.0

from gsplat import rasterization

with torch.no_grad():
    renders, alphas, info = rasterization(
        means=means, quats=quats, scales=scales, opacities=opacities,
        colors=sh0, viewmats=viewmat, Ks=K, width=W, height=H,
        sh_degree=0, render_mode="RGB", packed=False,
    )

img = renders[0]
result = {
    "mean_r": float(img[:, :, 0].mean()),
    "mean_g": float(img[:, :, 1].mean()),
    "mean_b": float(img[:, :, 2].mean()),
    "std": float(img.std()),
    "checksum": float(img.double().sum()),
    "alpha_mean": float(alphas[0].mean()),
    "gpu": torch.cuda.get_device_name(0),
    "compute_cap": list(torch.cuda.get_device_capability(0)),
}
print("RESULT:" + json.dumps(result))
''')


def run_subprocess(label, script):
    result = subprocess.run(
        [sys.executable, "-c", script], capture_output=True, text=True,
    )
    if result.returncode != 0:
        print(f"  [{label}] FAILED (exit code {result.returncode})")
        print(f"  {(result.stderr or '')[-2000:]}")
        return None
    for line in result.stdout.strip().split("\n"):
        if line.startswith("RESULT:"):
            return json.loads(line[7:])
    print(f"  [{label}] No RESULT line in output")
    return None


r1 = run_subprocess("run1-cold", SYNTHETIC_WORKER)
r2 = run_subprocess("run2-warm", SYNTHETIC_WORKER)

if r1 and r2:
    print(f"GPU: {r1['gpu']} (sm_{r1['compute_cap'][0]}{r1['compute_cap'][1]})")
    all_match = True
    for key in ["mean_r", "mean_g", "mean_b", "std", "checksum", "alpha_mean"]:
        v1, v2 = r1[key], r2[key]
        match = abs(v1 - v2) < 1e-4
        if not match:
            all_match = False
        print(f"  {key:>12s}  {v1:>14.6f}  {v2:>14.6f}  {'OK' if match else 'DIFF'}")
    print("PASS" if all_match else "FAIL: cold process produced different output")

Workaround

A short throwaway training run (~500 steps) before real work fully mitigates the issue:

python simple_trainer.py default \
  --data_dir /path/to/any/scene \
  --data_factor 4 \
  --result_dir /tmp/warmup \
  --max_steps 500 \
  --batch_size 1 \
  --disable_viewer --disable_video
rm -rf /tmp/warmup

Questions

  1. Could this be related to uninitialized GPU/CUDA state on the first training process (e.g., cuBLAS workspace, cuDNN autotuning, or TF32 precision settings)?
  2. Is there anything in the backward pass or rasterize_to_pixels backward kernel that might behave differently on the first invocation?
  3. The reduced Gaussian count (197K vs 240K) on the cold run suggests densification criteria are affected -- could gradient accumulation or the densification thresholds be impacted by the same root cause?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions