Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions graph_weather/models/aurora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,19 @@


class PointEncoder(nn.Module):
"""Encodes 2D point coordinates and features into a latent representation.

The encoder processes the input points and features separately and then combines them.
"""

def __init__(self, input_features: int, embed_dim: int, max_seq_len: int = 1024):
"""Initialize the PointEncoder.

Args:
input_features: Number of input features per point.
embed_dim: Dimension of the latent embedding.
max_seq_len: Maximum sequence length (number of points).
"""
super().__init__()
self.input_dim = input_features + 2 # Account for lat/lon coordinates
self.max_seq_len = max_seq_len
Expand All @@ -36,6 +48,15 @@ def __init__(self, input_features: int, embed_dim: int, max_seq_len: int = 1024)
self.norm = nn.LayerNorm(embed_dim)

def forward(self, points: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
"""Forward pass of the PointEncoder.

Args:
points: Tensor of point coordinates (lat/lon).
features: Tensor of point features.

Returns:
Encoded latent representation.
"""
num_points = points.shape[1]
if num_points > self.max_seq_len:
points = points[:, : self.max_seq_len, :]
Expand Down Expand Up @@ -64,13 +85,20 @@ class PointDecoder(nn.Module):
"""Decodes latent representations back to point features."""

def __init__(self, embed_dim: int, output_features: int):
"""Initialize the PointDecoder.

Args:
embed_dim: Dimension of the latent embedding.
output_features: Number of output features per point.
"""
super().__init__()
self.decoder = nn.Sequential(
nn.Linear(embed_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, output_features)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
"""Forward pass of the PointDecoder.

Args:
x: (batch_size, num_points, embed_dim) tensor
Returns:
Expand All @@ -83,11 +111,18 @@ class PointCloudProcessor(nn.Module):
"""Processes point cloud data using self-attention layers."""

def __init__(self, embed_dim: int, num_layers: int = 4):
"""Initialize the PointCloudProcessor.

Args:
embed_dim: Dimension of the latent embedding.
num_layers: Number of self-attention layers.
"""
super().__init__()
self.layers = nn.ModuleList([SelfAttentionLayer(embed_dim) for _ in range(num_layers)])

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
"""Forward pass of the PointCloudProcessor.

Args:
x: (batch_size, num_points, embed_dim) tensor
Returns:
Expand All @@ -99,7 +134,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class SelfAttentionLayer(nn.Module):
"""Single layer of self-attention with residual connections and layer normalization."""

def __init__(self, embed_dim: int):
"""Initialize the SelfAttentionLayer.

Args:
embed_dim: Dimension of the latent embedding.
"""
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim, num_heads=8)
self.norm1 = nn.LayerNorm(embed_dim)
Expand All @@ -109,6 +151,14 @@ def __init__(self, embed_dim: int):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the SelfAttentionLayer.

Args:
x: Input tensor.

Returns:
Processed tensor.
"""
# First attention block with residual
x_t = x.transpose(0, 1)
attended, _ = self.attention(x_t, x_t, x_t)
Expand All @@ -121,7 +171,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class EarthSystemLoss(nn.Module):
"""Custom loss function incorporating physical constraints for earth system modeling."""

def __init__(self, alpha: float = 0.5, beta: float = 0.3, gamma: float = 0.2):
"""Initialize the EarthSystemLoss.

Args:
alpha: Weight for MSE loss.
beta: Weight for spatial correlation loss.
gamma: Weight for physical constraint loss.
"""
super().__init__()
self.alpha = alpha
self.beta = beta
Expand All @@ -130,6 +189,7 @@ def __init__(self, alpha: float = 0.5, beta: float = 0.3, gamma: float = 0.2):
def spatial_correlation_loss(
self, pred: torch.Tensor, target: torch.Tensor, points: torch.Tensor
) -> torch.Tensor:
"""Calculate spatial correlation loss."""
batch_size, num_points, _ = points.shape
points_flat = points.view(-1, 2)

Expand All @@ -150,7 +210,15 @@ def spatial_correlation_loss(
return correlation_loss

def physical_loss(self, pred: torch.Tensor, points: torch.Tensor) -> torch.Tensor:
"""Calculate physical consistency loss - ensures predictions follow basic physical laws"""
"""Calculate physical consistency loss - ensures predictions follow basic physical laws.

Args:
pred: Predicted tensor.
points: Points tensor.

Returns:
Computed physical loss.
"""
# Ensure non-negative values for physical quantities (e.g., temperature in Kelvin)
min_value_loss = torch.nn.functional.relu(-pred).mean()

Expand All @@ -169,6 +237,16 @@ def physical_loss(self, pred: torch.Tensor, points: torch.Tensor) -> torch.Tenso
return physical_loss

def forward(self, pred: torch.Tensor, target: torch.Tensor, points: torch.Tensor) -> dict:
"""Calculate the total loss.

Args:
pred: Predicted tensor.
target: Target tensor.
points: Points tensor.

Returns:
Dictionary containing total loss and individual components.
"""
mse_loss = torch.nn.functional.mse_loss(pred, target)
spatial_loss = self.spatial_correlation_loss(pred, target, points)
physical_loss = self.physical_loss(pred, points)
Expand All @@ -185,6 +263,8 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor, points: torch.Tensor


class AuroraModel(nn.Module):
"""Aurora model implementation for unstructured weather data processing."""

def __init__(
self,
input_features: int,
Expand All @@ -195,6 +275,17 @@ def __init__(
max_seq_len: int = 1024,
use_checkpointing: bool = False,
):
"""Initialize the AuroraModel.

Args:
input_features: Number of input features per point.
output_features: Number of output features per point.
latent_dim: Dimension of the latent space.
num_layers: Number of processor layers.
max_points: Maximum number of points allowed.
max_seq_len: Maximum sequence length.
use_checkpointing: Whether to use gradient checkpointing.
"""
super().__init__()

self.max_points = max_points
Expand Down
54 changes: 54 additions & 0 deletions tests/test_input_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch
from graph_weather.models import GraphWeatherForecaster, GraphWeatherAssimilator
from graph_weather.models.aurora.model import AuroraModel


def test_forecaster_input_shape_validation():
"""Test that GraphWeatherForecaster raises ValueError for invalid input shapes."""
lat_lons = [(0, 0), (0, 1)]
model = GraphWeatherForecaster(lat_lons=lat_lons, feature_dim=10)

# Invalid shape: [batch, nodes] (missing features)
x = torch.randn(10, 2)
with pytest.raises(ValueError, match="Expected input shape"):
model(x)

# Valid shape
x = torch.randn(10, 2, 10)
try:
model(x)
except Exception as e:
# Ignore other errors, just checking the validation passes
pass


def test_assimilator_input_shape_validation():
"""Test that GraphWeatherAssimilator raises ValueError for invalid input shapes."""
lat_lons = [(0, 0), (0, 1)]
model = GraphWeatherAssimilator(output_lat_lons=lat_lons, analysis_dim=10)

# Invalid shape
x = torch.randn(10, 2)
obs = torch.randn(10, 2) # Dummy
with pytest.raises(ValueError, match="Expected input shape"):
model(x, obs)


def test_aurora_input_shape_validation():
"""Test that AuroraModel raises ValueError for invalid input shapes."""
config = {
"input_features": 2,
"output_features": 2,
"latent_dim": 8,
"max_points": 50,
"max_seq_len": 128,
}
model = AuroraModel(**config)

points = torch.randn(1, 10, 2)
# Invalid shape
features = torch.randn(1, 10)

with pytest.raises(ValueError, match="Expected input shape"):
model(points, features)