Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions graph_weather/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
"""Main import for the complete models"""

from .data.nnja_ai import SensorDataset
from .data.weather_station_reader import WeatherStationReader
from .models.analysis import GraphWeatherAssimilator
from .models.forecast import GraphWeatherForecaster
# Using lazy loading to avoid dependency conflicts
def __getattr__(name):
"""Lazy loading for all modules to avoid dependency conflicts."""
if name == "GraphWeatherAssimilator":
from .models.analysis import GraphWeatherAssimilator as GWA
globals()[name] = GWA
return GWA
elif name == "GraphWeatherForecaster":
from .models.forecast import GraphWeatherForecaster as GWF
globals()[name] = GWF
return GWF
elif name == "SensorDataset":
from .data.nnja_ai import SensorDataset as SD
globals()[name] = SD
return SD
elif name == "WeatherStationReader":
from .data.weather_station_reader import WeatherStationReader as WSR
globals()[name] = WSR
return WSR
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

__all__ = [
"GraphWeatherAssimilator",
"GraphWeatherForecaster",
"SensorDataset",
"WeatherStationReader",
]
62 changes: 49 additions & 13 deletions graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,51 @@
"""Models"""

from .fengwu_ghr.layers import (
ImageMetaModel,
LoRAModule,
MetaModel,
WrapperImageModel,
WrapperMetaModel,
)
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
from .layers.encoder import Encoder
from .layers.processor import Processor
from .layers.stochastic_decomposition import StochasticDecompositionLayer
# Using lazy loading to avoid dependency conflicts

__all__ = []

def __getattr__(name):
"""Lazy loading for models to avoid dependency conflicts."""
if name in ['ImageMetaModel', 'LoRAModule', 'MetaModel', 'WrapperImageModel', 'WrapperMetaModel']:
from .fengwu_ghr.layers import (
ImageMetaModel as IM,
LoRAModule as LM,
MetaModel as MM,
WrapperImageModel as WIM,
WrapperMetaModel as WMM,
)
result = {
'ImageMetaModel': IM,
'LoRAModule': LM,
'MetaModel': MM,
'WrapperImageModel': WIM,
'WrapperMetaModel': WMM,
}[name]
globals()[name] = result
return result
elif name == 'AssimilatorDecoder':
from .layers.assimilator_decoder import AssimilatorDecoder as AD
globals()[name] = AD
return AD
elif name == 'AssimilatorEncoder':
from .layers.assimilator_encoder import AssimilatorEncoder as AE
globals()[name] = AE
return AE
elif name == 'Decoder':
from .layers.decoder import Decoder as D
globals()[name] = D
return D
elif name == 'Encoder':
from .layers.encoder import Encoder as E
globals()[name] = E
return E
elif name == 'Processor':
from .layers.processor import Processor as P
globals()[name] = P
return P
elif name == 'StochasticDecompositionLayer':
from .layers.stochastic_decomposition import StochasticDecompositionLayer as SDL
globals()[name] = SDL
return SDL
else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
8 changes: 8 additions & 0 deletions graph_weather/models/data_assimilation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Data assimilation module initialization."""

from .data_assimilation_base import DataAssimilationBase, EnsembleGenerator

__all__ = [
"DataAssimilationBase",
"EnsembleGenerator",
]
110 changes: 110 additions & 0 deletions graph_weather/models/data_assimilation/data_assimilation_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Base classes for data assimilation modules."""

import abc
from typing import Any, Dict, Union

import torch
from torch_geometric.data import Data


class EnsembleGenerator:
"""Class to generate ensemble members from a background state."""

def __init__(self, noise_std: float = 0.1, method: str = "gaussian"):
self.noise_std = noise_std
self.method = method

def generate_ensemble(self, state: Union[torch.Tensor, Data], num_members: int):
if isinstance(state, torch.Tensor):
return self._generate_tensor_ensemble(state, num_members)
elif isinstance(state, Data):
return self._generate_graph_ensemble(state, num_members)
else:
raise TypeError(f"Unsupported state type: {type(state)}")

def _generate_tensor_ensemble(self, state: torch.Tensor, num_members: int) -> torch.Tensor:
batch_size, nodes, features = state.shape
ensemble = torch.zeros(batch_size, num_members, nodes, features, device=state.device)

for i in range(num_members):
if self.method == "gaussian":
noise = torch.randn_like(state) * self.noise_std
ensemble[:, i] = state + noise
elif self.method == "dropout":
mask = torch.bernoulli(torch.ones_like(state) * 0.9) # Keep 90% of values
noise = torch.randn_like(state) * self.noise_std * 0.1
ensemble[:, i] = (state * mask) + noise
elif self.method == "perturbation":
perturbation = (
torch.randn_like(state)
* self.noise_std
* torch.linspace(0.1, 1.0, num_members)[i]
)
ensemble[:, i] = state + perturbation
else:
raise ValueError(f"Unknown ensemble generation method: {self.method}")

return ensemble

