diff --git a/graph_weather/utils/input_validation.py b/graph_weather/utils/input_validation.py new file mode 100644 index 00000000..65b7dcce --- /dev/null +++ b/graph_weather/utils/input_validation.py @@ -0,0 +1,12 @@ +import torch + + +def validate_model_input(x): + if not isinstance(x, torch.Tensor): + raise TypeError(f"Expected input to be torch.Tensor, got {type(x)}") + + if x.ndim != 3: + raise ValueError(f"Expected input shape [batch, nodes, features], got {x.shape}") + + if x.size(1) <= 0 or x.size(2) <= 0: + raise ValueError(f"Input tensor must have non-zero nodes and features, got {x.shape}") diff --git a/tests/test_input_validation.py b/tests/test_input_validation.py new file mode 100644 index 00000000..e1746d35 --- /dev/null +++ b/tests/test_input_validation.py @@ -0,0 +1,28 @@ +import sys +from pathlib import Path + +import pytest +import torch + +PROJECT_ROOT = Path(__file__).resolve().parents[1] + +UTILS_PATH = PROJECT_ROOT / "graph_weather" / "utils" +sys.path.insert(0, str(UTILS_PATH)) + +from input_validation import validate_model_input + + +def test_rejects_invalid_shape(): + x = torch.randn(100, 64) + with pytest.raises(ValueError, match="Expected input shape"): + validate_model_input(x) + + +def test_rejects_non_tensor(): + with pytest.raises(TypeError): + validate_model_input([1, 2, 3]) + + +def test_accepts_valid_input(): + x = torch.randn(2, 10, 5) + validate_model_input(x)