diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index b33e23cd..454eacd0 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,6 +1,29 @@ """Main import for the complete models""" -from .data.nnja_ai import SensorDataset -from .data.weather_station_reader import WeatherStationReader -from .models.analysis import GraphWeatherAssimilator -from .models.forecast import GraphWeatherForecaster +# Using lazy loading to avoid dependency conflicts +def __getattr__(name): + """Lazy loading for all modules to avoid dependency conflicts.""" + if name == "GraphWeatherAssimilator": + from .models.analysis import GraphWeatherAssimilator as GWA + globals()[name] = GWA + return GWA + elif name == "GraphWeatherForecaster": + from .models.forecast import GraphWeatherForecaster as GWF + globals()[name] = GWF + return GWF + elif name == "SensorDataset": + from .data.nnja_ai import SensorDataset as SD + globals()[name] = SD + return SD + elif name == "WeatherStationReader": + from .data.weather_station_reader import WeatherStationReader as WSR + globals()[name] = WSR + return WSR + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + +__all__ = [ + "GraphWeatherAssimilator", + "GraphWeatherForecaster", + "SensorDataset", + "WeatherStationReader", +] diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 3710e24a..4415db99 100755 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,15 +1,51 @@ """Models""" -from .fengwu_ghr.layers import ( - ImageMetaModel, - LoRAModule, - MetaModel, - WrapperImageModel, - WrapperMetaModel, -) -from .layers.assimilator_decoder import AssimilatorDecoder -from .layers.assimilator_encoder import AssimilatorEncoder -from .layers.decoder import Decoder -from .layers.encoder import Encoder -from .layers.processor import Processor -from .layers.stochastic_decomposition import StochasticDecompositionLayer +# Using lazy loading to avoid dependency conflicts + +__all__ = [] + +def __getattr__(name): + """Lazy loading for models to avoid dependency conflicts.""" + if name in ['ImageMetaModel', 'LoRAModule', 'MetaModel', 'WrapperImageModel', 'WrapperMetaModel']: + from .fengwu_ghr.layers import ( + ImageMetaModel as IM, + LoRAModule as LM, + MetaModel as MM, + WrapperImageModel as WIM, + WrapperMetaModel as WMM, + ) + result = { + 'ImageMetaModel': IM, + 'LoRAModule': LM, + 'MetaModel': MM, + 'WrapperImageModel': WIM, + 'WrapperMetaModel': WMM, + }[name] + globals()[name] = result + return result + elif name == 'AssimilatorDecoder': + from .layers.assimilator_decoder import AssimilatorDecoder as AD + globals()[name] = AD + return AD + elif name == 'AssimilatorEncoder': + from .layers.assimilator_encoder import AssimilatorEncoder as AE + globals()[name] = AE + return AE + elif name == 'Decoder': + from .layers.decoder import Decoder as D + globals()[name] = D + return D + elif name == 'Encoder': + from .layers.encoder import Encoder as E + globals()[name] = E + return E + elif name == 'Processor': + from .layers.processor import Processor as P + globals()[name] = P + return P + elif name == 'StochasticDecompositionLayer': + from .layers.stochastic_decomposition import StochasticDecompositionLayer as SDL + globals()[name] = SDL + return SDL + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/graph_weather/models/data_assimilation/__init__.py b/graph_weather/models/data_assimilation/__init__.py new file mode 100644 index 00000000..53e8c969 --- /dev/null +++ b/graph_weather/models/data_assimilation/__init__.py @@ -0,0 +1,8 @@ +"""Data assimilation module initialization.""" + +from .data_assimilation_base import DataAssimilationBase, EnsembleGenerator + +__all__ = [ + "DataAssimilationBase", + "EnsembleGenerator", +] diff --git a/graph_weather/models/data_assimilation/data_assimilation_base.py b/graph_weather/models/data_assimilation/data_assimilation_base.py new file mode 100644 index 00000000..2c88c2a8 --- /dev/null +++ b/graph_weather/models/data_assimilation/data_assimilation_base.py @@ -0,0 +1,110 @@ +"""Base classes for data assimilation modules.""" + +import abc +from typing import Any, Dict, Union + +import torch +from torch_geometric.data import Data + + +class EnsembleGenerator: + """Class to generate ensemble members from a background state.""" + + def __init__(self, noise_std: float = 0.1, method: str = "gaussian"): + self.noise_std = noise_std + self.method = method + + def generate_ensemble(self, state: Union[torch.Tensor, Data], num_members: int): + if isinstance(state, torch.Tensor): + return self._generate_tensor_ensemble(state, num_members) + elif isinstance(state, Data): + return self._generate_graph_ensemble(state, num_members) + else: + raise TypeError(f"Unsupported state type: {type(state)}") + + def _generate_tensor_ensemble(self, state: torch.Tensor, num_members: int) -> torch.Tensor: + batch_size, nodes, features = state.shape + ensemble = torch.zeros(batch_size, num_members, nodes, features, device=state.device) + + for i in range(num_members): + if self.method == "gaussian": + noise = torch.randn_like(state) * self.noise_std + ensemble[:, i] = state + noise + elif self.method == "dropout": + mask = torch.bernoulli(torch.ones_like(state) * 0.9) # Keep 90% of values + noise = torch.randn_like(state) * self.noise_std * 0.1 + ensemble[:, i] = (state * mask) + noise + elif self.method == "perturbation": + perturbation = ( + torch.randn_like(state) + * self.noise_std + * torch.linspace(0.1, 1.0, num_members)[i] + ) + ensemble[:, i] = state + perturbation + else: + raise ValueError(f"Unknown ensemble generation method: {self.method}") + + return ensemble + + def _generate_graph_ensemble(self, state: Data, num_members: int) -> Data: + x_expanded = torch.zeros( + state.x.shape[0], num_members, state.x.shape[1], device=state.x.device + ) + + for i in range(num_members): + if self.method == "gaussian": + noise = torch.randn_like(state.x) * self.noise_std + x_expanded[:, i] = state.x + noise + elif self.method == "dropout": + mask = torch.bernoulli(torch.ones_like(state.x) * 0.9) + noise = torch.randn_like(state.x) * self.noise_std * 0.1 + x_expanded[:, i] = (state.x * mask) + noise + elif self.method == "perturbation": + perturbation = ( + torch.randn_like(state.x) + * self.noise_std + * torch.linspace(0.1, 1.0, num_members)[i] + ) + x_expanded[:, i] = state.x + perturbation + else: + raise ValueError(f"Unknown ensemble generation method: {self.method}") + + new_state = Data( + x=x_expanded, + edge_index=state.edge_index, + edge_attr=getattr(state, "edge_attr", None), + pos=getattr(state, "pos", None), + ) + + return new_state + + +class DataAssimilationBase(abc.ABC): + """Abstract base class for data assimilation modules.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.ensemble_generator = EnsembleGenerator( + noise_std=config.get("noise_std", 0.1), method=config.get("ensemble_method", "gaussian") + ) + + @abc.abstractmethod + def initialize_ensemble(self, background_state: Union[torch.Tensor, Data], num_members: int): + pass + + @abc.abstractmethod + def assimilate(self, ensemble: Union[torch.Tensor, Data], observations: torch.Tensor): + pass + + @abc.abstractmethod + def _compute_analysis(self, ensemble: Union[torch.Tensor, Data]) -> Union[torch.Tensor, Data]: + pass + + def forward( + self, state: Union[torch.Tensor, Data], observations: torch.Tensor, num_ensemble: int = 10 + ): + ensemble = self.initialize_ensemble(state, num_ensemble) + updated_ensemble = self.assimilate(ensemble, observations) + analysis = self._compute_analysis(updated_ensemble) + + return updated_ensemble, analysis diff --git a/tests/models/data_assimilation/test_data_assimilation_base.py b/tests/models/data_assimilation/test_data_assimilation_base.py new file mode 100644 index 00000000..c86738d7 --- /dev/null +++ b/tests/models/data_assimilation/test_data_assimilation_base.py @@ -0,0 +1,89 @@ +import pytest +import torch +from torch_geometric.data import Data + +from graph_weather.models.data_assimilation import DataAssimilationBase, EnsembleGenerator + + +class MockDA(DataAssimilationBase): + """Mock implementation of DataAssimilationBase for testing purposes.""" + + def initialize_ensemble(self, background_state, num_members): + return self.ensemble_generator.generate_ensemble(background_state, num_members) + + def assimilate(self, ensemble, observations): + return ensemble # Return unchanged for testing + + def _compute_analysis(self, ensemble): + if isinstance(ensemble, torch.Tensor): + return torch.mean(ensemble, dim=1) + elif isinstance(ensemble, Data): + return ensemble # Return as is for testing + else: + raise TypeError(f"Unsupported ensemble type: {type(ensemble)}") + + +def test_ensemble_generator_tensor(): + """Test ensemble generation for tensor inputs.""" + generator = EnsembleGenerator(noise_std=0.1, method="gaussian") + + # Test tensor input + state = torch.randn(2, 5, 3) # [batch, nodes, features] + ensemble = generator.generate_ensemble(state, 4) + + assert ensemble.shape == (2, 4, 5, 3) # [batch, members, nodes, features] + assert not torch.equal(state, ensemble[:, 0]) # Should have noise added + + +def test_ensemble_generator_graph(): + """Test ensemble generation for graph inputs.""" + generator = EnsembleGenerator(noise_std=0.1, method="gaussian") + + # Test graph input + x = torch.randn(10, 4) # Node features + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) + graph_state = Data(x=x, edge_index=edge_index) + + ensemble = generator.generate_ensemble(graph_state, 3) + + # Check that ensemble preserves structure + assert hasattr(ensemble, "x") + assert hasattr(ensemble, "edge_index") + assert ensemble.x.shape[1] == 3 # Ensemble dimension + + +def test_data_assimilation_base_abstract_methods(): + """Test that abstract methods are properly defined.""" + config = {"param": "value"} + da_module = MockDA(config) + + assert da_module.config == config + + # Test ensemble generation + state = torch.randn(2, 5, 3) + ensemble = da_module.initialize_ensemble(state, 4) + assert ensemble.shape == (2, 4, 5, 3) + + +def test_compute_analysis_tensor(): + """Test analysis computation for tensor ensembles.""" + da_module = MockDA({}) + + # Create ensemble: [batch, members, nodes, features] + ensemble = torch.stack( + [ + torch.ones(2, 5, 3), # First member + 2 * torch.ones(2, 5, 3), # Second member + 3 * torch.ones(2, 5, 3), # Third member + ], + dim=1, + ) # Shape: [2, 3, 5, 3] + + analysis = da_module._compute_analysis(ensemble) + + # Mean should be (1 + 2 + 3) / 3 = 2 + expected = 2 * torch.ones(2, 5, 3) + assert torch.allclose(analysis, expected) + + +