def _generate_graph_ensemble(self, state: Data, num_members: int) -> Data:
x_expanded = torch.zeros(
state.x.shape[0], num_members, state.x.shape[1], device=state.x.device
)

for i in range(num_members):
if self.method == "gaussian":
noise = torch.randn_like(state.x) * self.noise_std
x_expanded[:, i] = state.x + noise
elif self.method == "dropout":
mask = torch.bernoulli(torch.ones_like(state.x) * 0.9)
noise = torch.randn_like(state.x) * self.noise_std * 0.1
x_expanded[:, i] = (state.x * mask) + noise
elif self.method == "perturbation":
perturbation = (
torch.randn_like(state.x)
* self.noise_std
* torch.linspace(0.1, 1.0, num_members)[i]
)
x_expanded[:, i] = state.x + perturbation
else:
raise ValueError(f"Unknown ensemble generation method: {self.method}")

new_state = Data(
x=x_expanded,
edge_index=state.edge_index,
edge_attr=getattr(state, "edge_attr", None),
pos=getattr(state, "pos", None),
)

return new_state


class DataAssimilationBase(abc.ABC):
"""Abstract base class for data assimilation modules."""

def __init__(self, config: Dict[str, Any]):
self.config = config
self.ensemble_generator = EnsembleGenerator(
noise_std=config.get("noise_std", 0.1), method=config.get("ensemble_method", "gaussian")
)

@abc.abstractmethod
def initialize_ensemble(self, background_state: Union[torch.Tensor, Data], num_members: int):
pass

@abc.abstractmethod
def assimilate(self, ensemble: Union[torch.Tensor, Data], observations: torch.Tensor):
pass

@abc.abstractmethod
def _compute_analysis(self, ensemble: Union[torch.Tensor, Data]) -> Union[torch.Tensor, Data]:
pass

def forward(
self, state: Union[torch.Tensor, Data], observations: torch.Tensor, num_ensemble: int = 10
):
ensemble = self.initialize_ensemble(state, num_ensemble)
updated_ensemble = self.assimilate(ensemble, observations)
analysis = self._compute_analysis(updated_ensemble)

return updated_ensemble, analysis
89 changes: 89 additions & 0 deletions tests/models/data_assimilation/test_data_assimilation_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import torch
from torch_geometric.data import Data

from graph_weather.models.data_assimilation import DataAssimilationBase, EnsembleGenerator


class MockDA(DataAssimilationBase):
"""Mock implementation of DataAssimilationBase for testing purposes."""

def initialize_ensemble(self, background_state, num_members):
return self.ensemble_generator.generate_ensemble(background_state, num_members)

def assimilate(self, ensemble, observations):
return ensemble # Return unchanged for testing

def _compute_analysis(self, ensemble):
if isinstance(ensemble, torch.Tensor):
return torch.mean(ensemble, dim=1)
elif isinstance(ensemble, Data):
return ensemble # Return as is for testing
else:
raise TypeError(f"Unsupported ensemble type: {type(ensemble)}")


def test_ensemble_generator_tensor():
"""Test ensemble generation for tensor inputs."""
generator = EnsembleGenerator(noise_std=0.1, method="gaussian")

# Test tensor input
state = torch.randn(2, 5, 3) # [batch, nodes, features]
ensemble = generator.generate_ensemble(state, 4)

assert ensemble.shape == (2, 4, 5, 3) # [batch, members, nodes, features]
assert not torch.equal(state, ensemble[:, 0]) # Should have noise added


def test_ensemble_generator_graph():
"""Test ensemble generation for graph inputs."""
generator = EnsembleGenerator(noise_std=0.1, method="gaussian")

# Test graph input
x = torch.randn(10, 4) # Node features
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
graph_state = Data(x=x, edge_index=edge_index)

ensemble = generator.generate_ensemble(graph_state, 3)

# Check that ensemble preserves structure
assert hasattr(ensemble, "x")
assert hasattr(ensemble, "edge_index")
assert ensemble.x.shape[1] == 3 # Ensemble dimension


def test_data_assimilation_base_abstract_methods():
"""Test that abstract methods are properly defined."""
config = {"param": "value"}
da_module = MockDA(config)

assert da_module.config == config

# Test ensemble generation
state = torch.randn(2, 5, 3)
ensemble = da_module.initialize_ensemble(state, 4)
assert ensemble.shape == (2, 4, 5, 3)


def test_compute_analysis_tensor():
"""Test analysis computation for tensor ensembles."""
da_module = MockDA({})

# Create ensemble: [batch, members, nodes, features]
ensemble = torch.stack(
[
torch.ones(2, 5, 3), # First member
2 * torch.ones(2, 5, 3), # Second member
3 * torch.ones(2, 5, 3), # Third member
],
dim=1,
) # Shape: [2, 3, 5, 3]

analysis = da_module._compute_analysis(ensemble)

# Mean should be (1 + 2 + 3) / 3 = 2
expected = 2 * torch.ones(2, 5, 3)
assert torch.allclose(analysis, expected)