From 82da575760d936ae492384a87a4934fdbb39bb31 Mon Sep 17 00:00:00 2001 From: aryarathoree Date: Mon, 5 Jan 2026 22:39:10 +0530 Subject: [PATCH 1/2] Fail fast on invalid model input shapes with centralized validation --- graph_weather/utils/input_validation.py | 18 ++++++++++++++++ tests/test_input_validation.py | 28 +++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 graph_weather/utils/input_validation.py create mode 100644 tests/test_input_validation.py diff --git a/graph_weather/utils/input_validation.py b/graph_weather/utils/input_validation.py new file mode 100644 index 00000000..2b25a735 --- /dev/null +++ b/graph_weather/utils/input_validation.py @@ -0,0 +1,18 @@ +import torch + + +def validate_model_input(x): + if not isinstance(x, torch.Tensor): + raise TypeError( + f"Expected input to be torch.Tensor, got {type(x)}" + ) + + if x.ndim != 3: + raise ValueError( + f"Expected input shape [batch, nodes, features], got {x.shape}" + ) + + if x.size(1) <= 0 or x.size(2) <= 0: + raise ValueError( + f"Input tensor must have non-zero nodes and features, got {x.shape}" + ) diff --git a/tests/test_input_validation.py b/tests/test_input_validation.py new file mode 100644 index 00000000..e1746d35 --- /dev/null +++ b/tests/test_input_validation.py @@ -0,0 +1,28 @@ +import sys +from pathlib import Path + +import pytest +import torch + +PROJECT_ROOT = Path(__file__).resolve().parents[1] + +UTILS_PATH = PROJECT_ROOT / "graph_weather" / "utils" +sys.path.insert(0, str(UTILS_PATH)) + +from input_validation import validate_model_input + + +def test_rejects_invalid_shape(): + x = torch.randn(100, 64) + with pytest.raises(ValueError, match="Expected input shape"): + validate_model_input(x) + + +def test_rejects_non_tensor(): + with pytest.raises(TypeError): + validate_model_input([1, 2, 3]) + + +def test_accepts_valid_input(): + x = torch.randn(2, 10, 5) + validate_model_input(x) From b180a3b10b2ffa5fb49d5a3ee557988ee64cc9df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Jan 2026 17:17:59 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/utils/input_validation.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/graph_weather/utils/input_validation.py b/graph_weather/utils/input_validation.py index 2b25a735..65b7dcce 100644 --- a/graph_weather/utils/input_validation.py +++ b/graph_weather/utils/input_validation.py @@ -3,16 +3,10 @@ def validate_model_input(x): if not isinstance(x, torch.Tensor): - raise TypeError( - f"Expected input to be torch.Tensor, got {type(x)}" - ) + raise TypeError(f"Expected input to be torch.Tensor, got {type(x)}") if x.ndim != 3: - raise ValueError( - f"Expected input shape [batch, nodes, features], got {x.shape}" - ) + raise ValueError(f"Expected input shape [batch, nodes, features], got {x.shape}") if x.size(1) <= 0 or x.size(2) <= 0: - raise ValueError( - f"Input tensor must have non-zero nodes and features, got {x.shape}" - ) + raise ValueError(f"Input tensor must have non-zero nodes and features, got {x.shape}")