Skip to content

poc: reject bad batches with os._exit(1)#35

Open
IgnatovFedor wants to merge 1 commit into
tg/scratchpad_for_modefrom
fi/nan-poc-rejection
Open

poc: reject bad batches with os._exit(1)#35
IgnatovFedor wants to merge 1 commit into
tg/scratchpad_for_modefrom
fi/nan-poc-rejection

Conversation

@IgnatovFedor

@IgnatovFedor IgnatovFedor commented Apr 17, 2026

Copy link
Copy Markdown

Purpose

Replace per-sequence NaN filtering with whole-batch rejection — sequences share one forward pass so partial corruption means the entire batch is suspect.
Add Inf detection (isfinite instead of isnan) to match Go's ValidateFP16Vector which rejects exponent==0x1f.

Testing

Checked following code:

import numpy as np
import torch


def test_float32_large_passes_check1_fails_check2():
    """float32 value > float16 max (65504) is finite in float32 but overflows to Inf in float16.
    check1 (torch.isfinite on float32) misses it; check2 (np.isfinite on float16) catches it."""
    x = torch.tensor([[70000.0, 1.0]])  # 70000 > 65504
    assert torch.isfinite(x).all()
    assert not np.isfinite(x.half().cpu().numpy()).all()


def test_float16_nan_0x7fff_caught_by_numpy():
    """0x7fff is the float16 bit pattern for NaN (exp=0x1f, mantissa≠0).
    np.isfinite on a float16 array correctly returns False for it."""
    f16_nan = np.frombuffer(bytes([0xff, 0x7f]), dtype=np.float16)
    assert not np.isfinite(f16_nan).all()


def test_float16_inf_0x7c00_caught_by_numpy():
    """+Inf in float16 (exp=0x1f, mantissa=0)."""
    f16_inf = np.frombuffer(bytes([0x00, 0x7c]), dtype=np.float16)
    assert not np.isfinite(f16_inf).all()


def test_after_normalization_float16_overflow_impossible():
    """Normalization to unit sphere constrains all components to [-1, 1],
    making float16 overflow structurally impossible (float16 max is 65504)."""
    x = torch.tensor([[70000.0, -500.0, 1.0]])
    x_normed = x / (x.norm(dim=-1, keepdim=True) + 1e-8)
    assert x_normed.abs().max() <= 1.0
    assert np.isfinite(x_normed.half().cpu().numpy()).all()


def test_model_nan_caught_by_check1():
    """NaN from the model forward pass is caught by torch.isfinite before any further processing."""
    x = torch.tensor([[float("nan"), 1.0, 2.0],
                      [1.0, 2.0, 3.0]])
    assert not torch.isfinite(x).all()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant