From 15abececb430816f3654e6cab8573a14e95e2a40 Mon Sep 17 00:00:00 2001 From: SOHAMPAL23 Date: Mon, 5 Jan 2026 20:33:58 +0530 Subject: [PATCH 1/6] feat: add self-supervised 3D-Var-based AI data assimilation prototype --- graph_weather/__init__.py | 1 + graph_weather/data/__init__.py | 7 +- graph_weather/data/assimilation_dataloader.py | 288 +++++++++ .../data_assimilation_implementation.md | 165 +++++ graph_weather/example_usage.py | 170 +++++ graph_weather/models/__init__.py | 1 + graph_weather/models/data_assimilation.py | 364 +++++++++++ graph_weather/models/evaluation.py | 439 +++++++++++++ graph_weather/models/training_loop.py | 538 ++++++++++++++++ graph_weather/models/visualization.py | 582 ++++++++++++++++++ graph_weather/test_data_assimilation.py | 322 ++++++++++ 11 files changed, 2876 insertions(+), 1 deletion(-) create mode 100644 graph_weather/data/assimilation_dataloader.py create mode 100644 graph_weather/data_assimilation_implementation.md create mode 100644 graph_weather/example_usage.py create mode 100644 graph_weather/models/data_assimilation.py create mode 100644 graph_weather/models/evaluation.py create mode 100644 graph_weather/models/training_loop.py create mode 100644 graph_weather/models/visualization.py create mode 100644 graph_weather/test_data_assimilation.py diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index b33e23cd..2fabe11e 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,6 +1,7 @@ """Main import for the complete models""" from .data.nnja_ai import SensorDataset +from .data.assimilation_dataloader import AssimilationDataset, AssimilationDataModule from .data.weather_station_reader import WeatherStationReader from .models.analysis import GraphWeatherAssimilator from .models.forecast import GraphWeatherForecaster diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index d67a79e4..052921fa 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1,5 +1,10 @@ """Dataloaders and data processing utilities""" -from .anemoi_dataloader import AnemoiDataset +try: + from .anemoi_dataloader import AnemoiDataset +except ImportError: + # anemoi library not available, skip this import + pass from .nnja_ai import SensorDataset from .weather_station_reader import WeatherStationReader +from .assimilation_dataloader import AssimilationDataset, AssimilationDataModule diff --git a/graph_weather/data/assimilation_dataloader.py b/graph_weather/data/assimilation_dataloader.py new file mode 100644 index 00000000..412f43f1 --- /dev/null +++ b/graph_weather/data/assimilation_dataloader.py @@ -0,0 +1,288 @@ +""" +Data loader for self-supervised data assimilation framework +""" + +import torch +from torch.utils.data import Dataset, DataLoader +import numpy as np + + +class AssimilationDataset(Dataset): + """ + Dataset for self-supervised data assimilation + Each sample contains background state and observations + """ + + def __init__(self, background_states, observations, true_states=None): + """ + Initialize the assimilation dataset + + Args: + background_states: Background states (x_b) + observations: Observations (y) + true_states: True states (for evaluation only, not used in training) + """ + self.background_states = background_states + self.observations = observations + self.true_states = true_states + + assert len(background_states) == len(observations), \ + "Background and observation arrays must have same length" + + if true_states is not None: + assert len(true_states) == len(background_states), \ + "True states must have same length as background states" + + def __len__(self): + return len(self.background_states) + + def __getitem__(self, idx): + bg = self.background_states[idx] + obs = self.observations[idx] + + sample = { + 'background': bg, + 'observations': obs + } + + if self.true_states is not None: + sample['true_state'] = self.true_states[idx] + + return sample + + +def create_synthetic_assimilation_dataset( + num_samples=1000, + grid_size=(10, 10), + num_channels=1, + bg_error_std=0.5, + obs_error_std=0.3, + obs_fraction=0.5 +): + """ + Create a synthetic dataset for data assimilation experiments + + Args: + num_samples: Number of samples to generate + grid_size: Size of spatial grid + num_channels: Number of variables/channels + bg_error_std: Standard deviation of background errors + obs_error_std: Standard deviation of observation errors + obs_fraction: Fraction of grid points that have observations + + Returns: + dataset: AssimilationDataset object + """ + total_size = np.prod(grid_size) if isinstance(grid_size, (tuple, list)) else grid_size + + # Generate true states with spatial correlation + true_states = torch.randn(num_samples, num_channels, *grid_size) + + # Apply spatial smoothing to create realistic fields + if len(grid_size) == 2: + # Create a Gaussian smoothing kernel + kernel_size = 5 + sigma = 1.0 + kernel = torch.zeros(kernel_size, kernel_size) + center = kernel_size // 2 + + for i in range(kernel_size): + for j in range(kernel_size): + x, y = i - center, j - center + kernel[i, j] = np.exp(-(x**2 + y**2) / (2 * sigma**2)) + + kernel = kernel / kernel.sum() + kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(num_channels, 1, 1, 1) + + # Apply smoothing to each sample and channel + for i in range(num_samples): + for c in range(num_channels): + smoothed = torch.nn.functional.conv2d( + true_states[i:i+1, c:c+1], + kernel, + padding=kernel_size//2, + groups=1 + ) + true_states[i, c:c+1] = smoothed + + # Create background states with errors + bg_errors = torch.randn_like(true_states) * bg_error_std + background_states = true_states + bg_errors + + # Create observations with errors + obs_errors = torch.randn_like(true_states) * obs_error_std + observations = true_states + obs_errors + + # Optionally mask some observations based on obs_fraction + if obs_fraction < 1.0: + mask = torch.rand_like(observations) < obs_fraction + observations = observations * mask + + dataset = AssimilationDataset(background_states, observations, true_states) + return dataset + + +def get_assimilation_data_loaders( + dataset, + batch_size=32, + train_ratio=0.7, + val_ratio=0.2, + test_ratio=0.1, + shuffle=True +): + """ + Create train/validation/test data loaders from dataset + + Args: + dataset: AssimilationDataset object + batch_size: Size of batches + train_ratio: Fraction of data for training + val_ratio: Fraction of data for validation + test_ratio: Fraction of data for testing + shuffle: Whether to shuffle the data + + Returns: + train_loader, val_loader, test_loader: Data loaders + """ + total_size = len(dataset) + train_size = int(train_ratio * total_size) + val_size = int(val_ratio * total_size) + test_size = total_size - train_size - val_size + + # Split the dataset + train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( + dataset, [train_size, val_size, test_size] + ) + + # Create data loaders + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=shuffle + ) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False + ) + + test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False + ) + + return train_loader, val_loader, test_loader + + +def create_observation_mask(grid_size, obs_fraction=0.5, seed=None): + """ + Create a mask indicating which grid points have observations + + Args: + grid_size: Size of the grid (can be int for 1D or tuple for 2D) + obs_fraction: Fraction of grid points that have observations + seed: Random seed for reproducibility + + Returns: + mask: Boolean mask indicating observation locations + """ + if seed is not None: + np.random.seed(seed) + torch.manual_seed(seed) + + total_size = np.prod(grid_size) if isinstance(grid_size, (tuple, list)) else grid_size + num_obs = int(total_size * obs_fraction) + + # Create random indices for observation locations + obs_indices = np.random.choice(total_size, size=num_obs, replace=False) + + # Create mask + mask_flat = torch.zeros(total_size, dtype=torch.bool) + mask_flat[obs_indices] = True + + if isinstance(grid_size, (tuple, list)): + mask = mask_flat.view(grid_size) + else: + mask = mask_flat + + return mask + + +def apply_observation_operator(data, obs_mask): + """ + Apply observation operator to extract observed values from state + + Args: + data: Full state data + obs_mask: Boolean mask indicating observation locations + + Returns: + observed_data: Data at observation locations only + """ + if len(data.shape) > 2: # Spatial data + batch_size = data.size(0) + reshaped_data = data.view(batch_size, -1) # Flatten spatial dimensions + observed_flat = reshaped_data * obs_mask.view(-1).float() + return observed_flat.view_as(data) + else: + return data * obs_mask.float() + + +class AssimilationDataModule: + """ + A PyTorch Lightning-style data module for assimilation data + """ + + def __init__( + self, + num_samples=1000, + grid_size=(10, 10), + num_channels=1, + bg_error_std=0.5, + obs_error_std=0.3, + obs_fraction=0.5, + batch_size=32, + train_ratio=0.7, + val_ratio=0.2, + test_ratio=0.1 + ): + self.num_samples = num_samples + self.grid_size = grid_size + self.num_channels = num_channels + self.bg_error_std = bg_error_std + self.obs_error_std = obs_error_std + self.obs_fraction = obs_fraction + self.batch_size = batch_size + self.train_ratio = train_ratio + self.val_ratio = val_ratio + self.test_ratio = test_ratio + + def setup(self, stage=None): + """Setup the dataset""" + self.dataset = create_synthetic_assimilation_dataset( + num_samples=self.num_samples, + grid_size=self.grid_size, + num_channels=self.num_channels, + bg_error_std=self.bg_error_std, + obs_error_std=self.obs_error_std, + obs_fraction=self.obs_fraction + ) + + self.train_loader, self.val_loader, self.test_loader = get_assimilation_data_loaders( + self.dataset, + batch_size=self.batch_size, + train_ratio=self.train_ratio, + val_ratio=self.val_ratio, + test_ratio=self.test_ratio + ) + + def train_dataloader(self): + return self.train_loader + + def val_dataloader(self): + return self.val_loader + + def test_dataloader(self): + return self.test_loader \ No newline at end of file diff --git a/graph_weather/data_assimilation_implementation.md b/graph_weather/data_assimilation_implementation.md new file mode 100644 index 00000000..7c68c35e --- /dev/null +++ b/graph_weather/data_assimilation_implementation.md @@ -0,0 +1,165 @@ +# Self-Supervised Data Assimilation Framework with 3D-Var Loss + +## Overview + +This implementation provides a complete self-supervised data assimilation framework that learns to produce analysis states by minimizing the 3D-Var cost function without using ground-truth labels. The system consists of neural networks that take background states and observations as input and produce optimal analysis states. + +## Core Components + +### 1. 3D-Var Loss Function (`data_assimilation.py`) + +The `ThreeDVarLoss` class implements the core 3D-Var objective function: + +``` +J(x) = (x - x_b)^T B^{-1} (x - x_b) + (y - Hx)^T R^{-1} (y - Hx) +``` + +Where: +- `x`: analysis state (model output) +- `x_b`: background state (first guess) +- `y`: observations +- `B`: background error covariance +- `R`: observation error covariance +- `H`: observation operator + +Key features: +- Supports custom background and observation error covariances +- Handles different observation operators +- Works with both fully connected and convolutional models +- Self-supervised (no ground-truth required) + +### 2. Data Assimilation Models (`data_assimilation.py`) + +Two model architectures are provided: + +#### DataAssimilationModel +- Fully connected neural network +- Takes concatenated background and observations as input +- Produces analysis state as output +- Configurable hidden dimensions and layers + +#### SimpleDataAssimilationModel +- Convolutional neural network for spatial data +- Works with 1D/2D grid data +- Preserves spatial relationships +- Efficient for gridded meteorological data + +### 3. Data Pipeline (`assimilation_dataloader.py`) + +- `AssimilationDataset`: Dataset class for background/observation pairs +- `AssimilationDataModule`: PyTorch Lightning-style data module +- Synthetic data generation with spatial correlation +- Observation masking and operator creation +- Train/validation/test splitting + +### 4. Training Framework (`training_loop.py`) + +- `DataAssimilationTrainer`: Complete training loop with validation +- Self-supervised training using 3D-Var loss +- Learning rate scheduling +- Model checkpointing +- Multi-mode training (good/poor background, sparse observations) + +### 5. Evaluation Metrics (`evaluation.py`) + +Comprehensive evaluation including: +- RMSE, MAE, bias calculations +- Correlation coefficients +- Spatial metrics +- Information gain +- Baseline comparisons +- Cross-validation + +### 6. Visualization Tools (`visualization.py`) + +- Training curves plotting +- Comparison grids (background, observations, analysis, true state) +- Error maps visualization +- RMSE comparisons +- Heatmaps and scatter plots +- Comprehensive dashboard + +## Key Features + +### Self-Supervised Learning +- No ground-truth labels required +- Physics-based loss function +- Learns optimal combination of background and observations + +### Flexible Architecture +- Works with different grid sizes +- Supports multiple channels/variables +- Configurable network depth and width +- Multiple activation functions + +### Multiple Training Modes +- With good first guess (low background error) +- With poor first guess (cold start) +- With varying observation densities +- Different error covariance specifications + +### Comprehensive Evaluation +- Comparison with classical baselines +- Improvement metrics +- Spatial analysis +- Statistical validation + +## Usage Example + +```python +from graph_weather.graph_weather.models.data_assimilation import SimpleDataAssimilationModel, ThreeDVarLoss +from graph_weather.graph_weather.data.assimilation_dataloader import AssimilationDataModule +from graph_weather.graph_weather.models.training_loop import train_data_assimilation_model + +# Create data module +data_module = AssimilationDataModule( + grid_size=(16, 16), + num_channels=1, + bg_error_std=0.5, + obs_error_std=0.3, + obs_fraction=0.6 +) +data_module.setup() + +# Initialize model +model = SimpleDataAssimilationModel( + grid_size=(16, 16), + num_channels=1, + hidden_dim=64, + num_layers=3 +) + +# Train model +trainer, results = train_data_assimilation_model( + model=model, + train_loader=data_module.train_dataloader(), + val_loader=data_module.val_dataloader(), + epochs=100, + lr=1e-3 +) +``` + +## Mathematical Foundation + +The 3D-Var cost function is based on Bayesian estimation theory: + +``` +J(x) = (x - x_b)^T B^{-1} (x - x_b) + (y - Hx)^T R^{-1} (y - Hx) +``` + +Where the first term represents the background constraint and the second term represents the observation constraint. The neural network learns to find the optimal balance between these constraints without explicit supervision. + +## Advantages Over Classical Methods + +1. **Learned Error Covariances**: The neural network can learn complex, non-linear relationships +2. **Adaptive Combination**: Automatically adjusts weighting based on data quality +3. **Scalability**: Can handle high-dimensional state spaces efficiently +4. **End-to-End Learning**: Optimizes the complete assimilation process + +## Applications + +This framework is suitable for: +- Weather forecasting data assimilation +- Climate model state estimation +- Oceanographic data assimilation +- Any physical system with background models and observations \ No newline at end of file diff --git a/graph_weather/example_usage.py b/graph_weather/example_usage.py new file mode 100644 index 00000000..8ab17381 --- /dev/null +++ b/graph_weather/example_usage.py @@ -0,0 +1,170 @@ +""" +Example usage of the self-supervised data assimilation framework +""" + +import torch +import numpy as np +from graph_weather.graph_weather.models.data_assimilation import ( + SimpleDataAssimilationModel, + ThreeDVarLoss +) +from graph_weather.graph_weather.data.assimilation_dataloader import ( + AssimilationDataModule +) +from graph_weather.graph_weather.models.training_loop import ( + train_data_assimilation_model +) +from graph_weather.graph_weather.models.evaluation import ( + DataAssimilationEvaluator +) +from graph_weather.graph_weather.models.visualization import ( + plot_training_curves, + plot_comparison_grid, + plot_error_maps +) + + +def main(): + print("Self-Supervised Data Assimilation Example") + print("="*50) + + # Set random seed for reproducibility + torch.manual_seed(42) + np.random.seed(42) + + # 1. Define the problem setup + print("\n1. Setting up the problem...") + grid_size = (12, 12) # 12x12 spatial grid + num_channels = 1 # Single variable (e.g., temperature) + batch_size = 16 + epochs = 20 # Small number for demo + + print(f"Grid size: {grid_size}") + print(f"Number of channels: {num_channels}") + print(f"Batch size: {batch_size}") + print(f"Training epochs: {epochs}") + + # 2. Create data module + print("\n2. Creating data module...") + data_module = AssimilationDataModule( + num_samples=500, # Number of training samples + grid_size=grid_size, + num_channels=num_channels, + bg_error_std=0.5, # Background error standard deviation + obs_error_std=0.3, # Observation error standard deviation + obs_fraction=0.6, # 60% of grid points have observations + batch_size=batch_size + ) + data_module.setup() + + print(f"Training samples: {len(data_module.train_dataloader().dataset)}") + print(f"Validation samples: {len(data_module.val_dataloader().dataset)}") + print(f"Test samples: {len(data_module.test_dataloader().dataset)}") + + # 3. Initialize the model + print("\n3. Initializing the model...") + model = SimpleDataAssimilationModel( + grid_size=grid_size, + num_channels=num_channels, + hidden_dim=32, # Hidden dimension for conv layers + num_layers=2 # Number of processing layers + ) + + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + # 4. Train the model + print("\n4. Training the model...") + print("This will minimize the 3D-Var cost function without ground truth!") + + trainer, results = train_data_assimilation_model( + model=model, + train_loader=data_module.train_dataloader(), + val_loader=data_module.val_dataloader(), + epochs=epochs, + lr=1e-3, + device='cpu' # Using CPU for this example + ) + + print(f"Final training loss: {results['train_losses'][-1]:.6f}") + print(f"Final validation loss: {results['val_losses'][-1]:.6f}") + + # 5. Plot training curves + print("\n5. Plotting training curves...") + plot_training_curves( + results['train_losses'], + results['val_losses'], + title="Data Assimilation Training Curves" + ) + + # 6. Evaluate the model + print("\n6. Evaluating the model...") + evaluator = DataAssimilationEvaluator(model, device='cpu') + eval_metrics = evaluator.evaluate_dataset(data_module.test_dataloader()) + + print("Evaluation Metrics:") + for key, value in eval_metrics.items(): + if 'avg_' in key and ('rmse' in key or 'mae' in key or 'bias' in key or 'correlation' in key): + print(f" {key}: {value:.4f}") + + # 7. Visualize results + print("\n7. Visualizing results...") + + # Get a batch from test data for visualization + test_iter = iter(data_module.test_dataloader()) + batch = next(test_iter) + + background = batch['background'] + observations = batch['observations'] + + # Generate analysis using the trained model + model.eval() + with torch.no_grad(): + analysis = model(background, observations) + + # If true state is available, visualize comparison + if 'true_state' in batch: + true_state = batch['true_state'] + + print("Creating comparison visualization...") + plot_comparison_grid( + background, observations, analysis, true_state, + titles=['Background', 'Observations', 'Analysis', 'True State'] + ) + + print("Creating error maps...") + plot_error_maps( + background, observations, analysis, true_state, + titles=['Background Error', 'Observation Error', 'Analysis Error'] + ) + + # 8. Compare with baselines + print("\n8. Comparing with baselines...") + from graph_weather.graph_weather.models.training_loop import compare_with_baselines + + comparison = compare_with_baselines( + model, + data_module.test_dataloader(), + device='cpu' + ) + + print("Baseline Comparison:") + for key, value in comparison.items(): + print(f" {key}: {value:.4f}") + + # 9. Summary + print("\n" + "="*50) + print("SUMMARY") + print("="*50) + print(f"✓ Successfully trained a self-supervised data assimilation model") + print(f"✓ Model learned to minimize 3D-Var cost function without ground truth") + print(f"✓ Analysis RMSE: {comparison['avg_analysis_rmse']:.4f}") + print(f"✓ Background RMSE: {comparison['avg_background_rmse']:.4f}") + print(f"✓ Analysis improvement over background: {comparison['analysis_improvement_over_bg']:.2f}%") + print(f"✓ Analysis improvement over observations: {comparison['analysis_improvement_over_obs']:.2f}%") + + print(f"\nThe model successfully learned to combine background and observations") + print(f"optimally to produce better analysis states than either input alone!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 3710e24a..307e88d2 100755 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -13,3 +13,4 @@ from .layers.encoder import Encoder from .layers.processor import Processor from .layers.stochastic_decomposition import StochasticDecompositionLayer +from .data_assimilation import DataAssimilationModel, ThreeDVarLoss diff --git a/graph_weather/models/data_assimilation.py b/graph_weather/models/data_assimilation.py new file mode 100644 index 00000000..3c463b99 --- /dev/null +++ b/graph_weather/models/data_assimilation.py @@ -0,0 +1,364 @@ +""" +Self-Supervised Data Assimilation Framework with 3D-Var Loss + +Implements a neural network that learns to produce analysis states by minimizing +the 3D-Var cost function without using ground-truth labels. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class ThreeDVarLoss(nn.Module): + """ + Implements the 3D-Var cost function as a self-supervised loss: + + J(x) = (x - x_b)^T B^{-1} (x - x_b) + (y - Hx)^T R^{-1} (y - Hx) + + Where: + - x: analysis state (model output) + - x_b: background state (first guess) + - y: observations + - B: background error covariance + - R: observation error covariance + - H: observation operator + """ + + def __init__(self, + background_error_covariance=None, + observation_error_covariance=None, + observation_operator=None, + bg_weight=1.0, + obs_weight=1.0): + """ + Initialize the 3D-Var loss function + + Args: + background_error_covariance: B matrix (background error covariance) + observation_error_covariance: R matrix (observation error covariance) + observation_operator: H matrix (observation operator) + bg_weight: Weight for background term + obs_weight: Weight for observation term + """ + super(ThreeDVarLoss, self).__init__() + + self.bg_weight = bg_weight + self.obs_weight = obs_weight + + # Initialize background error covariance B + if background_error_covariance is None: + # Default to identity matrix (diagonal with unit variance) + self.B_inv = None # Will be computed as identity when needed + else: + if isinstance(background_error_covariance, torch.Tensor): + self.B_inv = torch.inverse(background_error_covariance) + else: + self.B_inv = torch.inverse(torch.tensor(background_error_covariance)) + + # Initialize observation error covariance R + if observation_error_covariance is None: + # Default to identity matrix (diagonal with unit variance) + self.R_inv = None # Will be computed as identity when needed + else: + if isinstance(observation_error_covariance, torch.Tensor): + self.R_inv = torch.inverse(observation_error_covariance) + else: + self.R_inv = torch.inverse(torch.tensor(observation_error_covariance)) + + # Initialize observation operator H + if observation_operator is None: + # Default to identity (direct observation of state variables) + self.H = None # Will be treated as identity when needed + else: + if isinstance(observation_operator, torch.Tensor): + self.H = observation_operator + else: + self.H = torch.tensor(observation_operator) + + def forward(self, analysis, background, observations): + """ + Compute the 3D-Var loss + + Args: + analysis: Model output (analysis state x) + background: Background state (x_b) + observations: Observations (y) + + Returns: + Total loss value + """ + batch_size = analysis.size(0) + + # Background term: (x - x_b)^T B^{-1} (x - x_b) + bg_diff = analysis - background + + if self.B_inv is None: + # Use identity matrix for B^{-1} + bg_term = torch.sum(bg_diff * bg_diff, dim=-1) # Element-wise square and sum + else: + # Compute quadratic form (x - x_b)^T B^{-1} (x - x_b) + bg_term = torch.sum(bg_diff * torch.matmul(bg_diff.unsqueeze(-2), self.B_inv).squeeze(-2), dim=-1) + + bg_term = self.bg_weight * torch.mean(bg_term) + + # Observation term: (y - Hx)^T R^{-1} (y - Hx) + if self.H is None: + # H is identity, so Hx = x + hx = analysis + else: + # Apply observation operator: Hx + if len(analysis.shape) == 2: + # 2D case: [batch, features] + hx = torch.matmul(analysis, self.H.T) + else: + # For multi-dimensional case, we might need to reshape + original_shape = analysis.shape + analysis_flat = analysis.view(batch_size, -1) + hx_flat = torch.matmul(analysis_flat, self.H.T) + hx = hx_flat.view(original_shape) + + obs_diff = observations - hx + + if self.R_inv is None: + # Use identity matrix for R^{-1} + obs_term = torch.sum(obs_diff * obs_diff, dim=-1) # Element-wise square and sum + else: + # Compute quadratic form (y - Hx)^T R^{-1} (y - Hx) + obs_term = torch.sum(obs_diff * torch.matmul(obs_diff.unsqueeze(-2), self.R_inv).squeeze(-2), dim=-1) + + obs_term = self.obs_weight * torch.mean(obs_term) + + # Total 3D-Var cost + total_loss = bg_term + obs_term + + return total_loss + + +class DataAssimilationModel(nn.Module): + """ + Neural network model for self-supervised data assimilation. + + Takes background state and observations as input and produces an analysis state + that minimizes the 3D-Var cost function. + """ + + def __init__(self, + input_dim, + hidden_dim=256, + num_layers=3, + dropout=0.1, + activation='relu'): + """ + Initialize the data assimilation model + + Args: + input_dim: Dimension of the input state + hidden_dim: Hidden layer dimension + num_layers: Number of hidden layers + dropout: Dropout rate + activation: Activation function ('relu', 'tanh', 'gelu') + """ + super(DataAssimilationModel, self).__init__() + + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + + # Define activation function + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'gelu': + self.activation = nn.GELU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + # Encoder to combine background and observations + layers = [] + layers.append(nn.Linear(input_dim * 2, hidden_dim)) # bg + obs + layers.append(self.activation) + layers.append(nn.Dropout(dropout)) + + for _ in range(num_layers - 1): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + layers.append(self.activation) + layers.append(nn.Dropout(dropout)) + + # Output layer to produce analysis + layers.append(nn.Linear(hidden_dim, input_dim)) + + self.network = nn.Sequential(*layers) + + def forward(self, background, observations): + """ + Forward pass of the data assimilation model + + Args: + background: Background state (x_b) + observations: Observations (y) + + Returns: + analysis: Analysis state (x) + """ + # Concatenate background and observations along the feature dimension + combined_input = torch.cat([background, observations], dim=-1) + + # Pass through the network to get analysis + analysis = self.network(combined_input) + + return analysis + + +class SimpleDataAssimilationModel(nn.Module): + """ + Simplified version that works with 1D/2D spatial grids + """ + + def __init__(self, + grid_size, + num_channels=1, + hidden_dim=64, + num_layers=2): + """ + Initialize a simple data assimilation model for grid data + + Args: + grid_size: Size of the spatial grid (height, width) or (size,) + num_channels: Number of channels/variables + hidden_dim: Hidden dimension for processing + num_layers: Number of processing layers + """ + super(SimpleDataAssimilationModel, self).__init__() + + if isinstance(grid_size, (tuple, list)): + self.grid_shape = grid_size + self.grid_size = np.prod(grid_size) + else: + self.grid_shape = (grid_size,) + self.grid_size = grid_size + + self.num_channels = num_channels + self.input_features = self.grid_size * num_channels + + # Simple CNN-based architecture for spatial data + layers = [] + layers.append(nn.Conv1d(2 * num_channels, hidden_dim, kernel_size=3, padding=1)) + layers.append(nn.ReLU()) + + for _ in range(num_layers - 1): + layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)) + layers.append(nn.ReLU()) + + layers.append(nn.Conv1d(hidden_dim, num_channels, kernel_size=3, padding=1)) + + self.conv_layers = nn.Sequential(*layers) + + def forward(self, background, observations): + """ + Forward pass for grid data + + Args: + background: Background state [batch, channels, ...spatial_dims] + observations: Observations [batch, channels, ...spatial_dims] + + Returns: + analysis: Analysis state [batch, channels, ...spatial_dims] + """ + batch_size = background.size(0) + + # Reshape for 1D convolution if needed + if len(background.shape) > 3: # [batch, channels, height, width] + bg_flat = background.view(batch_size, self.num_channels, -1) + obs_flat = observations.view(batch_size, self.num_channels, -1) + else: # [batch, channels, length] + bg_flat = background + obs_flat = observations + + # Concatenate along channel dimension + combined = torch.cat([bg_flat, obs_flat], dim=1) # [batch, 2*channels, spatial] + + # Process through convolutional layers + analysis_flat = self.conv_layers(combined) + + # Reshape back to original spatial dimensions + if len(background.shape) > 3: + analysis = analysis_flat.view(batch_size, self.num_channels, *self.grid_shape) + else: + analysis = analysis_flat + + return analysis + + +def create_observation_operator(grid_size, obs_fraction=0.5, obs_locations=None): + """ + Create a simple observation operator H that selects a subset of grid points + + Args: + grid_size: Size of the grid (can be int for 1D or tuple for 2D) + obs_fraction: Fraction of grid points that have observations + obs_locations: Specific locations of observations (optional) + + Returns: + H: Observation operator matrix + """ + total_size = np.prod(grid_size) if isinstance(grid_size, (tuple, list)) else grid_size + + if obs_locations is None: + # Randomly select observation locations + num_obs = int(total_size * obs_fraction) + obs_indices = np.random.choice(total_size, size=num_obs, replace=False) + else: + obs_indices = obs_locations + num_obs = len(obs_indices) + + # Create H matrix (num_obs x total_size) + H = torch.zeros(num_obs, total_size) + for i, idx in enumerate(obs_indices): + H[i, idx] = 1.0 + + return H + + +def generate_synthetic_data(batch_size=32, grid_size=(10, 10), num_channels=1): + """ + Generate synthetic background and observation data for testing + + Args: + batch_size: Number of samples in batch + grid_size: Size of spatial grid + num_channels: Number of variables/channels + + Returns: + background: Background state + observations: Observations + true_state: True state (for evaluation only, not used in training) + """ + total_size = np.prod(grid_size) if isinstance(grid_size, (tuple, list)) else grid_size + + # Generate a true state with some spatial correlation + true_state = torch.randn(batch_size, num_channels, *grid_size) * 2 + # Apply some smoothing to create spatial correlation + if len(grid_size) == 2: + # Apply a simple smoothing kernel + kernel = torch.ones(1, 1, 3, 3) / 9 + for c in range(num_channels): + smoothed = F.conv2d( + true_state[:, c:c+1], + kernel, + padding=1, + groups=1 + ) + true_state[:, c:c+1] = smoothed + + # Create background as true state with some error + background_error = torch.randn_like(true_state) * 0.5 + background = true_state + background_error + + # Create observations with observation error + observation_error = torch.randn_like(true_state) * 0.3 + observations = true_state + observation_error + + return background, observations, true_state \ No newline at end of file diff --git a/graph_weather/models/evaluation.py b/graph_weather/models/evaluation.py new file mode 100644 index 00000000..01e28f92 --- /dev/null +++ b/graph_weather/models/evaluation.py @@ -0,0 +1,439 @@ +""" +Evaluation metrics for self-supervised data assimilation +""" + +import torch +import numpy as np +from sklearn.metrics import mean_squared_error, mean_absolute_error +import scipy.stats as stats + + +def compute_rmse(predictions, targets): + """ + Compute Root Mean Square Error + + Args: + predictions: Predicted values + targets: Target values + + Returns: + rmse: Root mean square error + """ + return torch.sqrt(torch.mean((predictions - targets) ** 2)).item() + + +def compute_mae(predictions, targets): + """ + Compute Mean Absolute Error + + Args: + predictions: Predicted values + targets: Target values + + Returns: + mae: Mean absolute error + """ + return torch.mean(torch.abs(predictions - targets)).item() + + +def compute_bias(predictions, targets): + """ + Compute bias (mean error) + + Args: + predictions: Predicted values + targets: Target values + + Returns: + bias: Mean error + """ + return torch.mean(predictions - targets).item() + + +def compute_correlation(predictions, targets): + """ + Compute Pearson correlation coefficient + + Args: + predictions: Predicted values + targets: Target values + + Returns: + correlation: Pearson correlation coefficient + """ + pred_flat = predictions.view(-1) + target_flat = targets.view(-1) + + # Center the data + pred_centered = pred_flat - torch.mean(pred_flat) + target_centered = target_flat - torch.mean(target_flat) + + # Compute correlation + numerator = torch.sum(pred_centered * target_centered) + denominator = torch.sqrt(torch.sum(pred_centered ** 2) * torch.sum(target_centered ** 2)) + + if denominator == 0: + return 0.0 + + return (numerator / denominator).item() + + +def compute_spatial_metrics(predictions, targets): + """ + Compute spatial metrics for gridded data + + Args: + predictions: Predicted values [batch, channels, height, width] + targets: Target values [batch, channels, height, width] + + Returns: + metrics: Dictionary with spatial metrics + """ + batch_size, channels = predictions.shape[0], predictions.shape[1] + + rmse_spatial = [] + correlation_spatial = [] + + for b in range(batch_size): + for c in range(channels): + pred_channel = predictions[b, c].flatten() + target_channel = targets[b, c].flatten() + + rmse = torch.sqrt(torch.mean((pred_channel - target_channel) ** 2)).item() + rmse_spatial.append(rmse) + + # Compute correlation + pred_centered = pred_channel - torch.mean(pred_channel) + target_centered = target_channel - torch.mean(target_channel) + numerator = torch.sum(pred_centered * target_centered) + denominator = torch.sqrt(torch.sum(pred_centered ** 2) * torch.sum(target_centered ** 2)) + + if denominator != 0: + corr = (numerator / denominator).item() + else: + corr = 0.0 + correlation_spatial.append(corr) + + return { + 'avg_rmse_spatial': np.mean(rmse_spatial), + 'std_rmse_spatial': np.std(rmse_spatial), + 'avg_correlation_spatial': np.mean(correlation_spatial), + 'std_correlation_spatial': np.std(correlation_spatial) + } + + +def compute_information_gain(analysis, background, true_state): + """ + Compute information gain from data assimilation + Measures how much better the analysis is compared to background + + Args: + analysis: Analysis state from model + background: Background state (first guess) + true_state: True state (for evaluation) + + Returns: + info_gain: Information gain metric + """ + bg_error = torch.mean((background - true_state) ** 2) + analysis_error = torch.mean((analysis - true_state) ** 2) + + # Information gain as reduction in error variance + info_gain = (bg_error - analysis_error) / bg_error * 100 if bg_error > 0 else 0 + + return info_gain.item() + + +class DataAssimilationEvaluator: + """ + Comprehensive evaluator for data assimilation models + """ + + def __init__(self, model, device='cpu'): + self.model = model + self.device = device + + def evaluate_batch(self, batch): + """ + Evaluate a single batch + + Args: + batch: Dictionary with 'background', 'observations', 'true_state' + + Returns: + metrics: Dictionary with evaluation metrics for the batch + """ + self.model.eval() + + with torch.no_grad(): + background = batch['background'].to(self.device) + observations = batch['observations'].to(self.device) + true_state = batch['true_state'].to(self.device) + + # Get model analysis + analysis = self.model(background, observations) + + # Compute metrics + metrics = { + 'analysis_rmse': compute_rmse(analysis, true_state), + 'background_rmse': compute_rmse(background, true_state), + 'observations_rmse': compute_rmse(observations, true_state), + 'analysis_mae': compute_mae(analysis, true_state), + 'background_mae': compute_mae(background, true_state), + 'analysis_bias': compute_bias(analysis, true_state), + 'background_bias': compute_bias(background, true_state), + 'analysis_correlation': compute_correlation(analysis, true_state), + 'background_correlation': compute_correlation(background, true_state), + 'information_gain': compute_information_gain(analysis, background, true_state) + } + + # Add spatial metrics if data is gridded + if len(analysis.shape) > 2: # Has spatial dimensions + spatial_metrics = compute_spatial_metrics(analysis, true_state) + metrics.update(spatial_metrics) + + return metrics + + def evaluate_dataset(self, data_loader): + """ + Evaluate the model on an entire dataset + + Args: + data_loader: DataLoader with test data + + Returns: + overall_metrics: Dictionary with overall evaluation metrics + """ + all_metrics = { + 'analysis_rmse': [], + 'background_rmse': [], + 'observations_rmse': [], + 'analysis_mae': [], + 'background_mae': [], + 'analysis_bias': [], + 'background_bias': [], + 'analysis_correlation': [], + 'background_correlation': [], + 'information_gain': [] + } + + spatial_metrics_list = [] + + for batch in data_loader: + batch_metrics = self.evaluate_batch(batch) + + # Collect metrics + for key in all_metrics.keys(): + if key in batch_metrics: + all_metrics[key].append(batch_metrics[key]) + + # Collect spatial metrics if available + if 'avg_rmse_spatial' in batch_metrics: + spatial_metrics_list.append({ + 'avg_rmse_spatial': batch_metrics['avg_rmse_spatial'], + 'avg_correlation_spatial': batch_metrics['avg_correlation_spatial'] + }) + + # Compute overall metrics + overall_metrics = {} + for key, values in all_metrics.items(): + if values: # Only compute if we have values + overall_metrics[f'avg_{key}'] = np.mean(values) + overall_metrics[f'std_{key}'] = np.std(values) + + # Compute spatial metrics + if spatial_metrics_list: + spatial_rmse_values = [m['avg_rmse_spatial'] for m in spatial_metrics_list] + spatial_corr_values = [m['avg_correlation_spatial'] for m in spatial_metrics_list] + + overall_metrics['avg_spatial_rmse'] = np.mean(spatial_rmse_values) + overall_metrics['std_spatial_rmse'] = np.std(spatial_rmse_values) + overall_metrics['avg_spatial_correlation'] = np.mean(spatial_corr_values) + overall_metrics['std_spatial_correlation'] = np.std(spatial_corr_values) + + return overall_metrics + + +def compare_methods(model_analysis, background, observations, true_state): + """ + Compare different methods: model analysis, background, observations + + Args: + model_analysis: Analysis from the trained model + background: Background state + observations: Observations + true_state: True state for comparison + + Returns: + comparison: Dictionary with comparison results + """ + results = {} + + # Compute metrics for each method + methods = { + 'analysis': model_analysis, + 'background': background, + 'observations': observations + } + + for method_name, method_output in methods.items(): + results[f'{method_name}_rmse'] = compute_rmse(method_output, true_state) + results[f'{method_name}_mae'] = compute_mae(method_output, true_state) + results[f'{method_name}_bias'] = compute_bias(method_output, true_state) + results[f'{method_name}_correlation'] = compute_correlation(method_output, true_state) + + # Compute improvements + bg_rmse = results['background_rmse'] + obs_rmse = results['observations_rmse'] + analysis_rmse = results['analysis_rmse'] + + results['analysis_improvement_over_bg_pct'] = ( + (bg_rmse - analysis_rmse) / bg_rmse * 100 + ) if bg_rmse > 0 else 0 + + results['analysis_improvement_over_obs_pct'] = ( + (obs_rmse - analysis_rmse) / obs_rmse * 100 + ) if obs_rmse > 0 else 0 + + results['bg_improvement_over_obs_pct'] = ( + (obs_rmse - bg_rmse) / obs_rmse * 100 + ) if obs_rmse > 0 else 0 + + return results + + +def classical_3dvar_analysis(background, observations, H, B, R): + """ + Classical 3D-Var analysis for comparison + + Args: + background: Background state + observations: Observations + H: Observation operator + B: Background error covariance + R: Observation error covariance + + Returns: + analysis: Classical 3D-Var analysis + """ + # Reshape for matrix operations + batch_size = background.shape[0] + state_size = background[0].numel() + obs_size = observations[0].numel() + + # Convert to appropriate shapes + xb = background.view(batch_size, -1) # [batch, state_size] + y = observations.view(batch_size, -1) # [batch, obs_size] + + analysis_results = [] + + for i in range(batch_size): + xb_i = xb[i:i+1].T # [state_size, 1] + y_i = y[i:i+1].T # [obs_size, 1] + + # Compute Kalman gain: K = B * H^T * (H * B * H^T + R)^(-1) + # For simplicity, using diagonal approximations + if B is None: + B_i = torch.eye(state_size, device=xb.device) + else: + B_i = B + + if R is None: + R_i = torch.eye(obs_size, device=y.device) + else: + R_i = R + + if H is None: + H_i = torch.eye(min(state_size, obs_size), device=xb.device)[:obs_size, :state_size] + else: + H_i = H + + # Calculate terms + HBHT_R = torch.matmul(torch.matmul(H_i, B_i), H_i.T) + R_i + K = torch.matmul(torch.matmul(B_i, H_i.T), torch.inverse(HBHT_R)) + + # Compute analysis: xa = xb + K * (y - H * xb) + innovation = y_i - torch.matmul(H_i, xb_i) + correction = torch.matmul(K, innovation) + xa_i = xb_i + correction + + analysis_results.append(xa_i.T) + + analysis = torch.cat(analysis_results, dim=0) + return analysis.view_as(background) + + +def compute_cross_validation_score(model, data_loader, k_folds=5): + """ + Compute cross-validation score for the model + + Args: + model: Data assimilation model + data_loader: Data loader + k_folds: Number of folds for cross-validation + + Returns: + cv_scores: List of scores for each fold + """ + # For simplicity, using a basic approach to simulate cross-validation + # In practice, you'd split your dataset into k folds + model.eval() + + all_rmse = [] + batch_count = 0 + + with torch.no_grad(): + for batch in data_loader: + background = batch['background'].to(model.device if hasattr(model, 'device') else 'cpu') + observations = batch['observations'].to(model.device if hasattr(model, 'device') else 'cpu') + + if 'true_state' in batch: + true_state = batch['true_state'].to(model.device if hasattr(model, 'device') else 'cpu') + + analysis = model(background, observations) + rmse = compute_rmse(analysis, true_state) + all_rmse.append(rmse) + batch_count += 1 + + # Limit for efficiency + if batch_count >= k_folds: + break + + return all_rmse + + +def compute_gradient_norm(model): + """ + Compute the norm of gradients for the model + + Args: + model: PyTorch model + + Returns: + total_norm: Total gradient norm + """ + total_norm = 0 + for p in model.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** (1. / 2) + return total_norm + + +def compute_parameter_norm(model): + """ + Compute the norm of parameters for the model + + Args: + model: PyTorch model + + Returns: + total_norm: Total parameter norm + """ + total_norm = 0 + for p in model.parameters(): + param_norm = p.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** (1. / 2) + return total_norm \ No newline at end of file diff --git a/graph_weather/models/training_loop.py b/graph_weather/models/training_loop.py new file mode 100644 index 00000000..d5a6f4c6 --- /dev/null +++ b/graph_weather/models/training_loop.py @@ -0,0 +1,538 @@ +""" +Training loop for self-supervised data assimilation with 3D-Var loss +""" + +import torch +import torch.nn as nn +from torch.optim import Adam, SGD +from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau +import numpy as np +from tqdm import tqdm +import matplotlib.pyplot as plt +from .data_assimilation import DataAssimilationModel, ThreeDVarLoss, SimpleDataAssimilationModel + + +class DataAssimilationTrainer: + """ + Trainer for the self-supervised data assimilation model + """ + + def __init__( + self, + model, + loss_fn, + optimizer=None, + lr=1e-3, + device='cpu', + scheduler=None + ): + """ + Initialize the trainer + + Args: + model: Data assimilation model + loss_fn: 3D-Var loss function + optimizer: Optimizer (default: Adam) + lr: Learning rate + device: Device to train on + scheduler: Learning rate scheduler + """ + self.model = model.to(device) + self.loss_fn = loss_fn.to(device) + self.device = device + + if optimizer is None: + self.optimizer = Adam(model.parameters(), lr=lr) + else: + self.optimizer = optimizer + + self.scheduler = scheduler + self.train_losses = [] + self.val_losses = [] + + def train_step(self, background, observations): + """ + Perform a single training step + + Args: + background: Background state + observations: Observations + + Returns: + loss: Training loss value + """ + self.model.train() + self.optimizer.zero_grad() + + # Move data to device + background = background.to(self.device) + observations = observations.to(self.device) + + # Forward pass + analysis = self.model(background, observations) + + # Compute loss + loss = self.loss_fn(analysis, background, observations) + + # Backward pass + loss.backward() + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + # Update parameters + self.optimizer.step() + + return loss.item() + + def validation_step(self, background, observations): + """ + Perform a validation step + + Args: + background: Background state + observations: Observations + + Returns: + loss: Validation loss value + """ + self.model.eval() + + with torch.no_grad(): + # Move data to device + background = background.to(self.device) + observations = observations.to(self.device) + + # Forward pass + analysis = self.model(background, observations) + + # Compute loss + loss = self.loss_fn(analysis, background, observations) + + return loss.item() + + def train_epoch(self, train_loader): + """ + Train for one epoch + + Args: + train_loader: Training data loader + + Returns: + avg_loss: Average training loss for the epoch + """ + total_loss = 0 + num_batches = 0 + + for batch in tqdm(train_loader, desc="Training", leave=False): + background = batch['background'] + observations = batch['observations'] + + loss = self.train_step(background, observations) + total_loss += loss + num_batches += 1 + + avg_loss = total_loss / num_batches + return avg_loss + + def validate_epoch(self, val_loader): + """ + Validate for one epoch + + Args: + val_loader: Validation data loader + + Returns: + avg_loss: Average validation loss for the epoch + """ + total_loss = 0 + num_batches = 0 + + for batch in val_loader: + background = batch['background'] + observations = batch['observations'] + + loss = self.validation_step(background, observations) + total_loss += loss + num_batches += 1 + + avg_loss = total_loss / num_batches + return avg_loss + + def fit( + self, + train_loader, + val_loader, + epochs=100, + verbose=True, + save_best_model=True, + model_save_path="best_assimilation_model.pth" + ): + """ + Train the model + + Args: + train_loader: Training data loader + val_loader: Validation data loader + epochs: Number of training epochs + verbose: Whether to print progress + save_best_model: Whether to save the best model + model_save_path: Path to save the best model + + Returns: + train_losses: Training losses for each epoch + val_losses: Validation losses for each epoch + """ + best_val_loss = float('inf') + patience_counter = 0 + + for epoch in range(epochs): + # Training + train_loss = self.train_epoch(train_loader) + self.train_losses.append(train_loss) + + # Validation + val_loss = self.validate_epoch(val_loader) + self.val_losses.append(val_loss) + + # Learning rate scheduling + if self.scheduler is not None: + if isinstance(self.scheduler, ReduceLROnPlateau): + self.scheduler.step(val_loss) + else: + self.scheduler.step() + + if save_best_model and val_loss < best_val_loss: + best_val_loss = val_loss + torch.save(self.model.state_dict(), model_save_path) + patience_counter = 0 + else: + patience_counter += 1 + + if verbose and (epoch + 1) % 10 == 0: + print(f"Epoch [{epoch+1}/{epochs}] - " + f"Train Loss: {train_loss:.6f}, " + f"Val Loss: {val_loss:.6f}") + + if save_best_model: + self.model.load_state_dict(torch.load(model_save_path)) + + return self.train_losses, self.val_losses + + def evaluate_model(self, test_loader, compute_metrics=True): + """ + Evaluate the model on test data + + Args: + test_loader: Test data loader + compute_metrics: Whether to compute additional metrics + + Returns: + results: Dictionary with evaluation results + """ + self.model.eval() + total_loss = 0 + num_batches = 0 + + all_analysis = [] + all_background = [] + all_observations = [] + all_true = [] + + with torch.no_grad(): + for batch in test_loader: + background = batch['background'].to(self.device) + observations = batch['observations'].to(self.device) + + analysis = self.model(background, observations) + loss = self.loss_fn(analysis, background, observations) + + total_loss += loss.item() + num_batches += 1 + + if compute_metrics: + all_analysis.append(analysis.cpu()) + all_background.append(background.cpu()) + all_observations.append(observations.cpu()) + + if 'true_state' in batch: + all_true.append(batch['true_state']) + + avg_loss = total_loss / num_batches + results = {'loss': avg_loss} + + if compute_metrics and all_true: + # Compute additional metrics + all_analysis = torch.cat(all_analysis, dim=0) + all_background = torch.cat(all_background, dim=0) + all_observations = torch.cat(all_observations, dim=0) + all_true = torch.cat(all_true, dim=0) + + # RMSE metrics + analysis_rmse = torch.sqrt(torch.mean((all_analysis - all_true) ** 2)).item() + bg_rmse = torch.sqrt(torch.mean((all_background - all_true) ** 2)).item() + obs_rmse = torch.sqrt(torch.mean((all_observations - all_true) ** 2)).item() + + results.update({ + 'analysis_rmse': analysis_rmse, + 'background_rmse': bg_rmse, + 'observations_rmse': obs_rmse, + 'improvement_over_bg': (bg_rmse - analysis_rmse) / bg_rmse * 100 if bg_rmse > 0 else 0, + 'improvement_over_obs': (obs_rmse - analysis_rmse) / obs_rmse * 100 if obs_rmse > 0 else 0 + }) + + return results + + +def train_data_assimilation_model( + model, + train_loader, + val_loader, + bg_error_covariance=None, + obs_error_covariance=None, + obs_operator=None, + epochs=100, + lr=1e-3, + device='cpu' +): + """ + Convenience function to train a data assimilation model + + Args: + model: Data assimilation model + train_loader: Training data loader + val_loader: Validation data loader + bg_error_covariance: Background error covariance + obs_error_covariance: Observation error covariance + obs_operator: Observation operator + epochs: Number of training epochs + lr: Learning rate + device: Device to train on + + Returns: + trainer: Trained trainer object + results: Training results + """ + # Initialize the 3D-Var loss function + loss_fn = ThreeDVarLoss( + background_error_covariance=bg_error_covariance, + observation_error_covariance=obs_error_covariance, + observation_operator=obs_operator + ) + + # Initialize the trainer + trainer = DataAssimilationTrainer( + model=model, + loss_fn=loss_fn, + lr=lr, + device=device + ) + + # Add learning rate scheduler + scheduler = ReduceLROnPlateau(trainer.optimizer, mode='min', factor=0.5, patience=10) + trainer.scheduler = scheduler + + # Train the model + train_losses, val_losses = trainer.fit( + train_loader=train_loader, + val_loader=val_loader, + epochs=epochs, + verbose=True + ) + + return trainer, {'train_losses': train_losses, 'val_losses': val_losses} + + +def train_with_different_modes( + model_class, + data_module, + input_dim=None, + grid_size=None, + num_channels=1, + epochs=100, + lr=1e-3, + device='cpu' +): + """ + Train the model in different modes: + 1. With good first guess (low background error) + 2. With poor first guess (high background error) - cold start + 3. With varying observation densities + + Args: + model_class: Model class to instantiate + data_module: Data module with different configurations + input_dim: Input dimension for fully connected model + grid_size: Grid size for convolutional model + num_channels: Number of channels + epochs: Number of epochs + lr: Learning rate + device: Device to train on + + Returns: + results: Dictionary with results from different training modes + """ + results = {} + + # Mode 1: With good first guess (low background error) + print("Training with good first guess...") + data_module_good_bg = data_module( + num_samples=1000, + grid_size=grid_size, + num_channels=num_channels, + bg_error_std=0.2, # Low error + obs_error_std=0.3, + obs_fraction=0.5 + ) + data_module_good_bg.setup() + + if input_dim: + model_good_bg = model_class(input_dim=input_dim) + else: + model_good_bg = model_class(grid_size=grid_size, num_channels=num_channels) + + trainer_good_bg, res_good_bg = train_data_assimilation_model( + model=model_good_bg, + train_loader=data_module_good_bg.train_dataloader(), + val_loader=data_module_good_bg.val_dataloader(), + epochs=epochs, + lr=lr, + device=device + ) + + results['good_bg'] = { + 'trainer': trainer_good_bg, + 'results': res_good_bg, + 'eval_results': trainer_good_bg.evaluate_model(data_module_good_bg.test_dataloader()) + } + + # Mode 2: With poor first guess (high background error) - cold start + print("Training with poor first guess (cold start)...") + data_module_poor_bg = data_module( + num_samples=1000, + grid_size=grid_size, + num_channels=num_channels, + bg_error_std=1.0, # High error + obs_error_std=0.3, + obs_fraction=0.5 + ) + data_module_poor_bg.setup() + + if input_dim: + model_poor_bg = model_class(input_dim=input_dim) + else: + model_poor_bg = model_class(grid_size=grid_size, num_channels=num_channels) + + trainer_poor_bg, res_poor_bg = train_data_assimilation_model( + model=model_poor_bg, + train_loader=data_module_poor_bg.train_dataloader(), + val_loader=data_module_poor_bg.val_dataloader(), + epochs=epochs, + lr=lr, + device=device + ) + + results['poor_bg'] = { + 'trainer': trainer_poor_bg, + 'results': res_poor_bg, + 'eval_results': trainer_poor_bg.evaluate_model(data_module_poor_bg.test_dataloader()) + } + + # Mode 3: With sparse observations + print("Training with sparse observations...") + data_module_sparse_obs = data_module( + num_samples=1000, + grid_size=grid_size, + num_channels=num_channels, + bg_error_std=0.5, + obs_error_std=0.3, + obs_fraction=0.2 # Sparse observations + ) + data_module_sparse_obs.setup() + + if input_dim: + model_sparse_obs = model_class(input_dim=input_dim) + else: + model_sparse_obs = model_class(grid_size=grid_size, num_channels=num_channels) + + trainer_sparse_obs, res_sparse_obs = train_data_assimilation_model( + model=model_sparse_obs, + train_loader=data_module_sparse_obs.train_dataloader(), + val_loader=data_module_sparse_obs.val_dataloader(), + epochs=epochs, + lr=lr, + device=device + ) + + results['sparse_obs'] = { + 'trainer': trainer_sparse_obs, + 'results': res_sparse_obs, + 'eval_results': trainer_sparse_obs.evaluate_model(data_module_sparse_obs.test_dataloader()) + } + + return results + + +def compare_with_baselines(model, test_loader, device='cpu'): + """ + Compare the trained model with classical baselines + + Args: + model: Trained assimilation model + test_loader: Test data loader + device: Device to run on + + Returns: + comparison: Dictionary with comparison results + """ + model.eval() + results = { + 'analysis_rmse': [], + 'background_rmse': [], + 'observation_rmse': [], + 'persistence_rmse': [] + } + + with torch.no_grad(): + for batch in test_loader: + background = batch['background'].to(device) + observations = batch['observations'].to(device) + + if 'true_state' in batch: + true_state = batch['true_state'].to(device) + + # Model analysis + analysis = model(background, observations) + + # Compute RMSE for each method + analysis_rmse = torch.sqrt(torch.mean((analysis - true_state) ** 2)).item() + bg_rmse = torch.sqrt(torch.mean((background - true_state) ** 2)).item() + obs_rmse = torch.sqrt(torch.mean((observations - true_state) ** 2)).item() + + # Persistence (assuming observations are closer to truth than background) + # For simplicity, using a weighted average as persistence + persistence = 0.7 * observations + 0.3 * background + persist_rmse = torch.sqrt(torch.mean((persistence - true_state) ** 2)).item() + + results['analysis_rmse'].append(analysis_rmse) + results['background_rmse'].append(bg_rmse) + results['observation_rmse'].append(obs_rmse) + results['persistence_rmse'].append(persist_rmse) + + # Compute averages + comparison = { + 'avg_analysis_rmse': np.mean(results['analysis_rmse']), + 'avg_background_rmse': np.mean(results['background_rmse']), + 'avg_observation_rmse': np.mean(results['observation_rmse']), + 'avg_persistence_rmse': np.mean(results['persistence_rmse']), + 'analysis_improvement_over_bg': ( + (np.mean(results['background_rmse']) - np.mean(results['analysis_rmse'])) / + np.mean(results['background_rmse']) * 100 + ), + 'analysis_improvement_over_obs': ( + (np.mean(results['observation_rmse']) - np.mean(results['analysis_rmse'])) / + np.mean(results['observation_rmse']) * 100 + ) + } + + return comparison \ No newline at end of file diff --git a/graph_weather/models/visualization.py b/graph_weather/models/visualization.py new file mode 100644 index 00000000..6791e9b2 --- /dev/null +++ b/graph_weather/models/visualization.py @@ -0,0 +1,582 @@ +""" +Visualization functions for self-supervised data assimilation +""" + +import torch +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.colors import Normalize +import matplotlib.patches as patches +import warnings +warnings.filterwarnings('ignore') + + +def plot_training_curves(train_losses, val_losses, title="Training Curves"): + """ + Plot training and validation loss curves + + Args: + train_losses: List of training losses + val_losses: List of validation losses + title: Title for the plot + """ + plt.figure(figsize=(10, 6)) + plt.plot(train_losses, label='Training Loss', color='blue') + plt.plot(val_losses, label='Validation Loss', color='red') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.title(title) + plt.legend() + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.show() + + +def plot_comparison_grid(background, observations, analysis, true_state=None, + titles=None, figsize=(15, 10)): + """ + Plot a grid comparing background, observations, analysis, and true state + + Args: + background: Background state + observations: Observations + analysis: Analysis from model + true_state: True state (optional) + titles: Titles for each subplot + figsize: Figure size + """ + # Convert to numpy if torch tensors + if torch.is_tensor(background): + background = background.cpu().numpy() + if torch.is_tensor(observations): + observations = observations.cpu().numpy() + if torch.is_tensor(analysis): + analysis = analysis.cpu().numpy() + if true_state is not None and torch.is_tensor(true_state): + true_state = true_state.cpu().numpy() + + if titles is None: + titles = ['Background', 'Observations', 'Analysis'] + if true_state is not None: + titles.append('True State') + + n_plots = 3 if true_state is None else 4 + + fig, axes = plt.subplots(1, n_plots, figsize=figsize) + if n_plots == 1: + axes = [axes] + + # Determine common color scale + all_data = [background, observations, analysis] + if true_state is not None: + all_data.append(true_state) + + # Handle different tensor shapes + def get_data_for_plot(data): + if data.ndim == 4: # [batch, channels, height, width] + return data[0, 0] # Take first sample, first channel + elif data.ndim == 3: # [batch, height, width] or [channels, height, width] + return data[0] if data.shape[0] <= 10 else data[0] # Heuristic for batch vs channels + elif data.ndim == 2: # [height, width] + return data + else: + raise ValueError(f"Unexpected data shape: {data.shape}") + + processed_data = [get_data_for_plot(d) for d in all_data] + vmin = min([d.min() for d in processed_data]) + vmax = max([d.max() for d in processed_data]) + + # Plot each field + im1 = axes[0].imshow(processed_data[0], cmap='RdBu_r', vmin=vmin, vmax=vmax) + axes[0].set_title(titles[0]) + axes[0].axis('off') + plt.colorbar(im1, ax=axes[0]) + + im2 = axes[1].imshow(processed_data[1], cmap='RdBu_r', vmin=vmin, vmax=vmax) + axes[1].set_title(titles[1]) + axes[1].axis('off') + plt.colorbar(im2, ax=axes[1]) + + im3 = axes[2].imshow(processed_data[2], cmap='RdBu_r', vmin=vmin, vmax=vmax) + axes[2].set_title(titles[2]) + axes[2].axis('off') + plt.colorbar(im3, ax=axes[2]) + + if true_state is not None and n_plots > 3: + im4 = axes[3].imshow(processed_data[3], cmap='RdBu_r', vmin=vmin, vmax=vmax) + axes[3].set_title(titles[3]) + axes[3].axis('off') + plt.colorbar(im4, ax=axes[3]) + + plt.tight_layout() + plt.show() + + +def plot_error_maps(background, observations, analysis, true_state, + titles=None, figsize=(18, 5)): + """ + Plot error maps comparing different methods + + Args: + background: Background state + observations: Observations + analysis: Analysis from model + true_state: True state + titles: Titles for each subplot + figsize: Figure size + """ + # Convert to numpy if torch tensors + if torch.is_tensor(background): + background = background.cpu().numpy() + if torch.is_tensor(observations): + observations = observations.cpu().numpy() + if torch.is_tensor(analysis): + analysis = analysis.cpu().numpy() + if torch.is_tensor(true_state): + true_state = true_state.cpu().numpy() + + if titles is None: + titles = ['Background Error', 'Observation Error', 'Analysis Error'] + + fig, axes = plt.subplots(1, 3, figsize=figsize) + + # Calculate errors + def get_first_element(data): + if data.ndim == 4: # [batch, channels, height, width] + return data[0, 0] # Take first sample, first channel + elif data.ndim == 3: # [batch, height, width] + return data[0] + else: + return data + + bg_error = get_first_element(background) - get_first_element(true_state) + obs_error = get_first_element(observations) - get_first_element(true_state) + analysis_error = get_first_element(analysis) - get_first_element(true_state) + + # Determine common color scale for errors (centered at 0) + max_error = max(np.abs(bg_error).max(), + np.abs(obs_error).max(), + np.abs(analysis_error).max()) + + # Plot error maps + im1 = axes[0].imshow(bg_error if bg_error.ndim == 2 else bg_error[0], + cmap='RdBu_r', vmin=-max_error, vmax=max_error) + axes[0].set_title(titles[0]) + axes[0].axis('off') + plt.colorbar(im1, ax=axes[0]) + + im2 = axes[1].imshow(obs_error if obs_error.ndim == 2 else obs_error[0], + cmap='RdBu_r', vmin=-max_error, vmax=max_error) + axes[1].set_title(titles[1]) + axes[1].axis('off') + plt.colorbar(im2, ax=axes[1]) + + im3 = axes[2].imshow(analysis_error if analysis_error.ndim == 2 else analysis_error[0], + cmap='RdBu_r', vmin=-max_error, vmax=max_error) + axes[2].set_title(titles[2]) + axes[2].axis('off') + plt.colorbar(im3, ax=axes[2]) + + plt.tight_layout() + plt.show() + + +def plot_rmse_comparison(metrics_dict, title="RMSE Comparison"): + """ + Plot RMSE comparison between different methods + + Args: + metrics_dict: Dictionary with method names as keys and RMSE values as values + title: Title for the plot + """ + methods = list(metrics_dict.keys()) + rmse_values = list(metrics_dict.values()) + + plt.figure(figsize=(10, 6)) + bars = plt.bar(methods, rmse_values, color=['skyblue', 'lightcoral', 'lightgreen', 'gold']) + plt.ylabel('RMSE') + plt.title(title) + plt.xticks(rotation=45) + + # Add value labels on bars + for bar, value in zip(bars, rmse_values): + plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(rmse_values)*0.01, + f'{value:.3f}', ha='center', va='bottom') + + plt.tight_layout() + plt.show() + + +def plot_improvement_heatmap(improvement_matrix, title="Improvement Heatmap"): + """ + Plot improvement heatmap showing where analysis is better than background + + Args: + improvement_matrix: Matrix showing improvement at each grid point + title: Title for the plot + """ + plt.figure(figsize=(8, 6)) + sns.heatmap(improvement_matrix, annot=True, fmt='.2f', cmap='RdYlGn', center=0, + cbar_kws={'label': 'Improvement'}) + plt.title(title) + plt.tight_layout() + plt.show() + + +def plot_time_series_comparison(time_series_data, labels=None, title="Time Series Comparison"): + """ + Plot time series comparison of metrics + + Args: + time_series_data: List of time series to plot + labels: Labels for each series + title: Title for the plot + """ + plt.figure(figsize=(12, 6)) + + for i, series in enumerate(time_series_data): + label = labels[i] if labels else f'Series {i+1}' + plt.plot(series, label=label, linewidth=2) + + plt.xlabel('Time Steps') + plt.ylabel('Value') + plt.title(title) + plt.legend() + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.show() + + +def plot_histogram_comparison(true_state, background, analysis, bins=50, + title="Distribution Comparison"): + """ + Plot histogram comparison of distributions + + Args: + true_state: True state values + background: Background state values + analysis: Analysis state values + bins: Number of histogram bins + title: Title for the plot + """ + # Convert to numpy if torch tensors + if torch.is_tensor(true_state): + true_state = true_state.cpu().numpy() + if torch.is_tensor(background): + background = background.cpu().numpy() + if torch.is_tensor(analysis): + analysis = analysis.cpu().numpy() + + plt.figure(figsize=(10, 6)) + + true_flat = true_state.flatten() + bg_flat = background.flatten() + analysis_flat = analysis.flatten() + + plt.hist(true_flat, bins=bins, alpha=0.5, label='True State', density=True) + plt.hist(bg_flat, bins=bins, alpha=0.5, label='Background', density=True) + plt.hist(analysis_flat, bins=bins, alpha=0.5, label='Analysis', density=True) + + plt.xlabel('Value') + plt.ylabel('Density') + plt.title(title) + plt.legend() + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.show() + + +def plot_scatter_comparison(true_state, background, analysis, + title="Scatter Plot Comparison"): + """ + Plot scatter comparison showing correlation between true and predicted values + + Args: + true_state: True state values + background: Background state values + analysis: Analysis state values + title: Title for the plot + """ + # Convert to numpy if torch tensors + if torch.is_tensor(true_state): + true_state = true_state.cpu().numpy() + if torch.is_tensor(background): + background = background.cpu().numpy() + if torch.is_tensor(analysis): + analysis = analysis.cpu().numpy() + + fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + + true_flat = true_state.flatten() + bg_flat = background.flatten() + analysis_flat = analysis.flatten() + + # Background vs True + axes[0].scatter(true_flat, bg_flat, alpha=0.5) + min_val = min(true_flat.min(), bg_flat.min()) + max_val = max(true_flat.max(), bg_flat.max()) + axes[0].plot([min_val, max_val], [min_val, max_val], 'r--', lw=2) + axes[0].set_xlabel('True State') + axes[0].set_ylabel('Background') + axes[0].set_title('Background vs True') + axes[0].grid(True, alpha=0.3) + + # Analysis vs True + axes[1].scatter(true_flat, analysis_flat, alpha=0.5) + axes[1].plot([min_val, max_val], [min_val, max_val], 'r--', lw=2) + axes[1].set_xlabel('True State') + axes[1].set_ylabel('Analysis') + axes[1].set_title('Analysis vs True') + axes[1].grid(True, alpha=0.3) + + plt.suptitle(title) + plt.tight_layout() + plt.show() + + +def plot_convergence_analysis(train_losses, val_losses, title="Convergence Analysis"): + """ + Plot detailed convergence analysis + + Args: + train_losses: Training losses + val_losses: Validation losses + title: Title for the plot + """ + fig, axes = plt.subplots(2, 2, figsize=(15, 10)) + + epochs = range(1, len(train_losses) + 1) + + # Training and validation loss + axes[0, 0].plot(epochs, train_losses, label='Training Loss', color='blue') + axes[0, 0].plot(epochs, val_losses, label='Validation Loss', color='red') + axes[0, 0].set_xlabel('Epoch') + axes[0, 0].set_ylabel('Loss') + axes[0, 0].set_title('Training and Validation Loss') + axes[0, 0].legend() + axes[0, 0].grid(True, alpha=0.3) + + # Log scale + axes[0, 1].semilogy(epochs, train_losses, label='Training Loss', color='blue') + axes[0, 1].semilogy(epochs, val_losses, label='Validation Loss', color='red') + axes[0, 1].set_xlabel('Epoch') + axes[0, 1].set_ylabel('Loss (log scale)') + axes[0, 1].set_title('Loss (Log Scale)') + axes[0, 1].legend() + axes[0, 1].grid(True, alpha=0.3) + + # Loss difference + loss_diff = np.array(train_losses) - np.array(val_losses) + axes[1, 0].plot(epochs, loss_diff, color='purple') + axes[1, 0].axhline(y=0, color='black', linestyle='--', alpha=0.5) + axes[1, 0].set_xlabel('Epoch') + axes[1, 0].set_ylabel('Training - Validation Loss') + axes[1, 0].set_title('Overfitting Indicator') + axes[1, 0].grid(True, alpha=0.3) + + # Improvement per epoch + improvement = np.diff(train_losses) + axes[1, 1].plot(epochs[1:], improvement, color='green') + axes[1, 1].set_xlabel('Epoch') + axes[1, 1].set_ylabel('Loss Improvement') + axes[1, 1].set_title('Improvement per Epoch') + axes[1, 1].grid(True, alpha=0.3) + + plt.suptitle(title) + plt.tight_layout() + plt.show() + + +def plot_parameter_analysis(model, title="Parameter Analysis"): + """ + Plot analysis of model parameters + + Args: + model: PyTorch model + title: Title for the plot + """ + param_norms = [] + param_names = [] + + for name, param in model.named_parameters(): + if param.requires_grad: + param_norm = param.data.norm().item() + param_norms.append(param_norm) + param_names.append(name) + + if not param_norms: # Handle case where no parameters require gradients + print("No parameters require gradients to visualize") + return + + plt.figure(figsize=(12, 6)) + bars = plt.bar(range(len(param_names)), param_norms) + plt.xlabel('Parameters') + plt.ylabel('L2 Norm') + plt.title(title) + plt.xticks(range(len(param_names)), [name.split('.')[-1] for name in param_names], + rotation=45, ha='right') + + # Add value labels on bars + for bar, value in zip(bars, param_norms): + plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(param_norms)*0.01, + f'{value:.3f}', ha='center', va='bottom', fontsize=8) + + plt.tight_layout() + plt.show() + + +def create_summary_dashboard(metrics, figsize=(16, 12)): + """ + Create a comprehensive dashboard summarizing all results + + Args: + metrics: Dictionary with all evaluation metrics + figsize: Figure size for the dashboard + """ + fig = plt.figure(figsize=figsize) + + # Define grid for subplots + gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3) + + # 1. Training curves (if available) + if 'train_losses' in metrics and 'val_losses' in metrics: + ax1 = fig.add_subplot(gs[0, 0]) + ax1.plot(metrics['train_losses'], label='Train', color='blue') + ax1.plot(metrics['val_losses'], label='Val', color='red') + ax1.set_title('Training Curves') + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Loss') + ax1.legend() + ax1.grid(True, alpha=0.3) + + # 2. RMSE comparison + ax2 = fig.add_subplot(gs[0, 1]) + rmse_methods = [] + rmse_values = [] + for key, value in metrics.items(): + if 'rmse' in key.lower(): + rmse_methods.append(key.replace('_rmse', '').replace('avg_', '').title()) + rmse_values.append(value) + + if rmse_methods: + ax2.bar(rmse_methods, rmse_values, color=['skyblue', 'lightcoral', 'lightgreen']) + ax2.set_title('RMSE Comparison') + ax2.set_ylabel('RMSE') + ax2.tick_params(axis='x', rotation=45) + + # 3. Correlation comparison + ax3 = fig.add_subplot(gs[0, 2]) + corr_methods = [] + corr_values = [] + for key, value in metrics.items(): + if 'correlation' in key.lower(): + corr_methods.append(key.replace('_correlation', '').replace('avg_', '').title()) + corr_values.append(value) + + if corr_methods: + ax3.bar(corr_methods, corr_values, color=['gold', 'orange']) + ax3.set_title('Correlation Comparison') + ax3.set_ylabel('Correlation') + ax3.tick_params(axis='x', rotation=45) + + # 4. Bias comparison + ax4 = fig.add_subplot(gs[1, 0]) + bias_methods = [] + bias_values = [] + for key, value in metrics.items(): + if 'bias' in key.lower(): + bias_methods.append(key.replace('_bias', '').replace('avg_', '').title()) + bias_values.append(value) + + if bias_methods: + ax4.bar(bias_methods, bias_values, color=['lightblue', 'pink']) + ax4.set_title('Bias Comparison') + ax4.set_ylabel('Bias') + ax4.tick_params(axis='x', rotation=45) + ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5) + + # 5. Improvement metrics + ax5 = fig.add_subplot(gs[1, 1]) + improvement_metrics = [] + improvement_values = [] + for key, value in metrics.items(): + if 'improvement' in key.lower(): + improvement_metrics.append(key.replace('avg_', '').replace('_pct', '%').title()) + improvement_values.append(value) + + if improvement_metrics: + bars = ax5.bar(improvement_metrics, improvement_values, + color=['lightgreen' if v > 0 else 'lightcoral' for v in improvement_values]) + ax5.set_title('Improvement Metrics') + ax5.set_ylabel('Improvement (%)') + ax5.tick_params(axis='x', rotation=45) + ax5.axhline(y=0, color='black', linestyle='--', alpha=0.5) + + # Add value labels + for bar, value in zip(bars, improvement_values): + height = bar.get_height() + ax5.text(bar.get_x() + bar.get_width()/2, height + (max(improvement_values)*0.01 if max(improvement_values) > 0 else min(improvement_values)*0.01), + f'{value:.1f}%', ha='center', va='bottom' if value >= 0 else 'top') + + # 6. Information gain + if 'avg_information_gain' in metrics: + ax6 = fig.add_subplot(gs[1, 2]) + info_gain = metrics['avg_information_gain'] + ax6.bar(['Information Gain'], [info_gain], color='mediumpurple') + ax6.set_title('Information Gain') + ax6.set_ylabel('Gain (%)') + ax6.text(0, info_gain + max(info_gain*0.01, 0.1), f'{info_gain:.1f}%', + ha='center', va='bottom') + + # 7. Parameter norms (if model available) + if 'model' in metrics: + ax7 = fig.add_subplot(gs[2, :]) + param_norms = [] + param_names = [] + for name, param in metrics['model'].named_parameters(): + if param.requires_grad: + param_norm = param.data.norm().item() + param_norms.append(param_norm) + param_names.append(name.split('.')[-1][:10]) # Shorten names + + if param_norms: # Only plot if there are parameters to show + # Only show top parameters to avoid overcrowding + top_indices = np.argsort(param_norms)[-10:][::-1] # Top 10 largest + top_norms = [param_norms[i] for i in top_indices] + top_names = [param_names[i] for i in top_indices] + + ax7.bar(top_names, top_norms, color='lightsteelblue') + ax7.set_title('Top 10 Parameter Norms') + ax7.set_ylabel('L2 Norm') + ax7.tick_params(axis='x', rotation=45) + else: + ax7.text(0.5, 0.5, 'No parameters to display', horizontalalignment='center', + verticalalignment='center', transform=ax7.transAxes, fontsize=14) + ax7.set_title('Parameter Norms') + ax7.set_xticks([]) + ax7.set_yticks([]) + + plt.suptitle('Data Assimilation Results Dashboard', fontsize=16) + plt.show() + + +def visualize_observation_locations(observations, obs_mask, title="Observation Locations"): + """ + Visualize where observations are available + + Args: + observations: Observation tensor + obs_mask: Boolean mask indicating observation locations + title: Title for the plot + """ + plt.figure(figsize=(8, 6)) + + # Create a visualization where observed locations are highlighted + obs_visual = torch.zeros_like(observations[0, 0]) if len(observations.shape) > 2 else torch.zeros_like(observations[0]) + obs_visual[obs_mask] = 1 + + plt.imshow(obs_visual.cpu().numpy(), cmap='viridis', interpolation='none') + plt.title(title) + plt.colorbar(label='Observation Present (1) / Missing (0)') + plt.show() \ No newline at end of file diff --git a/graph_weather/test_data_assimilation.py b/graph_weather/test_data_assimilation.py new file mode 100644 index 00000000..13eaeee3 --- /dev/null +++ b/graph_weather/test_data_assimilation.py @@ -0,0 +1,322 @@ +""" +Test script for the complete self-supervised data assimilation pipeline +""" + +import torch +import numpy as np +from graph_weather.graph_weather.models.data_assimilation import ( + DataAssimilationModel, + SimpleDataAssimilationModel, + ThreeDVarLoss, + generate_synthetic_data +) +from graph_weather.graph_weather.data.assimilation_dataloader import ( + AssimilationDataModule, + create_synthetic_assimilation_dataset +) +from graph_weather.graph_weather.models.training_loop import ( + DataAssimilationTrainer, + train_data_assimilation_model, + compare_with_baselines +) +from graph_weather.graph_weather.models.evaluation import ( + DataAssimilationEvaluator, + compare_methods, + compute_rmse +) +from graph_weather.graph_weather.models.visualization import ( + plot_training_curves, + plot_comparison_grid, + plot_error_maps, + plot_rmse_comparison, + create_summary_dashboard +) + + +def test_basic_3dvar_loss(): + """Test the basic 3D-Var loss function""" + print("Testing 3D-Var loss function...") + + # Create sample data + batch_size, grid_size = 4, (5, 5) + background = torch.randn(batch_size, 1, *grid_size) + observations = torch.randn(batch_size, 1, *grid_size) + analysis = torch.randn(batch_size, 1, *grid_size) + + # Initialize loss function + loss_fn = ThreeDVarLoss() + + # Compute loss + loss = loss_fn(analysis, background, observations) + print(f"3D-Var loss: {loss.item():.4f}") + + # Test with custom covariances + B = torch.eye(grid_size[0] * grid_size[1]) * 0.5 + R = torch.eye(grid_size[0] * grid_size[1]) * 0.3 + loss_fn_custom = ThreeDVarLoss( + background_error_covariance=B, + observation_error_covariance=R + ) + + loss_custom = loss_fn_custom(analysis, background, observations) + print(f"3D-Var loss with custom covariances: {loss_custom.item():.4f}") + + print("✓ 3D-Var loss test passed\n") + + +def test_data_assimilation_model(): + """Test the data assimilation model""" + print("Testing data assimilation model...") + + # Test simple FC model + input_dim = 50 # 5x5 grid with 2 channels (bg + obs) + model = DataAssimilationModel(input_dim=input_dim) + + batch_size = 4 + background = torch.randn(batch_size, input_dim // 2) + observations = torch.randn(batch_size, input_dim // 2) + + analysis = model(background, observations) + print(f"FC Model - Input shape: {background.shape}, Output shape: {analysis.shape}") + + # Test convolutional model + grid_size = (5, 5) + model_conv = SimpleDataAssimilationModel(grid_size=grid_size, num_channels=1) + + background_conv = torch.randn(batch_size, 1, *grid_size) + observations_conv = torch.randn(batch_size, 1, *grid_size) + + analysis_conv = model_conv(background_conv, observations_conv) + print(f"Conv Model - Input shape: {background_conv.shape}, Output shape: {analysis_conv.shape}") + + print("✓ Data assimilation model test passed\n") + + +def test_training_pipeline(): + """Test the complete training pipeline""" + print("Testing training pipeline...") + + # Create synthetic data + data_module = AssimilationDataModule( + num_samples=200, + grid_size=(8, 8), + num_channels=1, + bg_error_std=0.5, + obs_error_std=0.3, + obs_fraction=0.6, + batch_size=16 + ) + data_module.setup() + + # Initialize model and loss + model = SimpleDataAssimilationModel( + grid_size=(8, 8), + num_channels=1, + hidden_dim=32, + num_layers=2 + ) + + # Train the model + trainer, results = train_data_assimilation_model( + model=model, + train_loader=data_module.train_dataloader(), + val_loader=data_module.val_dataloader(), + epochs=10, # Small number for testing + lr=1e-3, + device='cpu' + ) + + print(f"Final training loss: {results['train_losses'][-1]:.4f}") + print(f"Final validation loss: {results['val_losses'][-1]:.4f}") + + # Plot training curves + plot_training_curves( + results['train_losses'], + results['val_losses'], + title="Test Training Curves" + ) + + print("✓ Training pipeline test passed\n") + + return trainer, data_module + + +def test_evaluation_pipeline(trainer, data_module): + """Test the evaluation pipeline""" + print("Testing evaluation pipeline...") + + # Evaluate the trained model + eval_results = trainer.evaluate_model(data_module.test_dataloader(), compute_metrics=True) + + print("Evaluation Results:") + for key, value in eval_results.items(): + print(f" {key}: {value:.4f}") + + # Initialize evaluator + evaluator = DataAssimilationEvaluator(trainer.model, device='cpu') + overall_metrics = evaluator.evaluate_dataset(data_module.test_dataloader()) + + print("\nOverall Metrics:") + for key, value in overall_metrics.items(): + print(f" {key}: {value:.4f}") + + print("✓ Evaluation pipeline test passed\n") + + return overall_metrics + + +def test_comparison_with_baselines(trainer, data_module): + """Test comparison with classical baselines""" + print("Testing comparison with baselines...") + + comparison = compare_with_baselines( + trainer.model, + data_module.test_dataloader(), + device='cpu' + ) + + print("Baseline Comparison Results:") + for key, value in comparison.items(): + print(f" {key}: {value:.4f}") + + # Create RMSE comparison plot + rmse_comparison = { + 'Analysis': comparison['avg_analysis_rmse'], + 'Background': comparison['avg_background_rmse'], + 'Observations': comparison['avg_observation_rmse'], + 'Persistence': comparison['avg_persistence_rmse'] + } + + plot_rmse_comparison(rmse_comparison, title="RMSE Comparison with Baselines") + + print("✓ Baseline comparison test passed\n") + + return comparison + + +def test_visualization_pipeline(): + """Test visualization capabilities""" + print("Testing visualization pipeline...") + + # Generate sample data for visualization + batch_size, grid_size = 1, (6, 6) + background, observations, true_state = generate_synthetic_data( + batch_size=batch_size, + grid_size=grid_size, + num_channels=1 + ) + + # Create a simple "analysis" (for demonstration) + analysis = (background + observations) / 2 # Simple average + + # Test comparison grid + plot_comparison_grid( + background, observations, analysis, true_state, + titles=['Background', 'Observations', 'Analysis', 'True State'] + ) + + # Test error maps + plot_error_maps( + background, observations, analysis, true_state, + titles=['Background Error', 'Observation Error', 'Analysis Error'] + ) + + print("✓ Visualization pipeline test passed\n") + + +def run_comprehensive_test(): + """Run a comprehensive test of the entire pipeline""" + print("="*60) + print("COMPREHENSIVE TEST: Self-Supervised Data Assimilation Pipeline") + print("="*60) + + # Test 1: Basic components + test_basic_3dvar_loss() + + # Test 2: Model architecture + test_data_assimilation_model() + + # Test 3: Training pipeline + trainer, data_module = test_training_pipeline() + + # Test 4: Evaluation pipeline + eval_metrics = test_evaluation_pipeline(trainer, data_module) + + # Test 5: Baseline comparison + comparison_results = test_comparison_with_baselines(trainer, data_module) + + # Test 6: Visualization + test_visualization_pipeline() + + # Final summary dashboard + print("Creating summary dashboard...") + summary_metrics = eval_metrics.copy() + summary_metrics.update(comparison_results) + summary_metrics['model'] = trainer.model # Include model for parameter analysis + + create_summary_dashboard(summary_metrics) + + print("\n" + "="*60) + print("ALL TESTS COMPLETED SUCCESSFULLY!") + print("="*60) + + # Print key results + print(f"\nKey Results:") + print(f"- Analysis RMSE: {comparison_results['avg_analysis_rmse']:.4f}") + print(f"- Background RMSE: {comparison_results['avg_background_rmse']:.4f}") + print(f"- Analysis improvement over background: {comparison_results['analysis_improvement_over_bg']:.2f}%") + print(f"- Analysis improvement over observations: {comparison_results['analysis_improvement_over_obs']:.2f}%") + + return True + + +def test_different_training_modes(): + """Test the model under different training conditions""" + print("\n" + "="*60) + print("TESTING DIFFERENT TRAINING MODES") + print("="*60) + + from graph_weather.graph_weather.models.training_loop import train_with_different_modes + + # Test with different configurations + results = train_with_different_modes( + model_class=lambda **kwargs: SimpleDataAssimilationModel( + grid_size=(6, 6), + num_channels=1, + hidden_dim=16, + num_layers=2 + ), + data_module=AssimilationDataModule, + grid_size=(6, 6), + num_channels=1, + epochs=5, # Few epochs for testing + lr=1e-3, + device='cpu' + ) + + print("\nTraining Mode Results:") + for mode, result in results.items(): + eval_res = result['eval_results'] + print(f"\n{mode.upper()} MODE:") + print(f" Analysis RMSE: {eval_res.get('analysis_rmse', 'N/A')}") + print(f" Background RMSE: {eval_res.get('background_rmse', 'N/A')}") + print(f" Improvement over background: {eval_res.get('improvement_over_bg', 'N/A')}%") + + return results + + +if __name__ == "__main__": + # Set random seed for reproducibility + torch.manual_seed(42) + np.random.seed(42) + + # Run comprehensive test + success = run_comprehensive_test() + + # Test different training modes + mode_results = test_different_training_modes() + + print("\n" + "="*60) + print("ALL TESTS COMPLETED SUCCESSFULLY!") + print("Self-Supervised Data Assimilation Pipeline is working correctly.") + print("="*60) \ No newline at end of file From 0974c4d0618702a93d5b06d468167996c4713637 Mon Sep 17 00:00:00 2001 From: SOHAMPAL23 Date: Tue, 6 Jan 2026 16:22:04 +0530 Subject: [PATCH 2/6] Refactor import statements in __init__.py --- graph_weather/data/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index 052921fa..ff366609 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1,10 +1,6 @@ """Dataloaders and data processing utilities""" -try: - from .anemoi_dataloader import AnemoiDataset -except ImportError: - # anemoi library not available, skip this import - pass +from .anemoi_dataloader import AnemoiDataset from .nnja_ai import SensorDataset from .weather_station_reader import WeatherStationReader from .assimilation_dataloader import AssimilationDataset, AssimilationDataModule From cfd69a8855025868832f03b3bcbb74d6be8aa70b Mon Sep 17 00:00:00 2001 From: SOHAMPAL23 Date: Sat, 17 Jan 2026 17:12:28 +0530 Subject: [PATCH 3/6] Data Assimilation --- .../models/data_assimilation/__init__.py | 14 +++ .../data_assimilation_base.py | 97 +++++++++++++++++++ .../test_data_assimilation_base.py | 95 ++++++++++++++++++ 3 files changed, 206 insertions(+) create mode 100644 graph_weather/models/data_assimilation/__init__.py create mode 100644 graph_weather/models/data_assimilation/data_assimilation_base.py create mode 100644 tests/models/data_assimilation/test_data_assimilation_base.py diff --git a/graph_weather/models/data_assimilation/__init__.py b/graph_weather/models/data_assimilation/__init__.py new file mode 100644 index 00000000..a0a120be --- /dev/null +++ b/graph_weather/models/data_assimilation/__init__.py @@ -0,0 +1,14 @@ +"""Data assimilation module initialization.""" + +from .data_assimilation_base import DataAssimilationBase, EnsembleGenerator +from .kalman_filter_da import KalmanFilterDA +from .particle_filter_da import ParticleFilterDA +from .variational_da import VariationalDA + +__all__ = [ + 'DataAssimilationBase', + 'EnsembleGenerator', + 'KalmanFilterDA', + 'ParticleFilterDA', + 'VariationalDA' +] \ No newline at end of file diff --git a/graph_weather/models/data_assimilation/data_assimilation_base.py b/graph_weather/models/data_assimilation/data_assimilation_base.py new file mode 100644 index 00000000..75b67174 --- /dev/null +++ b/graph_weather/models/data_assimilation/data_assimilation_base.py @@ -0,0 +1,97 @@ +"""Base classes for data assimilation modules.""" +import abc +from typing import Union, Dict, Any, Optional +import torch +from torch_geometric.data import Data + + +class EnsembleGenerator: + """Class to generate ensemble members from a background state.""" + + def __init__(self, noise_std: float = 0.1, method: str = "gaussian"): + self.noise_std = noise_std + self.method = method + + def generate_ensemble(self, state: Union[torch.Tensor, Data], num_members: int): + if isinstance(state, torch.Tensor): + return self._generate_tensor_ensemble(state, num_members) + elif isinstance(state, Data): + return self._generate_graph_ensemble(state, num_members) + else: + raise TypeError(f"Unsupported state type: {type(state)}") + + def _generate_tensor_ensemble(self, state: torch.Tensor, num_members: int) -> torch.Tensor: + batch_size, nodes, features = state.shape + ensemble = torch.zeros(batch_size, num_members, nodes, features, device=state.device) + + for i in range(num_members): + if self.method == "gaussian": + noise = torch.randn_like(state) * self.noise_std + ensemble[:, i] = state + noise + elif self.method == "dropout": + mask = torch.bernoulli(torch.ones_like(state) * 0.9) # Keep 90% of values + noise = torch.randn_like(state) * self.noise_std * 0.1 + ensemble[:, i] = (state * mask) + noise + elif self.method == "perturbation": + perturbation = torch.randn_like(state) * self.noise_std * torch.linspace(0.1, 1.0, num_members)[i] + ensemble[:, i] = state + perturbation + else: + raise ValueError(f"Unknown ensemble generation method: {self.method}") + + return ensemble + + def _generate_graph_ensemble(self, state: Data, num_members: int) -> Data: + x_expanded = torch.zeros(state.x.shape[0], num_members, state.x.shape[1], device=state.x.device) + + for i in range(num_members): + if self.method == "gaussian": + noise = torch.randn_like(state.x) * self.noise_std + x_expanded[:, i] = state.x + noise + elif self.method == "dropout": + mask = torch.bernoulli(torch.ones_like(state.x) * 0.9) + noise = torch.randn_like(state.x) * self.noise_std * 0.1 + x_expanded[:, i] = (state.x * mask) + noise + elif self.method == "perturbation": + perturbation = torch.randn_like(state.x) * self.noise_std * torch.linspace(0.1, 1.0, num_members)[i] + x_expanded[:, i] = state.x + perturbation + else: + raise ValueError(f"Unknown ensemble generation method: {self.method}") + + new_state = Data( + x=x_expanded, + edge_index=state.edge_index, + edge_attr=getattr(state, 'edge_attr', None), + pos=getattr(state, 'pos', None) + ) + + return new_state + + +class DataAssimilationBase(abc.ABC): + """Abstract base class for data assimilation modules.""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.ensemble_generator = EnsembleGenerator( + noise_std=config.get('noise_std', 0.1), + method=config.get('ensemble_method', 'gaussian') + ) + + @abc.abstractmethod + def initialize_ensemble(self, background_state: Union[torch.Tensor, Data], num_members: int): + pass + + @abc.abstractmethod + def assimilate(self, ensemble: Union[torch.Tensor, Data], observations: torch.Tensor): + pass + + @abc.abstractmethod + def _compute_analysis(self, ensemble: Union[torch.Tensor, Data]) -> Union[torch.Tensor, Data]: + pass + + def forward(self, state: Union[torch.Tensor, Data], observations: torch.Tensor, num_ensemble: int = 10): + ensemble = self.initialize_ensemble(state, num_ensemble) + updated_ensemble = self.assimilate(ensemble, observations) + analysis = self._compute_analysis(updated_ensemble) + + return updated_ensemble, analysis \ No newline at end of file diff --git a/tests/models/data_assimilation/test_data_assimilation_base.py b/tests/models/data_assimilation/test_data_assimilation_base.py new file mode 100644 index 00000000..cb71a940 --- /dev/null +++ b/tests/models/data_assimilation/test_data_assimilation_base.py @@ -0,0 +1,95 @@ +import pytest +import torch +from torch_geometric.data import Data + +import sys +sys.path.insert(0, '../../../graph_weather/models/data_assimilation') + +# Execute modules directly to avoid import issues +exec(open('graph_weather/models/data_assimilation/data_assimilation_base.py').read()) + + +class MockDA(DataAssimilationBase): + """Mock implementation of DataAssimilationBase for testing purposes.""" + + def initialize_ensemble(self, background_state, num_members): + return self.ensemble_generator.generate_ensemble(background_state, num_members) + + def assimilate(self, ensemble, observations): + return ensemble # Return unchanged for testing + + def _compute_analysis(self, ensemble): + if isinstance(ensemble, torch.Tensor): + return torch.mean(ensemble, dim=1) + elif isinstance(ensemble, Data): + return ensemble # Return as is for testing + else: + raise TypeError(f"Unsupported ensemble type: {type(ensemble)}") + + +def test_ensemble_generator_tensor(): + """Test ensemble generation for tensor inputs.""" + generator = EnsembleGenerator(noise_std=0.1, method="gaussian") + + # Test tensor input + state = torch.randn(2, 5, 3) # [batch, nodes, features] + ensemble = generator.generate_ensemble(state, 4) + + assert ensemble.shape == (2, 4, 5, 3) # [batch, members, nodes, features] + assert not torch.equal(state, ensemble[:, 0]) # Should have noise added + + +def test_ensemble_generator_graph(): + """Test ensemble generation for graph inputs.""" + generator = EnsembleGenerator(noise_std=0.1, method="gaussian") + + # Test graph input + x = torch.randn(10, 4) # Node features + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) + graph_state = Data(x=x, edge_index=edge_index) + + ensemble = generator.generate_ensemble(graph_state, 3) + + # Check that ensemble preserves structure + assert hasattr(ensemble, 'x') + assert hasattr(ensemble, 'edge_index') + assert ensemble.x.shape[1] == 3 # Ensemble dimension + + +def test_data_assimilation_base_abstract_methods(): + """Test that abstract methods are properly defined.""" + config = {"param": "value"} + da_module = MockDA(config) + + assert da_module.config == config + + # Test ensemble generation + state = torch.randn(2, 5, 3) + ensemble = da_module.initialize_ensemble(state, 4) + assert ensemble.shape == (2, 4, 5, 3) + + +def test_compute_analysis_tensor(): + """Test analysis computation for tensor ensembles.""" + da_module = MockDA({}) + + # Create ensemble: [batch, members, nodes, features] + ensemble = torch.stack([ + torch.ones(2, 5, 3), # First member + 2 * torch.ones(2, 5, 3), # Second member + 3 * torch.ones(2, 5, 3), # Third member + ], dim=1) # Shape: [2, 3, 5, 3] + + analysis = da_module._compute_analysis(ensemble) + + # Mean should be (1 + 2 + 3) / 3 = 2 + expected = 2 * torch.ones(2, 5, 3) + assert torch.allclose(analysis, expected) + + +if __name__ == "__main__": + test_ensemble_generator_tensor() + test_ensemble_generator_graph() + test_data_assimilation_base_abstract_methods() + test_compute_analysis_tensor() + print("All tests passed!") \ No newline at end of file From 7d4791c480801ef5c2bda851e8d8492a5aa62da9 Mon Sep 17 00:00:00 2001 From: SOHAMPAL23 Date: Sun, 18 Jan 2026 16:20:53 +0530 Subject: [PATCH 4/6] Revert "feat: add self-supervised 3D-Var-based AI data assimilation prototype" This reverts commit 15abececb430816f3654e6cab8573a14e95e2a40. --- graph_weather/__init__.py | 1 - graph_weather/data/__init__.py | 1 - graph_weather/data/assimilation_dataloader.py | 288 --------- .../data_assimilation_implementation.md | 165 ----- graph_weather/example_usage.py | 170 ----- graph_weather/models/__init__.py | 1 - graph_weather/models/data_assimilation.py | 364 ----------- graph_weather/models/evaluation.py | 439 ------------- graph_weather/models/training_loop.py | 538 ---------------- graph_weather/models/visualization.py | 582 ------------------ graph_weather/test_data_assimilation.py | 322 ---------- 11 files changed, 2871 deletions(-) delete mode 100644 graph_weather/data/assimilation_dataloader.py delete mode 100644 graph_weather/data_assimilation_implementation.md delete mode 100644 graph_weather/example_usage.py delete mode 100644 graph_weather/models/data_assimilation.py delete mode 100644 graph_weather/models/evaluation.py delete mode 100644 graph_weather/models/training_loop.py delete mode 100644 graph_weather/models/visualization.py delete mode 100644 graph_weather/test_data_assimilation.py diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index 2fabe11e..b33e23cd 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,7 +1,6 @@ """Main import for the complete models""" from .data.nnja_ai import SensorDataset -from .data.assimilation_dataloader import AssimilationDataset, AssimilationDataModule from .data.weather_station_reader import WeatherStationReader from .models.analysis import GraphWeatherAssimilator from .models.forecast import GraphWeatherForecaster diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index ff366609..d67a79e4 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -3,4 +3,3 @@ from .anemoi_dataloader import AnemoiDataset from .nnja_ai import SensorDataset from .weather_station_reader import WeatherStationReader -from .assimilation_dataloader import AssimilationDataset, AssimilationDataModule diff --git a/graph_weather/data/assimilation_dataloader.py b/graph_weather/data/assimilation_dataloader.py deleted file mode 100644 index 412f43f1..00000000 --- a/graph_weather/data/assimilation_dataloader.py +++ /dev/null @@ -1,288 +0,0 @@ -""" -Data loader for self-supervised data assimilation framework -""" - -import torch -from torch.utils.data import Dataset, DataLoader -import numpy as np - - -class AssimilationDataset(Dataset): - """ - Dataset for self-supervised data assimilation - Each sample contains background state and observations - """ - - def __init__(self, background_states, observations, true_states=None): - """ - Initialize the assimilation dataset - - Args: - background_states: Background states (x_b) - observations: Observations (y) - true_states: True states (for evaluation only, not used in training) - """ - self.background_states = background_states - self.observations = observations - self.true_states = true_states - - assert len(background_states) == len(observations), \ - "Background and observation arrays must have same length" - - if true_states is not None: - assert len(true_states) == len(background_states), \ - "True states must have same length as background states" - - def __len__(self): - return len(self.background_states) - - def __getitem__(self, idx): - bg = self.background_states[idx] - obs = self.observations[idx] - - sample = { - 'background': bg, - 'observations': obs - } - - if self.true_states is not None: - sample['true_state'] = self.true_states[idx] - - return sample - - -def create_synthetic_assimilation_dataset( - num_samples=1000, - grid_size=(10, 10), - num_channels=1, - bg_error_std=0.5, - obs_error_std=0.3, - obs_fraction=0.5 -): - """ - Create a synthetic dataset for data assimilation experiments - - Args: - num_samples: Number of samples to generate - grid_size: Size of spatial grid - num_channels: Number of variables/channels - bg_error_std: Standard deviation of background errors - obs_error_std: Standard deviation of observation errors - obs_fraction: Fraction of grid points that have observations - - Returns: - dataset: AssimilationDataset object - """ - total_size = np.prod(grid_size) if isinstance(grid_size, (tuple, list)) else grid_size - - # Generate true states with spatial correlation - true_states = torch.randn(num_samples, num_channels, *grid_size) - - # Apply spatial smoothing to create realistic fields - if len(grid_size) == 2: - # Create a Gaussian smoothing kernel - kernel_size = 5 - sigma = 1.0 - kernel = torch.zeros(kernel_size, kernel_size) - center = kernel_size // 2 - - for i in range(kernel_size): - for j in range(kernel_size): - x, y = i - center, j - center - kernel[i, j] = np.exp(-(x**2 + y**2) / (2 * sigma**2)) - - kernel = kernel / kernel.sum() - kernel = kernel.unsqueeze(0).unsqueeze(0).repeat(num_channels, 1, 1, 1) - - # Apply smoothing to each sample and channel - for i in range(num_samples): - for c in range(num_channels): - smoothed = torch.nn.functional.conv2d( - true_states[i:i+1, c:c+1], - kernel, - padding=kernel_size//2, - groups=1 - ) - true_states[i, c:c+1] = smoothed - - # Create background states with errors - bg_errors = torch.randn_like(true_states) * bg_error_std - background_states = true_states + bg_errors - - # Create observations with errors - obs_errors = torch.randn_like(true_states) * obs_error_std - observations = true_states + obs_errors - - # Optionally mask some observations based on obs_fraction - if obs_fraction < 1.0: - mask = torch.rand_like(observations) < obs_fraction - observations = observations * mask - - dataset = AssimilationDataset(background_states, observations, true_states) - return dataset - - -def get_assimilation_data_loaders( - dataset, - batch_size=32, - train_ratio=0.7, - val_ratio=0.2, - test_ratio=0.1, - shuffle=True -): - """ - Create train/validation/test data loaders from dataset - - Args: - dataset: AssimilationDataset object - batch_size: Size of batches - train_ratio: Fraction of data for training - val_ratio: Fraction of data for validation - test_ratio: Fraction of data for testing - shuffle: Whether to shuffle the data - - Returns: - train_loader, val_loader, test_loader: Data loaders - """ - total_size = len(dataset) - train_size = int(train_ratio * total_size) - val_size = int(val_ratio * total_size) - test_size = total_size - train_size - val_size - - # Split the dataset - train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( - dataset, [train_size, val_size, test_size] - ) - - # Create data loaders - train_loader = DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=shuffle - ) - - val_loader = DataLoader( - val_dataset, - batch_size=batch_size, - shuffle=False - ) - - test_loader = DataLoader( - test_dataset, - batch_size=batch_size, - shuffle=False - ) - - return train_loader, val_loader, test_loader - - -def create_observation_mask(grid_size, obs_fraction=0.5, seed=None): - """ - Create a mask indicating which grid points have observations - - Args: - grid_size: Size of the grid (can be int for 1D or tuple for 2D) - obs_fraction: Fraction of grid points that have observations - seed: Random seed for reproducibility - - Returns: - mask: Boolean mask indicating observation locations - """ - if seed is not None: - np.random.seed(seed) - torch.manual_seed(seed) - - total_size = np.prod(grid_size) if isinstance(grid_size, (tuple, list)) else grid_size - num_obs = int(total_size * obs_fraction) - - # Create random indices for observation locations - obs_indices = np.random.choice(total_size, size=num_obs, replace=False) - - # Create mask - mask_flat = torch.zeros(total_size, dtype=torch.bool) - mask_flat[obs_indices] = True - - if isinstance(grid_size, (tuple, list)): - mask = mask_flat.view(grid_size) - else: - mask = mask_flat - - return mask - - -def apply_observation_operator(data, obs_mask): - """ - Apply observation operator to extract observed values from state - - Args: - data: Full state data - obs_mask: Boolean mask indicating observation locations - - Returns: - observed_data: Data at observation locations only - """ - if len(data.shape) > 2: # Spatial data - batch_size = data.size(0) - reshaped_data = data.view(batch_size, -1) # Flatten spatial dimensions - observed_flat = reshaped_data * obs_mask.view(-1).float() - return observed_flat.view_as(data) - else: - return data * obs_mask.float() - - -class AssimilationDataModule: - """ - A PyTorch Lightning-style data module for assimilation data - """ - - def __init__( - self, - num_samples=1000, - grid_size=(10, 10), - num_channels=1, - bg_error_std=0.5, - obs_error_std=0.3, - obs_fraction=0.5, - batch_size=32, - train_ratio=0.7, - val_ratio=0.2, - test_ratio=0.1 - ): - self.num_samples = num_samples - self.grid_size = grid_size - self.num_channels = num_channels - self.bg_error_std = bg_error_std - self.obs_error_std = obs_error_std - self.obs_fraction = obs_fraction - self.batch_size = batch_size - self.train_ratio = train_ratio - self.val_ratio = val_ratio - self.test_ratio = test_ratio - - def setup(self, stage=None): - """Setup the dataset""" - self.dataset = create_synthetic_assimilation_dataset( - num_samples=self.num_samples, - grid_size=self.grid_size, - num_channels=self.num_channels, - bg_error_std=self.bg_error_std, - obs_error_std=self.obs_error_std, - obs_fraction=self.obs_fraction - ) - - self.train_loader, self.val_loader, self.test_loader = get_assimilation_data_loaders( - self.dataset, - batch_size=self.batch_size, - train_ratio=self.train_ratio, - val_ratio=self.val_ratio, - test_ratio=self.test_ratio - ) - - def train_dataloader(self): - return self.train_loader - - def val_dataloader(self): - return self.val_loader - - def test_dataloader(self): - return self.test_loader \ No newline at end of file diff --git a/graph_weather/data_assimilation_implementation.md b/graph_weather/data_assimilation_implementation.md deleted file mode 100644 index 7c68c35e..00000000 --- a/graph_weather/data_assimilation_implementation.md +++ /dev/null @@ -1,165 +0,0 @@ -# Self-Supervised Data Assimilation Framework with 3D-Var Loss - -## Overview - -This implementation provides a complete self-supervised data assimilation framework that learns to produce analysis states by minimizing the 3D-Var cost function without using ground-truth labels. The system consists of neural networks that take background states and observations as input and produce optimal analysis states. - -## Core Components - -### 1. 3D-Var Loss Function (`data_assimilation.py`) - -The `ThreeDVarLoss` class implements the core 3D-Var objective function: - -``` -J(x) = (x - x_b)^T B^{-1} (x - x_b) + (y - Hx)^T R^{-1} (y - Hx) -``` - -Where: -- `x`: analysis state (model output) -- `x_b`: background state (first guess) -- `y`: observations -- `B`: background error covariance -- `R`: observation error covariance -- `H`: observation operator - -Key features: -- Supports custom background and observation error covariances -- Handles different observation operators -- Works with both fully connected and convolutional models -- Self-supervised (no ground-truth required) - -### 2. Data Assimilation Models (`data_assimilation.py`) - -Two model architectures are provided: - -#### DataAssimilationModel -- Fully connected neural network -- Takes concatenated background and observations as input -- Produces analysis state as output -- Configurable hidden dimensions and layers - -#### SimpleDataAssimilationModel -- Convolutional neural network for spatial data -- Works with 1D/2D grid data -- Preserves spatial relationships -- Efficient for gridded meteorological data - -### 3. Data Pipeline (`assimilation_dataloader.py`) - -- `AssimilationDataset`: Dataset class for background/observation pairs -- `AssimilationDataModule`: PyTorch Lightning-style data module -- Synthetic data generation with spatial correlation -- Observation masking and operator creation -- Train/validation/test splitting - -### 4. Training Framework (`training_loop.py`) - -- `DataAssimilationTrainer`: Complete training loop with validation -- Self-supervised training using 3D-Var loss -- Learning rate scheduling -- Model checkpointing -- Multi-mode training (good/poor background, sparse observations) - -### 5. Evaluation Metrics (`evaluation.py`) - -Comprehensive evaluation including: -- RMSE, MAE, bias calculations -- Correlation coefficients -- Spatial metrics -- Information gain -- Baseline comparisons -- Cross-validation - -### 6. Visualization Tools (`visualization.py`) - -- Training curves plotting -- Comparison grids (background, observations, analysis, true state) -- Error maps visualization -- RMSE comparisons -- Heatmaps and scatter plots -- Comprehensive dashboard - -## Key Features - -### Self-Supervised Learning -- No ground-truth labels required -- Physics-based loss function -- Learns optimal combination of background and observations - -### Flexible Architecture -- Works with different grid sizes -- Supports multiple channels/variables -- Configurable network depth and width -- Multiple activation functions - -### Multiple Training Modes -- With good first guess (low background error) -- With poor first guess (cold start) -- With varying observation densities -- Different error covariance specifications - -### Comprehensive Evaluation -- Comparison with classical baselines -- Improvement metrics -- Spatial analysis -- Statistical validation - -## Usage Example - -```python -from graph_weather.graph_weather.models.data_assimilation import SimpleDataAssimilationModel, ThreeDVarLoss -from graph_weather.graph_weather.data.assimilation_dataloader import AssimilationDataModule -from graph_weather.graph_weather.models.training_loop import train_data_assimilation_model - -# Create data module -data_module = AssimilationDataModule( - grid_size=(16, 16), - num_channels=1, - bg_error_std=0.5, - obs_error_std=0.3, - obs_fraction=0.6 -) -data_module.setup() - -# Initialize model -model = SimpleDataAssimilationModel( - grid_size=(16, 16), - num_channels=1, - hidden_dim=64, - num_layers=3 -) - -# Train model -trainer, results = train_data_assimilation_model( - model=model, - train_loader=data_module.train_dataloader(), - val_loader=data_module.val_dataloader(), - epochs=100, - lr=1e-3 -) -``` - -## Mathematical Foundation - -The 3D-Var cost function is based on Bayesian estimation theory: - -``` -J(x) = (x - x_b)^T B^{-1} (x - x_b) + (y - Hx)^T R^{-1} (y - Hx) -``` - -Where the first term represents the background constraint and the second term represents the observation constraint. The neural network learns to find the optimal balance between these constraints without explicit supervision. - -## Advantages Over Classical Methods - -1. **Learned Error Covariances**: The neural network can learn complex, non-linear relationships -2. **Adaptive Combination**: Automatically adjusts weighting based on data quality -3. **Scalability**: Can handle high-dimensional state spaces efficiently -4. **End-to-End Learning**: Optimizes the complete assimilation process - -## Applications - -This framework is suitable for: -- Weather forecasting data assimilation -- Climate model state estimation -- Oceanographic data assimilation -- Any physical system with background models and observations \ No newline at end of file diff --git a/graph_weather/example_usage.py b/graph_weather/example_usage.py deleted file mode 100644 index 8ab17381..00000000 --- a/graph_weather/example_usage.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Example usage of the self-supervised data assimilation framework -""" - -import torch -import numpy as np -from graph_weather.graph_weather.models.data_assimilation import ( - SimpleDataAssimilationModel, - ThreeDVarLoss -) -from graph_weather.graph_weather.data.assimilation_dataloader import ( - AssimilationDataModule -) -from graph_weather.graph_weather.models.training_loop import ( - train_data_assimilation_model -) -from graph_weather.graph_weather.models.evaluation import ( - DataAssimilationEvaluator -) -from graph_weather.graph_weather.models.visualization import ( - plot_training_curves, - plot_comparison_grid, - plot_error_maps -) - - -def main(): - print("Self-Supervised Data Assimilation Example") - print("="*50) - - # Set random seed for reproducibility - torch.manual_seed(42) - np.random.seed(42) - - # 1. Define the problem setup - print("\n1. Setting up the problem...") - grid_size = (12, 12) # 12x12 spatial grid - num_channels = 1 # Single variable (e.g., temperature) - batch_size = 16 - epochs = 20 # Small number for demo - - print(f"Grid size: {grid_size}") - print(f"Number of channels: {num_channels}") - print(f"Batch size: {batch_size}") - print(f"Training epochs: {epochs}") - - # 2. Create data module - print("\n2. Creating data module...") - data_module = AssimilationDataModule( - num_samples=500, # Number of training samples - grid_size=grid_size, - num_channels=num_channels, - bg_error_std=0.5, # Background error standard deviation - obs_error_std=0.3, # Observation error standard deviation - obs_fraction=0.6, # 60% of grid points have observations - batch_size=batch_size - ) - data_module.setup() - - print(f"Training samples: {len(data_module.train_dataloader().dataset)}") - print(f"Validation samples: {len(data_module.val_dataloader().dataset)}") - print(f"Test samples: {len(data_module.test_dataloader().dataset)}") - - # 3. Initialize the model - print("\n3. Initializing the model...") - model = SimpleDataAssimilationModel( - grid_size=grid_size, - num_channels=num_channels, - hidden_dim=32, # Hidden dimension for conv layers - num_layers=2 # Number of processing layers - ) - - print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") - - # 4. Train the model - print("\n4. Training the model...") - print("This will minimize the 3D-Var cost function without ground truth!") - - trainer, results = train_data_assimilation_model( - model=model, - train_loader=data_module.train_dataloader(), - val_loader=data_module.val_dataloader(), - epochs=epochs, - lr=1e-3, - device='cpu' # Using CPU for this example - ) - - print(f"Final training loss: {results['train_losses'][-1]:.6f}") - print(f"Final validation loss: {results['val_losses'][-1]:.6f}") - - # 5. Plot training curves - print("\n5. Plotting training curves...") - plot_training_curves( - results['train_losses'], - results['val_losses'], - title="Data Assimilation Training Curves" - ) - - # 6. Evaluate the model - print("\n6. Evaluating the model...") - evaluator = DataAssimilationEvaluator(model, device='cpu') - eval_metrics = evaluator.evaluate_dataset(data_module.test_dataloader()) - - print("Evaluation Metrics:") - for key, value in eval_metrics.items(): - if 'avg_' in key and ('rmse' in key or 'mae' in key or 'bias' in key or 'correlation' in key): - print(f" {key}: {value:.4f}") - - # 7. Visualize results - print("\n7. Visualizing results...") - - # Get a batch from test data for visualization - test_iter = iter(data_module.test_dataloader()) - batch = next(test_iter) - - background = batch['background'] - observations = batch['observations'] - - # Generate analysis using the trained model - model.eval() - with torch.no_grad(): - analysis = model(background, observations) - - # If true state is available, visualize comparison - if 'true_state' in batch: - true_state = batch['true_state'] - - print("Creating comparison visualization...") - plot_comparison_grid( - background, observations, analysis, true_state, - titles=['Background', 'Observations', 'Analysis', 'True State'] - ) - - print("Creating error maps...") - plot_error_maps( - background, observations, analysis, true_state, - titles=['Background Error', 'Observation Error', 'Analysis Error'] - ) - - # 8. Compare with baselines - print("\n8. Comparing with baselines...") - from graph_weather.graph_weather.models.training_loop import compare_with_baselines - - comparison = compare_with_baselines( - model, - data_module.test_dataloader(), - device='cpu' - ) - - print("Baseline Comparison:") - for key, value in comparison.items(): - print(f" {key}: {value:.4f}") - - # 9. Summary - print("\n" + "="*50) - print("SUMMARY") - print("="*50) - print(f"✓ Successfully trained a self-supervised data assimilation model") - print(f"✓ Model learned to minimize 3D-Var cost function without ground truth") - print(f"✓ Analysis RMSE: {comparison['avg_analysis_rmse']:.4f}") - print(f"✓ Background RMSE: {comparison['avg_background_rmse']:.4f}") - print(f"✓ Analysis improvement over background: {comparison['analysis_improvement_over_bg']:.2f}%") - print(f"✓ Analysis improvement over observations: {comparison['analysis_improvement_over_obs']:.2f}%") - - print(f"\nThe model successfully learned to combine background and observations") - print(f"optimally to produce better analysis states than either input alone!") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 307e88d2..3710e24a 100755 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -13,4 +13,3 @@ from .layers.encoder import Encoder from .layers.processor import Processor from .layers.stochastic_decomposition import StochasticDecompositionLayer -from .data_assimilation import DataAssimilationModel, ThreeDVarLoss diff --git a/graph_weather/models/data_assimilation.py b/graph_weather/models/data_assimilation.py deleted file mode 100644 index 3c463b99..00000000 --- a/graph_weather/models/data_assimilation.py +++ /dev/null @@ -1,364 +0,0 @@ -""" -Self-Supervised Data Assimilation Framework with 3D-Var Loss - -Implements a neural network that learns to produce analysis states by minimizing -the 3D-Var cost function without using ground-truth labels. -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - - -class ThreeDVarLoss(nn.Module): - """ - Implements the 3D-Var cost function as a self-supervised loss: - - J(x) = (x - x_b)^T B^{-1} (x - x_b) + (y - Hx)^T R^{-1} (y - Hx) - - Where: - - x: analysis state (model output) - - x_b: background state (first guess) - - y: observations - - B: background error covariance - - R: observation error covariance - - H: observation operator - """ - - def __init__(self, - background_error_covariance=None, - observation_error_covariance=None, - observation_operator=None, - bg_weight=1.0, - obs_weight=1.0): - """ - Initialize the 3D-Var loss function - - Args: - background_error_covariance: B matrix (background error covariance) - observation_error_covariance: R matrix (observation error covariance) - observation_operator: H matrix (observation operator) - bg_weight: Weight for background term - obs_weight: Weight for observation term - """ - super(ThreeDVarLoss, self).__init__() - - self.bg_weight = bg_weight - self.obs_weight = obs_weight - - # Initialize background error covariance B - if background_error_covariance is None: - # Default to identity matrix (diagonal with unit variance) - self.B_inv = None # Will be computed as identity when needed - else: - if isinstance(background_error_covariance, torch.Tensor): - self.B_inv = torch.inverse(background_error_covariance) - else: - self.B_inv = torch.inverse(torch.tensor(background_error_covariance)) - - # Initialize observation error covariance R - if observation_error_covariance is None: - # Default to identity matrix (diagonal with unit variance) - self.R_inv = None # Will be computed as identity when needed - else: - if isinstance(observation_error_covariance, torch.Tensor): - self.R_inv = torch.inverse(observation_error_covariance) - else: - self.R_inv = torch.inverse(torch.tensor(observation_error_covariance)) - - # Initialize observation operator H - if observation_operator is None: - # Default to identity (direct observation of state variables) - self.H = None # Will be treated as identity when needed - else: - if isinstance(observation_operator, torch.Tensor): - self.H = observation_operator - else: - self.H = torch.tensor(observation_operator) - - def forward(self, analysis, background, observations): - """ - Compute the 3D-Var loss - - Args: - analysis: Model output (analysis state x) - background: Background state (x_b) - observations: Observations (y) - - Returns: - Total loss value - """ - batch_size = analysis.size(0) - - # Background term: (x - x_b)^T B^{-1} (x - x_b) - bg_diff = analysis - background - - if self.B_inv is None: - # Use identity matrix for B^{-1} - bg_term = torch.sum(bg_diff * bg_diff, dim=-1) # Element-wise square and sum - else: - # Compute quadratic form (x - x_b)^T B^{-1} (x - x_b) - bg_term = torch.sum(bg_diff * torch.matmul(bg_diff.unsqueeze(-2), self.B_inv).squeeze(-2), dim=-1) - - bg_term = self.bg_weight * torch.mean(bg_term) - - # Observation term: (y - Hx)^T R^{-1} (y - Hx) - if self.H is None: - # H is identity, so Hx = x - hx = analysis - else: - # Apply observation operator: Hx - if len(analysis.shape) == 2: - # 2D case: [batch, features] - hx = torch.matmul(analysis, self.H.T) - else: - # For multi-dimensional case, we might need to reshape - original_shape = analysis.shape - analysis_flat = analysis.view(batch_size, -1) - hx_flat = torch.matmul(analysis_flat, self.H.T) - hx = hx_flat.view(original_shape) - - obs_diff = observations - hx - - if self.R_inv is None: - # Use identity matrix for R^{-1} - obs_term = torch.sum(obs_diff * obs_diff, dim=-1) # Element-wise square and sum - else: - # Compute quadratic form (y - Hx)^T R^{-1} (y - Hx) - obs_term = torch.sum(obs_diff * torch.matmul(obs_diff.unsqueeze(-2), self.R_inv).squeeze(-2), dim=-1) - - obs_term = self.obs_weight * torch.mean(obs_term) - - # Total 3D-Var cost - total_loss = bg_term + obs_term - - return total_loss - - -class DataAssimilationModel(nn.Module): - """ - Neural network model for self-supervised data assimilation. - - Takes background state and observations as input and produces an analysis state - that minimizes the 3D-Var cost function. - """ - - def __init__(self, - input_dim, - hidden_dim=256, - num_layers=3, - dropout=0.1, - activation='relu'): - """ - Initialize the data assimilation model - - Args: - input_dim: Dimension of the input state - hidden_dim: Hidden layer dimension - num_layers: Number of hidden layers - dropout: Dropout rate - activation: Activation function ('relu', 'tanh', 'gelu') - """ - super(DataAssimilationModel, self).__init__() - - self.input_dim = input_dim - self.hidden_dim = hidden_dim - self.num_layers = num_layers - - # Define activation function - if activation == 'relu': - self.activation = nn.ReLU() - elif activation == 'tanh': - self.activation = nn.Tanh() - elif activation == 'gelu': - self.activation = nn.GELU() - else: - raise ValueError(f"Unsupported activation: {activation}") - - # Encoder to combine background and observations - layers = [] - layers.append(nn.Linear(input_dim * 2, hidden_dim)) # bg + obs - layers.append(self.activation) - layers.append(nn.Dropout(dropout)) - - for _ in range(num_layers - 1): - layers.append(nn.Linear(hidden_dim, hidden_dim)) - layers.append(self.activation) - layers.append(nn.Dropout(dropout)) - - # Output layer to produce analysis - layers.append(nn.Linear(hidden_dim, input_dim)) - - self.network = nn.Sequential(*layers) - - def forward(self, background, observations): - """ - Forward pass of the data assimilation model - - Args: - background: Background state (x_b) - observations: Observations (y) - - Returns: - analysis: Analysis state (x) - """ - # Concatenate background and observations along the feature dimension - combined_input = torch.cat([background, observations], dim=-1) - - # Pass through the network to get analysis - analysis = self.network(combined_input) - - return analysis - - -class SimpleDataAssimilationModel(nn.Module): - """ - Simplified version that works with 1D/2D spatial grids - """ - - def __init__(self, - grid_size, - num_channels=1, - hidden_dim=64, - num_layers=2): - """ - Initialize a simple data assimilation model for grid data - - Args: - grid_size: Size of the spatial grid (height, width) or (size,) - num_channels: Number of channels/variables - hidden_dim: Hidden dimension for processing - num_layers: Number of processing layers - """ - super(SimpleDataAssimilationModel, self).__init__() - - if isinstance(grid_size, (tuple, list)): - self.grid_shape = grid_size - self.grid_size = np.prod(grid_size) - else: - self.grid_shape = (grid_size,) - self.grid_size = grid_size - - self.num_channels = num_channels - self.input_features = self.grid_size * num_channels - - # Simple CNN-based architecture for spatial data - layers = [] - layers.append(nn.Conv1d(2 * num_channels, hidden_dim, kernel_size=3, padding=1)) - layers.append(nn.ReLU()) - - for _ in range(num_layers - 1): - layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)) - layers.append(nn.ReLU()) - - layers.append(nn.Conv1d(hidden_dim, num_channels, kernel_size=3, padding=1)) - - self.conv_layers = nn.Sequential(*layers) - - def forward(self, background, observations): - """ - Forward pass for grid data - - Args: - background: Background state [batch, channels, ...spatial_dims] - observations: Observations [batch, channels, ...spatial_dims] - - Returns: - analysis: Analysis state [batch, channels, ...spatial_dims] - """ - batch_size = background.size(0) - - # Reshape for 1D convolution if needed - if len(background.shape) > 3: # [batch, channels, height, width] - bg_flat = background.view(batch_size, self.num_channels, -1) - obs_flat = observations.view(batch_size, self.num_channels, -1) - else: # [batch, channels, length] - bg_flat = background - obs_flat = observations - - # Concatenate along channel dimension - combined = torch.cat([bg_flat, obs_flat], dim=1) # [batch, 2*channels, spatial] - - # Process through convolutional layers - analysis_flat = self.conv_layers(combined) - - # Reshape back to original spatial dimensions - if len(background.shape) > 3: - analysis = analysis_flat.view(batch_size, self.num_channels, *self.grid_shape) - else: - analysis = analysis_flat - - return analysis - - -def create_observation_operator(grid_size, obs_fraction=0.5, obs_locations=None): - """ - Create a simple observation operator H that selects a subset of grid points - - Args: - grid_size: Size of the grid (can be int for 1D or tuple for 2D) - obs_fraction: Fraction of grid points that have observations - obs_locations: Specific locations of observations (optional) - - Returns: - H: Observation operator matrix - """ - total_size = np.prod(grid_size) if isinstance(grid_size, (tuple, list)) else grid_size - - if obs_locations is None: - # Randomly select observation locations - num_obs = int(total_size * obs_fraction) - obs_indices = np.random.choice(total_size, size=num_obs, replace=False) - else: - obs_indices = obs_locations - num_obs = len(obs_indices) - - # Create H matrix (num_obs x total_size) - H = torch.zeros(num_obs, total_size) - for i, idx in enumerate(obs_indices): - H[i, idx] = 1.0 - - return H - - -def generate_synthetic_data(batch_size=32, grid_size=(10, 10), num_channels=1): - """ - Generate synthetic background and observation data for testing - - Args: - batch_size: Number of samples in batch - grid_size: Size of spatial grid - num_channels: Number of variables/channels - - Returns: - background: Background state - observations: Observations - true_state: True state (for evaluation only, not used in training) - """ - total_size = np.prod(grid_size) if isinstance(grid_size, (tuple, list)) else grid_size - - # Generate a true state with some spatial correlation - true_state = torch.randn(batch_size, num_channels, *grid_size) * 2 - # Apply some smoothing to create spatial correlation - if len(grid_size) == 2: - # Apply a simple smoothing kernel - kernel = torch.ones(1, 1, 3, 3) / 9 - for c in range(num_channels): - smoothed = F.conv2d( - true_state[:, c:c+1], - kernel, - padding=1, - groups=1 - ) - true_state[:, c:c+1] = smoothed - - # Create background as true state with some error - background_error = torch.randn_like(true_state) * 0.5 - background = true_state + background_error - - # Create observations with observation error - observation_error = torch.randn_like(true_state) * 0.3 - observations = true_state + observation_error - - return background, observations, true_state \ No newline at end of file diff --git a/graph_weather/models/evaluation.py b/graph_weather/models/evaluation.py deleted file mode 100644 index 01e28f92..00000000 --- a/graph_weather/models/evaluation.py +++ /dev/null @@ -1,439 +0,0 @@ -""" -Evaluation metrics for self-supervised data assimilation -""" - -import torch -import numpy as np -from sklearn.metrics import mean_squared_error, mean_absolute_error -import scipy.stats as stats - - -def compute_rmse(predictions, targets): - """ - Compute Root Mean Square Error - - Args: - predictions: Predicted values - targets: Target values - - Returns: - rmse: Root mean square error - """ - return torch.sqrt(torch.mean((predictions - targets) ** 2)).item() - - -def compute_mae(predictions, targets): - """ - Compute Mean Absolute Error - - Args: - predictions: Predicted values - targets: Target values - - Returns: - mae: Mean absolute error - """ - return torch.mean(torch.abs(predictions - targets)).item() - - -def compute_bias(predictions, targets): - """ - Compute bias (mean error) - - Args: - predictions: Predicted values - targets: Target values - - Returns: - bias: Mean error - """ - return torch.mean(predictions - targets).item() - - -def compute_correlation(predictions, targets): - """ - Compute Pearson correlation coefficient - - Args: - predictions: Predicted values - targets: Target values - - Returns: - correlation: Pearson correlation coefficient - """ - pred_flat = predictions.view(-1) - target_flat = targets.view(-1) - - # Center the data - pred_centered = pred_flat - torch.mean(pred_flat) - target_centered = target_flat - torch.mean(target_flat) - - # Compute correlation - numerator = torch.sum(pred_centered * target_centered) - denominator = torch.sqrt(torch.sum(pred_centered ** 2) * torch.sum(target_centered ** 2)) - - if denominator == 0: - return 0.0 - - return (numerator / denominator).item() - - -def compute_spatial_metrics(predictions, targets): - """ - Compute spatial metrics for gridded data - - Args: - predictions: Predicted values [batch, channels, height, width] - targets: Target values [batch, channels, height, width] - - Returns: - metrics: Dictionary with spatial metrics - """ - batch_size, channels = predictions.shape[0], predictions.shape[1] - - rmse_spatial = [] - correlation_spatial = [] - - for b in range(batch_size): - for c in range(channels): - pred_channel = predictions[b, c].flatten() - target_channel = targets[b, c].flatten() - - rmse = torch.sqrt(torch.mean((pred_channel - target_channel) ** 2)).item() - rmse_spatial.append(rmse) - - # Compute correlation - pred_centered = pred_channel - torch.mean(pred_channel) - target_centered = target_channel - torch.mean(target_channel) - numerator = torch.sum(pred_centered * target_centered) - denominator = torch.sqrt(torch.sum(pred_centered ** 2) * torch.sum(target_centered ** 2)) - - if denominator != 0: - corr = (numerator / denominator).item() - else: - corr = 0.0 - correlation_spatial.append(corr) - - return { - 'avg_rmse_spatial': np.mean(rmse_spatial), - 'std_rmse_spatial': np.std(rmse_spatial), - 'avg_correlation_spatial': np.mean(correlation_spatial), - 'std_correlation_spatial': np.std(correlation_spatial) - } - - -def compute_information_gain(analysis, background, true_state): - """ - Compute information gain from data assimilation - Measures how much better the analysis is compared to background - - Args: - analysis: Analysis state from model - background: Background state (first guess) - true_state: True state (for evaluation) - - Returns: - info_gain: Information gain metric - """ - bg_error = torch.mean((background - true_state) ** 2) - analysis_error = torch.mean((analysis - true_state) ** 2) - - # Information gain as reduction in error variance - info_gain = (bg_error - analysis_error) / bg_error * 100 if bg_error > 0 else 0 - - return info_gain.item() - - -class DataAssimilationEvaluator: - """ - Comprehensive evaluator for data assimilation models - """ - - def __init__(self, model, device='cpu'): - self.model = model - self.device = device - - def evaluate_batch(self, batch): - """ - Evaluate a single batch - - Args: - batch: Dictionary with 'background', 'observations', 'true_state' - - Returns: - metrics: Dictionary with evaluation metrics for the batch - """ - self.model.eval() - - with torch.no_grad(): - background = batch['background'].to(self.device) - observations = batch['observations'].to(self.device) - true_state = batch['true_state'].to(self.device) - - # Get model analysis - analysis = self.model(background, observations) - - # Compute metrics - metrics = { - 'analysis_rmse': compute_rmse(analysis, true_state), - 'background_rmse': compute_rmse(background, true_state), - 'observations_rmse': compute_rmse(observations, true_state), - 'analysis_mae': compute_mae(analysis, true_state), - 'background_mae': compute_mae(background, true_state), - 'analysis_bias': compute_bias(analysis, true_state), - 'background_bias': compute_bias(background, true_state), - 'analysis_correlation': compute_correlation(analysis, true_state), - 'background_correlation': compute_correlation(background, true_state), - 'information_gain': compute_information_gain(analysis, background, true_state) - } - - # Add spatial metrics if data is gridded - if len(analysis.shape) > 2: # Has spatial dimensions - spatial_metrics = compute_spatial_metrics(analysis, true_state) - metrics.update(spatial_metrics) - - return metrics - - def evaluate_dataset(self, data_loader): - """ - Evaluate the model on an entire dataset - - Args: - data_loader: DataLoader with test data - - Returns: - overall_metrics: Dictionary with overall evaluation metrics - """ - all_metrics = { - 'analysis_rmse': [], - 'background_rmse': [], - 'observations_rmse': [], - 'analysis_mae': [], - 'background_mae': [], - 'analysis_bias': [], - 'background_bias': [], - 'analysis_correlation': [], - 'background_correlation': [], - 'information_gain': [] - } - - spatial_metrics_list = [] - - for batch in data_loader: - batch_metrics = self.evaluate_batch(batch) - - # Collect metrics - for key in all_metrics.keys(): - if key in batch_metrics: - all_metrics[key].append(batch_metrics[key]) - - # Collect spatial metrics if available - if 'avg_rmse_spatial' in batch_metrics: - spatial_metrics_list.append({ - 'avg_rmse_spatial': batch_metrics['avg_rmse_spatial'], - 'avg_correlation_spatial': batch_metrics['avg_correlation_spatial'] - }) - - # Compute overall metrics - overall_metrics = {} - for key, values in all_metrics.items(): - if values: # Only compute if we have values - overall_metrics[f'avg_{key}'] = np.mean(values) - overall_metrics[f'std_{key}'] = np.std(values) - - # Compute spatial metrics - if spatial_metrics_list: - spatial_rmse_values = [m['avg_rmse_spatial'] for m in spatial_metrics_list] - spatial_corr_values = [m['avg_correlation_spatial'] for m in spatial_metrics_list] - - overall_metrics['avg_spatial_rmse'] = np.mean(spatial_rmse_values) - overall_metrics['std_spatial_rmse'] = np.std(spatial_rmse_values) - overall_metrics['avg_spatial_correlation'] = np.mean(spatial_corr_values) - overall_metrics['std_spatial_correlation'] = np.std(spatial_corr_values) - - return overall_metrics - - -def compare_methods(model_analysis, background, observations, true_state): - """ - Compare different methods: model analysis, background, observations - - Args: - model_analysis: Analysis from the trained model - background: Background state - observations: Observations - true_state: True state for comparison - - Returns: - comparison: Dictionary with comparison results - """ - results = {} - - # Compute metrics for each method - methods = { - 'analysis': model_analysis, - 'background': background, - 'observations': observations - } - - for method_name, method_output in methods.items(): - results[f'{method_name}_rmse'] = compute_rmse(method_output, true_state) - results[f'{method_name}_mae'] = compute_mae(method_output, true_state) - results[f'{method_name}_bias'] = compute_bias(method_output, true_state) - results[f'{method_name}_correlation'] = compute_correlation(method_output, true_state) - - # Compute improvements - bg_rmse = results['background_rmse'] - obs_rmse = results['observations_rmse'] - analysis_rmse = results['analysis_rmse'] - - results['analysis_improvement_over_bg_pct'] = ( - (bg_rmse - analysis_rmse) / bg_rmse * 100 - ) if bg_rmse > 0 else 0 - - results['analysis_improvement_over_obs_pct'] = ( - (obs_rmse - analysis_rmse) / obs_rmse * 100 - ) if obs_rmse > 0 else 0 - - results['bg_improvement_over_obs_pct'] = ( - (obs_rmse - bg_rmse) / obs_rmse * 100 - ) if obs_rmse > 0 else 0 - - return results - - -def classical_3dvar_analysis(background, observations, H, B, R): - """ - Classical 3D-Var analysis for comparison - - Args: - background: Background state - observations: Observations - H: Observation operator - B: Background error covariance - R: Observation error covariance - - Returns: - analysis: Classical 3D-Var analysis - """ - # Reshape for matrix operations - batch_size = background.shape[0] - state_size = background[0].numel() - obs_size = observations[0].numel() - - # Convert to appropriate shapes - xb = background.view(batch_size, -1) # [batch, state_size] - y = observations.view(batch_size, -1) # [batch, obs_size] - - analysis_results = [] - - for i in range(batch_size): - xb_i = xb[i:i+1].T # [state_size, 1] - y_i = y[i:i+1].T # [obs_size, 1] - - # Compute Kalman gain: K = B * H^T * (H * B * H^T + R)^(-1) - # For simplicity, using diagonal approximations - if B is None: - B_i = torch.eye(state_size, device=xb.device) - else: - B_i = B - - if R is None: - R_i = torch.eye(obs_size, device=y.device) - else: - R_i = R - - if H is None: - H_i = torch.eye(min(state_size, obs_size), device=xb.device)[:obs_size, :state_size] - else: - H_i = H - - # Calculate terms - HBHT_R = torch.matmul(torch.matmul(H_i, B_i), H_i.T) + R_i - K = torch.matmul(torch.matmul(B_i, H_i.T), torch.inverse(HBHT_R)) - - # Compute analysis: xa = xb + K * (y - H * xb) - innovation = y_i - torch.matmul(H_i, xb_i) - correction = torch.matmul(K, innovation) - xa_i = xb_i + correction - - analysis_results.append(xa_i.T) - - analysis = torch.cat(analysis_results, dim=0) - return analysis.view_as(background) - - -def compute_cross_validation_score(model, data_loader, k_folds=5): - """ - Compute cross-validation score for the model - - Args: - model: Data assimilation model - data_loader: Data loader - k_folds: Number of folds for cross-validation - - Returns: - cv_scores: List of scores for each fold - """ - # For simplicity, using a basic approach to simulate cross-validation - # In practice, you'd split your dataset into k folds - model.eval() - - all_rmse = [] - batch_count = 0 - - with torch.no_grad(): - for batch in data_loader: - background = batch['background'].to(model.device if hasattr(model, 'device') else 'cpu') - observations = batch['observations'].to(model.device if hasattr(model, 'device') else 'cpu') - - if 'true_state' in batch: - true_state = batch['true_state'].to(model.device if hasattr(model, 'device') else 'cpu') - - analysis = model(background, observations) - rmse = compute_rmse(analysis, true_state) - all_rmse.append(rmse) - batch_count += 1 - - # Limit for efficiency - if batch_count >= k_folds: - break - - return all_rmse - - -def compute_gradient_norm(model): - """ - Compute the norm of gradients for the model - - Args: - model: PyTorch model - - Returns: - total_norm: Total gradient norm - """ - total_norm = 0 - for p in model.parameters(): - if p.grad is not None: - param_norm = p.grad.data.norm(2) - total_norm += param_norm.item() ** 2 - total_norm = total_norm ** (1. / 2) - return total_norm - - -def compute_parameter_norm(model): - """ - Compute the norm of parameters for the model - - Args: - model: PyTorch model - - Returns: - total_norm: Total parameter norm - """ - total_norm = 0 - for p in model.parameters(): - param_norm = p.data.norm(2) - total_norm += param_norm.item() ** 2 - total_norm = total_norm ** (1. / 2) - return total_norm \ No newline at end of file diff --git a/graph_weather/models/training_loop.py b/graph_weather/models/training_loop.py deleted file mode 100644 index d5a6f4c6..00000000 --- a/graph_weather/models/training_loop.py +++ /dev/null @@ -1,538 +0,0 @@ -""" -Training loop for self-supervised data assimilation with 3D-Var loss -""" - -import torch -import torch.nn as nn -from torch.optim import Adam, SGD -from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau -import numpy as np -from tqdm import tqdm -import matplotlib.pyplot as plt -from .data_assimilation import DataAssimilationModel, ThreeDVarLoss, SimpleDataAssimilationModel - - -class DataAssimilationTrainer: - """ - Trainer for the self-supervised data assimilation model - """ - - def __init__( - self, - model, - loss_fn, - optimizer=None, - lr=1e-3, - device='cpu', - scheduler=None - ): - """ - Initialize the trainer - - Args: - model: Data assimilation model - loss_fn: 3D-Var loss function - optimizer: Optimizer (default: Adam) - lr: Learning rate - device: Device to train on - scheduler: Learning rate scheduler - """ - self.model = model.to(device) - self.loss_fn = loss_fn.to(device) - self.device = device - - if optimizer is None: - self.optimizer = Adam(model.parameters(), lr=lr) - else: - self.optimizer = optimizer - - self.scheduler = scheduler - self.train_losses = [] - self.val_losses = [] - - def train_step(self, background, observations): - """ - Perform a single training step - - Args: - background: Background state - observations: Observations - - Returns: - loss: Training loss value - """ - self.model.train() - self.optimizer.zero_grad() - - # Move data to device - background = background.to(self.device) - observations = observations.to(self.device) - - # Forward pass - analysis = self.model(background, observations) - - # Compute loss - loss = self.loss_fn(analysis, background, observations) - - # Backward pass - loss.backward() - - # Gradient clipping - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - # Update parameters - self.optimizer.step() - - return loss.item() - - def validation_step(self, background, observations): - """ - Perform a validation step - - Args: - background: Background state - observations: Observations - - Returns: - loss: Validation loss value - """ - self.model.eval() - - with torch.no_grad(): - # Move data to device - background = background.to(self.device) - observations = observations.to(self.device) - - # Forward pass - analysis = self.model(background, observations) - - # Compute loss - loss = self.loss_fn(analysis, background, observations) - - return loss.item() - - def train_epoch(self, train_loader): - """ - Train for one epoch - - Args: - train_loader: Training data loader - - Returns: - avg_loss: Average training loss for the epoch - """ - total_loss = 0 - num_batches = 0 - - for batch in tqdm(train_loader, desc="Training", leave=False): - background = batch['background'] - observations = batch['observations'] - - loss = self.train_step(background, observations) - total_loss += loss - num_batches += 1 - - avg_loss = total_loss / num_batches - return avg_loss - - def validate_epoch(self, val_loader): - """ - Validate for one epoch - - Args: - val_loader: Validation data loader - - Returns: - avg_loss: Average validation loss for the epoch - """ - total_loss = 0 - num_batches = 0 - - for batch in val_loader: - background = batch['background'] - observations = batch['observations'] - - loss = self.validation_step(background, observations) - total_loss += loss - num_batches += 1 - - avg_loss = total_loss / num_batches - return avg_loss - - def fit( - self, - train_loader, - val_loader, - epochs=100, - verbose=True, - save_best_model=True, - model_save_path="best_assimilation_model.pth" - ): - """ - Train the model - - Args: - train_loader: Training data loader - val_loader: Validation data loader - epochs: Number of training epochs - verbose: Whether to print progress - save_best_model: Whether to save the best model - model_save_path: Path to save the best model - - Returns: - train_losses: Training losses for each epoch - val_losses: Validation losses for each epoch - """ - best_val_loss = float('inf') - patience_counter = 0 - - for epoch in range(epochs): - # Training - train_loss = self.train_epoch(train_loader) - self.train_losses.append(train_loss) - - # Validation - val_loss = self.validate_epoch(val_loader) - self.val_losses.append(val_loss) - - # Learning rate scheduling - if self.scheduler is not None: - if isinstance(self.scheduler, ReduceLROnPlateau): - self.scheduler.step(val_loss) - else: - self.scheduler.step() - - if save_best_model and val_loss < best_val_loss: - best_val_loss = val_loss - torch.save(self.model.state_dict(), model_save_path) - patience_counter = 0 - else: - patience_counter += 1 - - if verbose and (epoch + 1) % 10 == 0: - print(f"Epoch [{epoch+1}/{epochs}] - " - f"Train Loss: {train_loss:.6f}, " - f"Val Loss: {val_loss:.6f}") - - if save_best_model: - self.model.load_state_dict(torch.load(model_save_path)) - - return self.train_losses, self.val_losses - - def evaluate_model(self, test_loader, compute_metrics=True): - """ - Evaluate the model on test data - - Args: - test_loader: Test data loader - compute_metrics: Whether to compute additional metrics - - Returns: - results: Dictionary with evaluation results - """ - self.model.eval() - total_loss = 0 - num_batches = 0 - - all_analysis = [] - all_background = [] - all_observations = [] - all_true = [] - - with torch.no_grad(): - for batch in test_loader: - background = batch['background'].to(self.device) - observations = batch['observations'].to(self.device) - - analysis = self.model(background, observations) - loss = self.loss_fn(analysis, background, observations) - - total_loss += loss.item() - num_batches += 1 - - if compute_metrics: - all_analysis.append(analysis.cpu()) - all_background.append(background.cpu()) - all_observations.append(observations.cpu()) - - if 'true_state' in batch: - all_true.append(batch['true_state']) - - avg_loss = total_loss / num_batches - results = {'loss': avg_loss} - - if compute_metrics and all_true: - # Compute additional metrics - all_analysis = torch.cat(all_analysis, dim=0) - all_background = torch.cat(all_background, dim=0) - all_observations = torch.cat(all_observations, dim=0) - all_true = torch.cat(all_true, dim=0) - - # RMSE metrics - analysis_rmse = torch.sqrt(torch.mean((all_analysis - all_true) ** 2)).item() - bg_rmse = torch.sqrt(torch.mean((all_background - all_true) ** 2)).item() - obs_rmse = torch.sqrt(torch.mean((all_observations - all_true) ** 2)).item() - - results.update({ - 'analysis_rmse': analysis_rmse, - 'background_rmse': bg_rmse, - 'observations_rmse': obs_rmse, - 'improvement_over_bg': (bg_rmse - analysis_rmse) / bg_rmse * 100 if bg_rmse > 0 else 0, - 'improvement_over_obs': (obs_rmse - analysis_rmse) / obs_rmse * 100 if obs_rmse > 0 else 0 - }) - - return results - - -def train_data_assimilation_model( - model, - train_loader, - val_loader, - bg_error_covariance=None, - obs_error_covariance=None, - obs_operator=None, - epochs=100, - lr=1e-3, - device='cpu' -): - """ - Convenience function to train a data assimilation model - - Args: - model: Data assimilation model - train_loader: Training data loader - val_loader: Validation data loader - bg_error_covariance: Background error covariance - obs_error_covariance: Observation error covariance - obs_operator: Observation operator - epochs: Number of training epochs - lr: Learning rate - device: Device to train on - - Returns: - trainer: Trained trainer object - results: Training results - """ - # Initialize the 3D-Var loss function - loss_fn = ThreeDVarLoss( - background_error_covariance=bg_error_covariance, - observation_error_covariance=obs_error_covariance, - observation_operator=obs_operator - ) - - # Initialize the trainer - trainer = DataAssimilationTrainer( - model=model, - loss_fn=loss_fn, - lr=lr, - device=device - ) - - # Add learning rate scheduler - scheduler = ReduceLROnPlateau(trainer.optimizer, mode='min', factor=0.5, patience=10) - trainer.scheduler = scheduler - - # Train the model - train_losses, val_losses = trainer.fit( - train_loader=train_loader, - val_loader=val_loader, - epochs=epochs, - verbose=True - ) - - return trainer, {'train_losses': train_losses, 'val_losses': val_losses} - - -def train_with_different_modes( - model_class, - data_module, - input_dim=None, - grid_size=None, - num_channels=1, - epochs=100, - lr=1e-3, - device='cpu' -): - """ - Train the model in different modes: - 1. With good first guess (low background error) - 2. With poor first guess (high background error) - cold start - 3. With varying observation densities - - Args: - model_class: Model class to instantiate - data_module: Data module with different configurations - input_dim: Input dimension for fully connected model - grid_size: Grid size for convolutional model - num_channels: Number of channels - epochs: Number of epochs - lr: Learning rate - device: Device to train on - - Returns: - results: Dictionary with results from different training modes - """ - results = {} - - # Mode 1: With good first guess (low background error) - print("Training with good first guess...") - data_module_good_bg = data_module( - num_samples=1000, - grid_size=grid_size, - num_channels=num_channels, - bg_error_std=0.2, # Low error - obs_error_std=0.3, - obs_fraction=0.5 - ) - data_module_good_bg.setup() - - if input_dim: - model_good_bg = model_class(input_dim=input_dim) - else: - model_good_bg = model_class(grid_size=grid_size, num_channels=num_channels) - - trainer_good_bg, res_good_bg = train_data_assimilation_model( - model=model_good_bg, - train_loader=data_module_good_bg.train_dataloader(), - val_loader=data_module_good_bg.val_dataloader(), - epochs=epochs, - lr=lr, - device=device - ) - - results['good_bg'] = { - 'trainer': trainer_good_bg, - 'results': res_good_bg, - 'eval_results': trainer_good_bg.evaluate_model(data_module_good_bg.test_dataloader()) - } - - # Mode 2: With poor first guess (high background error) - cold start - print("Training with poor first guess (cold start)...") - data_module_poor_bg = data_module( - num_samples=1000, - grid_size=grid_size, - num_channels=num_channels, - bg_error_std=1.0, # High error - obs_error_std=0.3, - obs_fraction=0.5 - ) - data_module_poor_bg.setup() - - if input_dim: - model_poor_bg = model_class(input_dim=input_dim) - else: - model_poor_bg = model_class(grid_size=grid_size, num_channels=num_channels) - - trainer_poor_bg, res_poor_bg = train_data_assimilation_model( - model=model_poor_bg, - train_loader=data_module_poor_bg.train_dataloader(), - val_loader=data_module_poor_bg.val_dataloader(), - epochs=epochs, - lr=lr, - device=device - ) - - results['poor_bg'] = { - 'trainer': trainer_poor_bg, - 'results': res_poor_bg, - 'eval_results': trainer_poor_bg.evaluate_model(data_module_poor_bg.test_dataloader()) - } - - # Mode 3: With sparse observations - print("Training with sparse observations...") - data_module_sparse_obs = data_module( - num_samples=1000, - grid_size=grid_size, - num_channels=num_channels, - bg_error_std=0.5, - obs_error_std=0.3, - obs_fraction=0.2 # Sparse observations - ) - data_module_sparse_obs.setup() - - if input_dim: - model_sparse_obs = model_class(input_dim=input_dim) - else: - model_sparse_obs = model_class(grid_size=grid_size, num_channels=num_channels) - - trainer_sparse_obs, res_sparse_obs = train_data_assimilation_model( - model=model_sparse_obs, - train_loader=data_module_sparse_obs.train_dataloader(), - val_loader=data_module_sparse_obs.val_dataloader(), - epochs=epochs, - lr=lr, - device=device - ) - - results['sparse_obs'] = { - 'trainer': trainer_sparse_obs, - 'results': res_sparse_obs, - 'eval_results': trainer_sparse_obs.evaluate_model(data_module_sparse_obs.test_dataloader()) - } - - return results - - -def compare_with_baselines(model, test_loader, device='cpu'): - """ - Compare the trained model with classical baselines - - Args: - model: Trained assimilation model - test_loader: Test data loader - device: Device to run on - - Returns: - comparison: Dictionary with comparison results - """ - model.eval() - results = { - 'analysis_rmse': [], - 'background_rmse': [], - 'observation_rmse': [], - 'persistence_rmse': [] - } - - with torch.no_grad(): - for batch in test_loader: - background = batch['background'].to(device) - observations = batch['observations'].to(device) - - if 'true_state' in batch: - true_state = batch['true_state'].to(device) - - # Model analysis - analysis = model(background, observations) - - # Compute RMSE for each method - analysis_rmse = torch.sqrt(torch.mean((analysis - true_state) ** 2)).item() - bg_rmse = torch.sqrt(torch.mean((background - true_state) ** 2)).item() - obs_rmse = torch.sqrt(torch.mean((observations - true_state) ** 2)).item() - - # Persistence (assuming observations are closer to truth than background) - # For simplicity, using a weighted average as persistence - persistence = 0.7 * observations + 0.3 * background - persist_rmse = torch.sqrt(torch.mean((persistence - true_state) ** 2)).item() - - results['analysis_rmse'].append(analysis_rmse) - results['background_rmse'].append(bg_rmse) - results['observation_rmse'].append(obs_rmse) - results['persistence_rmse'].append(persist_rmse) - - # Compute averages - comparison = { - 'avg_analysis_rmse': np.mean(results['analysis_rmse']), - 'avg_background_rmse': np.mean(results['background_rmse']), - 'avg_observation_rmse': np.mean(results['observation_rmse']), - 'avg_persistence_rmse': np.mean(results['persistence_rmse']), - 'analysis_improvement_over_bg': ( - (np.mean(results['background_rmse']) - np.mean(results['analysis_rmse'])) / - np.mean(results['background_rmse']) * 100 - ), - 'analysis_improvement_over_obs': ( - (np.mean(results['observation_rmse']) - np.mean(results['analysis_rmse'])) / - np.mean(results['observation_rmse']) * 100 - ) - } - - return comparison \ No newline at end of file diff --git a/graph_weather/models/visualization.py b/graph_weather/models/visualization.py deleted file mode 100644 index 6791e9b2..00000000 --- a/graph_weather/models/visualization.py +++ /dev/null @@ -1,582 +0,0 @@ -""" -Visualization functions for self-supervised data assimilation -""" - -import torch -import numpy as np -import matplotlib.pyplot as plt -import seaborn as sns -from matplotlib.colors import Normalize -import matplotlib.patches as patches -import warnings -warnings.filterwarnings('ignore') - - -def plot_training_curves(train_losses, val_losses, title="Training Curves"): - """ - Plot training and validation loss curves - - Args: - train_losses: List of training losses - val_losses: List of validation losses - title: Title for the plot - """ - plt.figure(figsize=(10, 6)) - plt.plot(train_losses, label='Training Loss', color='blue') - plt.plot(val_losses, label='Validation Loss', color='red') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.title(title) - plt.legend() - plt.grid(True, alpha=0.3) - plt.tight_layout() - plt.show() - - -def plot_comparison_grid(background, observations, analysis, true_state=None, - titles=None, figsize=(15, 10)): - """ - Plot a grid comparing background, observations, analysis, and true state - - Args: - background: Background state - observations: Observations - analysis: Analysis from model - true_state: True state (optional) - titles: Titles for each subplot - figsize: Figure size - """ - # Convert to numpy if torch tensors - if torch.is_tensor(background): - background = background.cpu().numpy() - if torch.is_tensor(observations): - observations = observations.cpu().numpy() - if torch.is_tensor(analysis): - analysis = analysis.cpu().numpy() - if true_state is not None and torch.is_tensor(true_state): - true_state = true_state.cpu().numpy() - - if titles is None: - titles = ['Background', 'Observations', 'Analysis'] - if true_state is not None: - titles.append('True State') - - n_plots = 3 if true_state is None else 4 - - fig, axes = plt.subplots(1, n_plots, figsize=figsize) - if n_plots == 1: - axes = [axes] - - # Determine common color scale - all_data = [background, observations, analysis] - if true_state is not None: - all_data.append(true_state) - - # Handle different tensor shapes - def get_data_for_plot(data): - if data.ndim == 4: # [batch, channels, height, width] - return data[0, 0] # Take first sample, first channel - elif data.ndim == 3: # [batch, height, width] or [channels, height, width] - return data[0] if data.shape[0] <= 10 else data[0] # Heuristic for batch vs channels - elif data.ndim == 2: # [height, width] - return data - else: - raise ValueError(f"Unexpected data shape: {data.shape}") - - processed_data = [get_data_for_plot(d) for d in all_data] - vmin = min([d.min() for d in processed_data]) - vmax = max([d.max() for d in processed_data]) - - # Plot each field - im1 = axes[0].imshow(processed_data[0], cmap='RdBu_r', vmin=vmin, vmax=vmax) - axes[0].set_title(titles[0]) - axes[0].axis('off') - plt.colorbar(im1, ax=axes[0]) - - im2 = axes[1].imshow(processed_data[1], cmap='RdBu_r', vmin=vmin, vmax=vmax) - axes[1].set_title(titles[1]) - axes[1].axis('off') - plt.colorbar(im2, ax=axes[1]) - - im3 = axes[2].imshow(processed_data[2], cmap='RdBu_r', vmin=vmin, vmax=vmax) - axes[2].set_title(titles[2]) - axes[2].axis('off') - plt.colorbar(im3, ax=axes[2]) - - if true_state is not None and n_plots > 3: - im4 = axes[3].imshow(processed_data[3], cmap='RdBu_r', vmin=vmin, vmax=vmax) - axes[3].set_title(titles[3]) - axes[3].axis('off') - plt.colorbar(im4, ax=axes[3]) - - plt.tight_layout() - plt.show() - - -def plot_error_maps(background, observations, analysis, true_state, - titles=None, figsize=(18, 5)): - """ - Plot error maps comparing different methods - - Args: - background: Background state - observations: Observations - analysis: Analysis from model - true_state: True state - titles: Titles for each subplot - figsize: Figure size - """ - # Convert to numpy if torch tensors - if torch.is_tensor(background): - background = background.cpu().numpy() - if torch.is_tensor(observations): - observations = observations.cpu().numpy() - if torch.is_tensor(analysis): - analysis = analysis.cpu().numpy() - if torch.is_tensor(true_state): - true_state = true_state.cpu().numpy() - - if titles is None: - titles = ['Background Error', 'Observation Error', 'Analysis Error'] - - fig, axes = plt.subplots(1, 3, figsize=figsize) - - # Calculate errors - def get_first_element(data): - if data.ndim == 4: # [batch, channels, height, width] - return data[0, 0] # Take first sample, first channel - elif data.ndim == 3: # [batch, height, width] - return data[0] - else: - return data - - bg_error = get_first_element(background) - get_first_element(true_state) - obs_error = get_first_element(observations) - get_first_element(true_state) - analysis_error = get_first_element(analysis) - get_first_element(true_state) - - # Determine common color scale for errors (centered at 0) - max_error = max(np.abs(bg_error).max(), - np.abs(obs_error).max(), - np.abs(analysis_error).max()) - - # Plot error maps - im1 = axes[0].imshow(bg_error if bg_error.ndim == 2 else bg_error[0], - cmap='RdBu_r', vmin=-max_error, vmax=max_error) - axes[0].set_title(titles[0]) - axes[0].axis('off') - plt.colorbar(im1, ax=axes[0]) - - im2 = axes[1].imshow(obs_error if obs_error.ndim == 2 else obs_error[0], - cmap='RdBu_r', vmin=-max_error, vmax=max_error) - axes[1].set_title(titles[1]) - axes[1].axis('off') - plt.colorbar(im2, ax=axes[1]) - - im3 = axes[2].imshow(analysis_error if analysis_error.ndim == 2 else analysis_error[0], - cmap='RdBu_r', vmin=-max_error, vmax=max_error) - axes[2].set_title(titles[2]) - axes[2].axis('off') - plt.colorbar(im3, ax=axes[2]) - - plt.tight_layout() - plt.show() - - -def plot_rmse_comparison(metrics_dict, title="RMSE Comparison"): - """ - Plot RMSE comparison between different methods - - Args: - metrics_dict: Dictionary with method names as keys and RMSE values as values - title: Title for the plot - """ - methods = list(metrics_dict.keys()) - rmse_values = list(metrics_dict.values()) - - plt.figure(figsize=(10, 6)) - bars = plt.bar(methods, rmse_values, color=['skyblue', 'lightcoral', 'lightgreen', 'gold']) - plt.ylabel('RMSE') - plt.title(title) - plt.xticks(rotation=45) - - # Add value labels on bars - for bar, value in zip(bars, rmse_values): - plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(rmse_values)*0.01, - f'{value:.3f}', ha='center', va='bottom') - - plt.tight_layout() - plt.show() - - -def plot_improvement_heatmap(improvement_matrix, title="Improvement Heatmap"): - """ - Plot improvement heatmap showing where analysis is better than background - - Args: - improvement_matrix: Matrix showing improvement at each grid point - title: Title for the plot - """ - plt.figure(figsize=(8, 6)) - sns.heatmap(improvement_matrix, annot=True, fmt='.2f', cmap='RdYlGn', center=0, - cbar_kws={'label': 'Improvement'}) - plt.title(title) - plt.tight_layout() - plt.show() - - -def plot_time_series_comparison(time_series_data, labels=None, title="Time Series Comparison"): - """ - Plot time series comparison of metrics - - Args: - time_series_data: List of time series to plot - labels: Labels for each series - title: Title for the plot - """ - plt.figure(figsize=(12, 6)) - - for i, series in enumerate(time_series_data): - label = labels[i] if labels else f'Series {i+1}' - plt.plot(series, label=label, linewidth=2) - - plt.xlabel('Time Steps') - plt.ylabel('Value') - plt.title(title) - plt.legend() - plt.grid(True, alpha=0.3) - plt.tight_layout() - plt.show() - - -def plot_histogram_comparison(true_state, background, analysis, bins=50, - title="Distribution Comparison"): - """ - Plot histogram comparison of distributions - - Args: - true_state: True state values - background: Background state values - analysis: Analysis state values - bins: Number of histogram bins - title: Title for the plot - """ - # Convert to numpy if torch tensors - if torch.is_tensor(true_state): - true_state = true_state.cpu().numpy() - if torch.is_tensor(background): - background = background.cpu().numpy() - if torch.is_tensor(analysis): - analysis = analysis.cpu().numpy() - - plt.figure(figsize=(10, 6)) - - true_flat = true_state.flatten() - bg_flat = background.flatten() - analysis_flat = analysis.flatten() - - plt.hist(true_flat, bins=bins, alpha=0.5, label='True State', density=True) - plt.hist(bg_flat, bins=bins, alpha=0.5, label='Background', density=True) - plt.hist(analysis_flat, bins=bins, alpha=0.5, label='Analysis', density=True) - - plt.xlabel('Value') - plt.ylabel('Density') - plt.title(title) - plt.legend() - plt.grid(True, alpha=0.3) - plt.tight_layout() - plt.show() - - -def plot_scatter_comparison(true_state, background, analysis, - title="Scatter Plot Comparison"): - """ - Plot scatter comparison showing correlation between true and predicted values - - Args: - true_state: True state values - background: Background state values - analysis: Analysis state values - title: Title for the plot - """ - # Convert to numpy if torch tensors - if torch.is_tensor(true_state): - true_state = true_state.cpu().numpy() - if torch.is_tensor(background): - background = background.cpu().numpy() - if torch.is_tensor(analysis): - analysis = analysis.cpu().numpy() - - fig, axes = plt.subplots(1, 2, figsize=(12, 5)) - - true_flat = true_state.flatten() - bg_flat = background.flatten() - analysis_flat = analysis.flatten() - - # Background vs True - axes[0].scatter(true_flat, bg_flat, alpha=0.5) - min_val = min(true_flat.min(), bg_flat.min()) - max_val = max(true_flat.max(), bg_flat.max()) - axes[0].plot([min_val, max_val], [min_val, max_val], 'r--', lw=2) - axes[0].set_xlabel('True State') - axes[0].set_ylabel('Background') - axes[0].set_title('Background vs True') - axes[0].grid(True, alpha=0.3) - - # Analysis vs True - axes[1].scatter(true_flat, analysis_flat, alpha=0.5) - axes[1].plot([min_val, max_val], [min_val, max_val], 'r--', lw=2) - axes[1].set_xlabel('True State') - axes[1].set_ylabel('Analysis') - axes[1].set_title('Analysis vs True') - axes[1].grid(True, alpha=0.3) - - plt.suptitle(title) - plt.tight_layout() - plt.show() - - -def plot_convergence_analysis(train_losses, val_losses, title="Convergence Analysis"): - """ - Plot detailed convergence analysis - - Args: - train_losses: Training losses - val_losses: Validation losses - title: Title for the plot - """ - fig, axes = plt.subplots(2, 2, figsize=(15, 10)) - - epochs = range(1, len(train_losses) + 1) - - # Training and validation loss - axes[0, 0].plot(epochs, train_losses, label='Training Loss', color='blue') - axes[0, 0].plot(epochs, val_losses, label='Validation Loss', color='red') - axes[0, 0].set_xlabel('Epoch') - axes[0, 0].set_ylabel('Loss') - axes[0, 0].set_title('Training and Validation Loss') - axes[0, 0].legend() - axes[0, 0].grid(True, alpha=0.3) - - # Log scale - axes[0, 1].semilogy(epochs, train_losses, label='Training Loss', color='blue') - axes[0, 1].semilogy(epochs, val_losses, label='Validation Loss', color='red') - axes[0, 1].set_xlabel('Epoch') - axes[0, 1].set_ylabel('Loss (log scale)') - axes[0, 1].set_title('Loss (Log Scale)') - axes[0, 1].legend() - axes[0, 1].grid(True, alpha=0.3) - - # Loss difference - loss_diff = np.array(train_losses) - np.array(val_losses) - axes[1, 0].plot(epochs, loss_diff, color='purple') - axes[1, 0].axhline(y=0, color='black', linestyle='--', alpha=0.5) - axes[1, 0].set_xlabel('Epoch') - axes[1, 0].set_ylabel('Training - Validation Loss') - axes[1, 0].set_title('Overfitting Indicator') - axes[1, 0].grid(True, alpha=0.3) - - # Improvement per epoch - improvement = np.diff(train_losses) - axes[1, 1].plot(epochs[1:], improvement, color='green') - axes[1, 1].set_xlabel('Epoch') - axes[1, 1].set_ylabel('Loss Improvement') - axes[1, 1].set_title('Improvement per Epoch') - axes[1, 1].grid(True, alpha=0.3) - - plt.suptitle(title) - plt.tight_layout() - plt.show() - - -def plot_parameter_analysis(model, title="Parameter Analysis"): - """ - Plot analysis of model parameters - - Args: - model: PyTorch model - title: Title for the plot - """ - param_norms = [] - param_names = [] - - for name, param in model.named_parameters(): - if param.requires_grad: - param_norm = param.data.norm().item() - param_norms.append(param_norm) - param_names.append(name) - - if not param_norms: # Handle case where no parameters require gradients - print("No parameters require gradients to visualize") - return - - plt.figure(figsize=(12, 6)) - bars = plt.bar(range(len(param_names)), param_norms) - plt.xlabel('Parameters') - plt.ylabel('L2 Norm') - plt.title(title) - plt.xticks(range(len(param_names)), [name.split('.')[-1] for name in param_names], - rotation=45, ha='right') - - # Add value labels on bars - for bar, value in zip(bars, param_norms): - plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(param_norms)*0.01, - f'{value:.3f}', ha='center', va='bottom', fontsize=8) - - plt.tight_layout() - plt.show() - - -def create_summary_dashboard(metrics, figsize=(16, 12)): - """ - Create a comprehensive dashboard summarizing all results - - Args: - metrics: Dictionary with all evaluation metrics - figsize: Figure size for the dashboard - """ - fig = plt.figure(figsize=figsize) - - # Define grid for subplots - gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3) - - # 1. Training curves (if available) - if 'train_losses' in metrics and 'val_losses' in metrics: - ax1 = fig.add_subplot(gs[0, 0]) - ax1.plot(metrics['train_losses'], label='Train', color='blue') - ax1.plot(metrics['val_losses'], label='Val', color='red') - ax1.set_title('Training Curves') - ax1.set_xlabel('Epoch') - ax1.set_ylabel('Loss') - ax1.legend() - ax1.grid(True, alpha=0.3) - - # 2. RMSE comparison - ax2 = fig.add_subplot(gs[0, 1]) - rmse_methods = [] - rmse_values = [] - for key, value in metrics.items(): - if 'rmse' in key.lower(): - rmse_methods.append(key.replace('_rmse', '').replace('avg_', '').title()) - rmse_values.append(value) - - if rmse_methods: - ax2.bar(rmse_methods, rmse_values, color=['skyblue', 'lightcoral', 'lightgreen']) - ax2.set_title('RMSE Comparison') - ax2.set_ylabel('RMSE') - ax2.tick_params(axis='x', rotation=45) - - # 3. Correlation comparison - ax3 = fig.add_subplot(gs[0, 2]) - corr_methods = [] - corr_values = [] - for key, value in metrics.items(): - if 'correlation' in key.lower(): - corr_methods.append(key.replace('_correlation', '').replace('avg_', '').title()) - corr_values.append(value) - - if corr_methods: - ax3.bar(corr_methods, corr_values, color=['gold', 'orange']) - ax3.set_title('Correlation Comparison') - ax3.set_ylabel('Correlation') - ax3.tick_params(axis='x', rotation=45) - - # 4. Bias comparison - ax4 = fig.add_subplot(gs[1, 0]) - bias_methods = [] - bias_values = [] - for key, value in metrics.items(): - if 'bias' in key.lower(): - bias_methods.append(key.replace('_bias', '').replace('avg_', '').title()) - bias_values.append(value) - - if bias_methods: - ax4.bar(bias_methods, bias_values, color=['lightblue', 'pink']) - ax4.set_title('Bias Comparison') - ax4.set_ylabel('Bias') - ax4.tick_params(axis='x', rotation=45) - ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5) - - # 5. Improvement metrics - ax5 = fig.add_subplot(gs[1, 1]) - improvement_metrics = [] - improvement_values = [] - for key, value in metrics.items(): - if 'improvement' in key.lower(): - improvement_metrics.append(key.replace('avg_', '').replace('_pct', '%').title()) - improvement_values.append(value) - - if improvement_metrics: - bars = ax5.bar(improvement_metrics, improvement_values, - color=['lightgreen' if v > 0 else 'lightcoral' for v in improvement_values]) - ax5.set_title('Improvement Metrics') - ax5.set_ylabel('Improvement (%)') - ax5.tick_params(axis='x', rotation=45) - ax5.axhline(y=0, color='black', linestyle='--', alpha=0.5) - - # Add value labels - for bar, value in zip(bars, improvement_values): - height = bar.get_height() - ax5.text(bar.get_x() + bar.get_width()/2, height + (max(improvement_values)*0.01 if max(improvement_values) > 0 else min(improvement_values)*0.01), - f'{value:.1f}%', ha='center', va='bottom' if value >= 0 else 'top') - - # 6. Information gain - if 'avg_information_gain' in metrics: - ax6 = fig.add_subplot(gs[1, 2]) - info_gain = metrics['avg_information_gain'] - ax6.bar(['Information Gain'], [info_gain], color='mediumpurple') - ax6.set_title('Information Gain') - ax6.set_ylabel('Gain (%)') - ax6.text(0, info_gain + max(info_gain*0.01, 0.1), f'{info_gain:.1f}%', - ha='center', va='bottom') - - # 7. Parameter norms (if model available) - if 'model' in metrics: - ax7 = fig.add_subplot(gs[2, :]) - param_norms = [] - param_names = [] - for name, param in metrics['model'].named_parameters(): - if param.requires_grad: - param_norm = param.data.norm().item() - param_norms.append(param_norm) - param_names.append(name.split('.')[-1][:10]) # Shorten names - - if param_norms: # Only plot if there are parameters to show - # Only show top parameters to avoid overcrowding - top_indices = np.argsort(param_norms)[-10:][::-1] # Top 10 largest - top_norms = [param_norms[i] for i in top_indices] - top_names = [param_names[i] for i in top_indices] - - ax7.bar(top_names, top_norms, color='lightsteelblue') - ax7.set_title('Top 10 Parameter Norms') - ax7.set_ylabel('L2 Norm') - ax7.tick_params(axis='x', rotation=45) - else: - ax7.text(0.5, 0.5, 'No parameters to display', horizontalalignment='center', - verticalalignment='center', transform=ax7.transAxes, fontsize=14) - ax7.set_title('Parameter Norms') - ax7.set_xticks([]) - ax7.set_yticks([]) - - plt.suptitle('Data Assimilation Results Dashboard', fontsize=16) - plt.show() - - -def visualize_observation_locations(observations, obs_mask, title="Observation Locations"): - """ - Visualize where observations are available - - Args: - observations: Observation tensor - obs_mask: Boolean mask indicating observation locations - title: Title for the plot - """ - plt.figure(figsize=(8, 6)) - - # Create a visualization where observed locations are highlighted - obs_visual = torch.zeros_like(observations[0, 0]) if len(observations.shape) > 2 else torch.zeros_like(observations[0]) - obs_visual[obs_mask] = 1 - - plt.imshow(obs_visual.cpu().numpy(), cmap='viridis', interpolation='none') - plt.title(title) - plt.colorbar(label='Observation Present (1) / Missing (0)') - plt.show() \ No newline at end of file diff --git a/graph_weather/test_data_assimilation.py b/graph_weather/test_data_assimilation.py deleted file mode 100644 index 13eaeee3..00000000 --- a/graph_weather/test_data_assimilation.py +++ /dev/null @@ -1,322 +0,0 @@ -""" -Test script for the complete self-supervised data assimilation pipeline -""" - -import torch -import numpy as np -from graph_weather.graph_weather.models.data_assimilation import ( - DataAssimilationModel, - SimpleDataAssimilationModel, - ThreeDVarLoss, - generate_synthetic_data -) -from graph_weather.graph_weather.data.assimilation_dataloader import ( - AssimilationDataModule, - create_synthetic_assimilation_dataset -) -from graph_weather.graph_weather.models.training_loop import ( - DataAssimilationTrainer, - train_data_assimilation_model, - compare_with_baselines -) -from graph_weather.graph_weather.models.evaluation import ( - DataAssimilationEvaluator, - compare_methods, - compute_rmse -) -from graph_weather.graph_weather.models.visualization import ( - plot_training_curves, - plot_comparison_grid, - plot_error_maps, - plot_rmse_comparison, - create_summary_dashboard -) - - -def test_basic_3dvar_loss(): - """Test the basic 3D-Var loss function""" - print("Testing 3D-Var loss function...") - - # Create sample data - batch_size, grid_size = 4, (5, 5) - background = torch.randn(batch_size, 1, *grid_size) - observations = torch.randn(batch_size, 1, *grid_size) - analysis = torch.randn(batch_size, 1, *grid_size) - - # Initialize loss function - loss_fn = ThreeDVarLoss() - - # Compute loss - loss = loss_fn(analysis, background, observations) - print(f"3D-Var loss: {loss.item():.4f}") - - # Test with custom covariances - B = torch.eye(grid_size[0] * grid_size[1]) * 0.5 - R = torch.eye(grid_size[0] * grid_size[1]) * 0.3 - loss_fn_custom = ThreeDVarLoss( - background_error_covariance=B, - observation_error_covariance=R - ) - - loss_custom = loss_fn_custom(analysis, background, observations) - print(f"3D-Var loss with custom covariances: {loss_custom.item():.4f}") - - print("✓ 3D-Var loss test passed\n") - - -def test_data_assimilation_model(): - """Test the data assimilation model""" - print("Testing data assimilation model...") - - # Test simple FC model - input_dim = 50 # 5x5 grid with 2 channels (bg + obs) - model = DataAssimilationModel(input_dim=input_dim) - - batch_size = 4 - background = torch.randn(batch_size, input_dim // 2) - observations = torch.randn(batch_size, input_dim // 2) - - analysis = model(background, observations) - print(f"FC Model - Input shape: {background.shape}, Output shape: {analysis.shape}") - - # Test convolutional model - grid_size = (5, 5) - model_conv = SimpleDataAssimilationModel(grid_size=grid_size, num_channels=1) - - background_conv = torch.randn(batch_size, 1, *grid_size) - observations_conv = torch.randn(batch_size, 1, *grid_size) - - analysis_conv = model_conv(background_conv, observations_conv) - print(f"Conv Model - Input shape: {background_conv.shape}, Output shape: {analysis_conv.shape}") - - print("✓ Data assimilation model test passed\n") - - -def test_training_pipeline(): - """Test the complete training pipeline""" - print("Testing training pipeline...") - - # Create synthetic data - data_module = AssimilationDataModule( - num_samples=200, - grid_size=(8, 8), - num_channels=1, - bg_error_std=0.5, - obs_error_std=0.3, - obs_fraction=0.6, - batch_size=16 - ) - data_module.setup() - - # Initialize model and loss - model = SimpleDataAssimilationModel( - grid_size=(8, 8), - num_channels=1, - hidden_dim=32, - num_layers=2 - ) - - # Train the model - trainer, results = train_data_assimilation_model( - model=model, - train_loader=data_module.train_dataloader(), - val_loader=data_module.val_dataloader(), - epochs=10, # Small number for testing - lr=1e-3, - device='cpu' - ) - - print(f"Final training loss: {results['train_losses'][-1]:.4f}") - print(f"Final validation loss: {results['val_losses'][-1]:.4f}") - - # Plot training curves - plot_training_curves( - results['train_losses'], - results['val_losses'], - title="Test Training Curves" - ) - - print("✓ Training pipeline test passed\n") - - return trainer, data_module - - -def test_evaluation_pipeline(trainer, data_module): - """Test the evaluation pipeline""" - print("Testing evaluation pipeline...") - - # Evaluate the trained model - eval_results = trainer.evaluate_model(data_module.test_dataloader(), compute_metrics=True) - - print("Evaluation Results:") - for key, value in eval_results.items(): - print(f" {key}: {value:.4f}") - - # Initialize evaluator - evaluator = DataAssimilationEvaluator(trainer.model, device='cpu') - overall_metrics = evaluator.evaluate_dataset(data_module.test_dataloader()) - - print("\nOverall Metrics:") - for key, value in overall_metrics.items(): - print(f" {key}: {value:.4f}") - - print("✓ Evaluation pipeline test passed\n") - - return overall_metrics - - -def test_comparison_with_baselines(trainer, data_module): - """Test comparison with classical baselines""" - print("Testing comparison with baselines...") - - comparison = compare_with_baselines( - trainer.model, - data_module.test_dataloader(), - device='cpu' - ) - - print("Baseline Comparison Results:") - for key, value in comparison.items(): - print(f" {key}: {value:.4f}") - - # Create RMSE comparison plot - rmse_comparison = { - 'Analysis': comparison['avg_analysis_rmse'], - 'Background': comparison['avg_background_rmse'], - 'Observations': comparison['avg_observation_rmse'], - 'Persistence': comparison['avg_persistence_rmse'] - } - - plot_rmse_comparison(rmse_comparison, title="RMSE Comparison with Baselines") - - print("✓ Baseline comparison test passed\n") - - return comparison - - -def test_visualization_pipeline(): - """Test visualization capabilities""" - print("Testing visualization pipeline...") - - # Generate sample data for visualization - batch_size, grid_size = 1, (6, 6) - background, observations, true_state = generate_synthetic_data( - batch_size=batch_size, - grid_size=grid_size, - num_channels=1 - ) - - # Create a simple "analysis" (for demonstration) - analysis = (background + observations) / 2 # Simple average - - # Test comparison grid - plot_comparison_grid( - background, observations, analysis, true_state, - titles=['Background', 'Observations', 'Analysis', 'True State'] - ) - - # Test error maps - plot_error_maps( - background, observations, analysis, true_state, - titles=['Background Error', 'Observation Error', 'Analysis Error'] - ) - - print("✓ Visualization pipeline test passed\n") - - -def run_comprehensive_test(): - """Run a comprehensive test of the entire pipeline""" - print("="*60) - print("COMPREHENSIVE TEST: Self-Supervised Data Assimilation Pipeline") - print("="*60) - - # Test 1: Basic components - test_basic_3dvar_loss() - - # Test 2: Model architecture - test_data_assimilation_model() - - # Test 3: Training pipeline - trainer, data_module = test_training_pipeline() - - # Test 4: Evaluation pipeline - eval_metrics = test_evaluation_pipeline(trainer, data_module) - - # Test 5: Baseline comparison - comparison_results = test_comparison_with_baselines(trainer, data_module) - - # Test 6: Visualization - test_visualization_pipeline() - - # Final summary dashboard - print("Creating summary dashboard...") - summary_metrics = eval_metrics.copy() - summary_metrics.update(comparison_results) - summary_metrics['model'] = trainer.model # Include model for parameter analysis - - create_summary_dashboard(summary_metrics) - - print("\n" + "="*60) - print("ALL TESTS COMPLETED SUCCESSFULLY!") - print("="*60) - - # Print key results - print(f"\nKey Results:") - print(f"- Analysis RMSE: {comparison_results['avg_analysis_rmse']:.4f}") - print(f"- Background RMSE: {comparison_results['avg_background_rmse']:.4f}") - print(f"- Analysis improvement over background: {comparison_results['analysis_improvement_over_bg']:.2f}%") - print(f"- Analysis improvement over observations: {comparison_results['analysis_improvement_over_obs']:.2f}%") - - return True - - -def test_different_training_modes(): - """Test the model under different training conditions""" - print("\n" + "="*60) - print("TESTING DIFFERENT TRAINING MODES") - print("="*60) - - from graph_weather.graph_weather.models.training_loop import train_with_different_modes - - # Test with different configurations - results = train_with_different_modes( - model_class=lambda **kwargs: SimpleDataAssimilationModel( - grid_size=(6, 6), - num_channels=1, - hidden_dim=16, - num_layers=2 - ), - data_module=AssimilationDataModule, - grid_size=(6, 6), - num_channels=1, - epochs=5, # Few epochs for testing - lr=1e-3, - device='cpu' - ) - - print("\nTraining Mode Results:") - for mode, result in results.items(): - eval_res = result['eval_results'] - print(f"\n{mode.upper()} MODE:") - print(f" Analysis RMSE: {eval_res.get('analysis_rmse', 'N/A')}") - print(f" Background RMSE: {eval_res.get('background_rmse', 'N/A')}") - print(f" Improvement over background: {eval_res.get('improvement_over_bg', 'N/A')}%") - - return results - - -if __name__ == "__main__": - # Set random seed for reproducibility - torch.manual_seed(42) - np.random.seed(42) - - # Run comprehensive test - success = run_comprehensive_test() - - # Test different training modes - mode_results = test_different_training_modes() - - print("\n" + "="*60) - print("ALL TESTS COMPLETED SUCCESSFULLY!") - print("Self-Supervised Data Assimilation Pipeline is working correctly.") - print("="*60) \ No newline at end of file From 93f2d03a202d4795398c112115a1b4dc9a5689f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jan 2026 10:54:28 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../models/data_assimilation/__init__.py | 12 ++-- .../data_assimilation_base.py | 63 +++++++++++-------- .../test_data_assimilation_base.py | 44 +++++++------ 3 files changed, 68 insertions(+), 51 deletions(-) diff --git a/graph_weather/models/data_assimilation/__init__.py b/graph_weather/models/data_assimilation/__init__.py index a0a120be..2c40d060 100644 --- a/graph_weather/models/data_assimilation/__init__.py +++ b/graph_weather/models/data_assimilation/__init__.py @@ -6,9 +6,9 @@ from .variational_da import VariationalDA __all__ = [ - 'DataAssimilationBase', - 'EnsembleGenerator', - 'KalmanFilterDA', - 'ParticleFilterDA', - 'VariationalDA' -] \ No newline at end of file + "DataAssimilationBase", + "EnsembleGenerator", + "KalmanFilterDA", + "ParticleFilterDA", + "VariationalDA", +] diff --git a/graph_weather/models/data_assimilation/data_assimilation_base.py b/graph_weather/models/data_assimilation/data_assimilation_base.py index 75b67174..2c88c2a8 100644 --- a/graph_weather/models/data_assimilation/data_assimilation_base.py +++ b/graph_weather/models/data_assimilation/data_assimilation_base.py @@ -1,17 +1,19 @@ """Base classes for data assimilation modules.""" + import abc -from typing import Union, Dict, Any, Optional +from typing import Any, Dict, Union + import torch from torch_geometric.data import Data class EnsembleGenerator: """Class to generate ensemble members from a background state.""" - + def __init__(self, noise_std: float = 0.1, method: str = "gaussian"): self.noise_std = noise_std self.method = method - + def generate_ensemble(self, state: Union[torch.Tensor, Data], num_members: int): if isinstance(state, torch.Tensor): return self._generate_tensor_ensemble(state, num_members) @@ -19,11 +21,11 @@ def generate_ensemble(self, state: Union[torch.Tensor, Data], num_members: int): return self._generate_graph_ensemble(state, num_members) else: raise TypeError(f"Unsupported state type: {type(state)}") - + def _generate_tensor_ensemble(self, state: torch.Tensor, num_members: int) -> torch.Tensor: batch_size, nodes, features = state.shape ensemble = torch.zeros(batch_size, num_members, nodes, features, device=state.device) - + for i in range(num_members): if self.method == "gaussian": noise = torch.randn_like(state) * self.noise_std @@ -33,16 +35,22 @@ def _generate_tensor_ensemble(self, state: torch.Tensor, num_members: int) -> to noise = torch.randn_like(state) * self.noise_std * 0.1 ensemble[:, i] = (state * mask) + noise elif self.method == "perturbation": - perturbation = torch.randn_like(state) * self.noise_std * torch.linspace(0.1, 1.0, num_members)[i] + perturbation = ( + torch.randn_like(state) + * self.noise_std + * torch.linspace(0.1, 1.0, num_members)[i] + ) ensemble[:, i] = state + perturbation else: raise ValueError(f"Unknown ensemble generation method: {self.method}") - + return ensemble - + def _generate_graph_ensemble(self, state: Data, num_members: int) -> Data: - x_expanded = torch.zeros(state.x.shape[0], num_members, state.x.shape[1], device=state.x.device) - + x_expanded = torch.zeros( + state.x.shape[0], num_members, state.x.shape[1], device=state.x.device + ) + for i in range(num_members): if self.method == "gaussian": noise = torch.randn_like(state.x) * self.noise_std @@ -52,46 +60,51 @@ def _generate_graph_ensemble(self, state: Data, num_members: int) -> Data: noise = torch.randn_like(state.x) * self.noise_std * 0.1 x_expanded[:, i] = (state.x * mask) + noise elif self.method == "perturbation": - perturbation = torch.randn_like(state.x) * self.noise_std * torch.linspace(0.1, 1.0, num_members)[i] + perturbation = ( + torch.randn_like(state.x) + * self.noise_std + * torch.linspace(0.1, 1.0, num_members)[i] + ) x_expanded[:, i] = state.x + perturbation else: raise ValueError(f"Unknown ensemble generation method: {self.method}") - + new_state = Data( x=x_expanded, edge_index=state.edge_index, - edge_attr=getattr(state, 'edge_attr', None), - pos=getattr(state, 'pos', None) + edge_attr=getattr(state, "edge_attr", None), + pos=getattr(state, "pos", None), ) - + return new_state class DataAssimilationBase(abc.ABC): """Abstract base class for data assimilation modules.""" - + def __init__(self, config: Dict[str, Any]): self.config = config self.ensemble_generator = EnsembleGenerator( - noise_std=config.get('noise_std', 0.1), - method=config.get('ensemble_method', 'gaussian') + noise_std=config.get("noise_std", 0.1), method=config.get("ensemble_method", "gaussian") ) - + @abc.abstractmethod def initialize_ensemble(self, background_state: Union[torch.Tensor, Data], num_members: int): pass - + @abc.abstractmethod def assimilate(self, ensemble: Union[torch.Tensor, Data], observations: torch.Tensor): pass - + @abc.abstractmethod def _compute_analysis(self, ensemble: Union[torch.Tensor, Data]) -> Union[torch.Tensor, Data]: pass - - def forward(self, state: Union[torch.Tensor, Data], observations: torch.Tensor, num_ensemble: int = 10): + + def forward( + self, state: Union[torch.Tensor, Data], observations: torch.Tensor, num_ensemble: int = 10 + ): ensemble = self.initialize_ensemble(state, num_ensemble) updated_ensemble = self.assimilate(ensemble, observations) analysis = self._compute_analysis(updated_ensemble) - - return updated_ensemble, analysis \ No newline at end of file + + return updated_ensemble, analysis diff --git a/tests/models/data_assimilation/test_data_assimilation_base.py b/tests/models/data_assimilation/test_data_assimilation_base.py index cb71a940..7e20a0f9 100644 --- a/tests/models/data_assimilation/test_data_assimilation_base.py +++ b/tests/models/data_assimilation/test_data_assimilation_base.py @@ -3,10 +3,11 @@ from torch_geometric.data import Data import sys -sys.path.insert(0, '../../../graph_weather/models/data_assimilation') + +sys.path.insert(0, "../../../graph_weather/models/data_assimilation") # Execute modules directly to avoid import issues -exec(open('graph_weather/models/data_assimilation/data_assimilation_base.py').read()) +exec(open("graph_weather/models/data_assimilation/data_assimilation_base.py").read()) class MockDA(DataAssimilationBase): @@ -30,11 +31,11 @@ def _compute_analysis(self, ensemble): def test_ensemble_generator_tensor(): """Test ensemble generation for tensor inputs.""" generator = EnsembleGenerator(noise_std=0.1, method="gaussian") - + # Test tensor input state = torch.randn(2, 5, 3) # [batch, nodes, features] ensemble = generator.generate_ensemble(state, 4) - + assert ensemble.shape == (2, 4, 5, 3) # [batch, members, nodes, features] assert not torch.equal(state, ensemble[:, 0]) # Should have noise added @@ -42,17 +43,17 @@ def test_ensemble_generator_tensor(): def test_ensemble_generator_graph(): """Test ensemble generation for graph inputs.""" generator = EnsembleGenerator(noise_std=0.1, method="gaussian") - + # Test graph input x = torch.randn(10, 4) # Node features edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) graph_state = Data(x=x, edge_index=edge_index) - + ensemble = generator.generate_ensemble(graph_state, 3) - + # Check that ensemble preserves structure - assert hasattr(ensemble, 'x') - assert hasattr(ensemble, 'edge_index') + assert hasattr(ensemble, "x") + assert hasattr(ensemble, "edge_index") assert ensemble.x.shape[1] == 3 # Ensemble dimension @@ -60,9 +61,9 @@ def test_data_assimilation_base_abstract_methods(): """Test that abstract methods are properly defined.""" config = {"param": "value"} da_module = MockDA(config) - + assert da_module.config == config - + # Test ensemble generation state = torch.randn(2, 5, 3) ensemble = da_module.initialize_ensemble(state, 4) @@ -72,16 +73,19 @@ def test_data_assimilation_base_abstract_methods(): def test_compute_analysis_tensor(): """Test analysis computation for tensor ensembles.""" da_module = MockDA({}) - + # Create ensemble: [batch, members, nodes, features] - ensemble = torch.stack([ - torch.ones(2, 5, 3), # First member - 2 * torch.ones(2, 5, 3), # Second member - 3 * torch.ones(2, 5, 3), # Third member - ], dim=1) # Shape: [2, 3, 5, 3] - + ensemble = torch.stack( + [ + torch.ones(2, 5, 3), # First member + 2 * torch.ones(2, 5, 3), # Second member + 3 * torch.ones(2, 5, 3), # Third member + ], + dim=1, + ) # Shape: [2, 3, 5, 3] + analysis = da_module._compute_analysis(ensemble) - + # Mean should be (1 + 2 + 3) / 3 = 2 expected = 2 * torch.ones(2, 5, 3) assert torch.allclose(analysis, expected) @@ -92,4 +96,4 @@ def test_compute_analysis_tensor(): test_ensemble_generator_graph() test_data_assimilation_base_abstract_methods() test_compute_analysis_tensor() - print("All tests passed!") \ No newline at end of file + print("All tests passed!") From b4d996c682a5bfa162a2431977ca73c11808b128 Mon Sep 17 00:00:00 2001 From: SOHAMPAL23 Date: Sun, 22 Feb 2026 13:42:43 +0530 Subject: [PATCH 6/6] Updation in the Data Assimilation --- graph_weather/__init__.py | 31 ++++++++-- graph_weather/models/__init__.py | 62 +++++++++++++++---- .../models/data_assimilation/__init__.py | 6 -- .../test_data_assimilation_base.py | 14 +---- 4 files changed, 78 insertions(+), 35 deletions(-) diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index b33e23cd..454eacd0 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,6 +1,29 @@ """Main import for the complete models""" -from .data.nnja_ai import SensorDataset -from .data.weather_station_reader import WeatherStationReader -from .models.analysis import GraphWeatherAssimilator -from .models.forecast import GraphWeatherForecaster +# Using lazy loading to avoid dependency conflicts +def __getattr__(name): + """Lazy loading for all modules to avoid dependency conflicts.""" + if name == "GraphWeatherAssimilator": + from .models.analysis import GraphWeatherAssimilator as GWA + globals()[name] = GWA + return GWA + elif name == "GraphWeatherForecaster": + from .models.forecast import GraphWeatherForecaster as GWF + globals()[name] = GWF + return GWF + elif name == "SensorDataset": + from .data.nnja_ai import SensorDataset as SD + globals()[name] = SD + return SD + elif name == "WeatherStationReader": + from .data.weather_station_reader import WeatherStationReader as WSR + globals()[name] = WSR + return WSR + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + +__all__ = [ + "GraphWeatherAssimilator", + "GraphWeatherForecaster", + "SensorDataset", + "WeatherStationReader", +] diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index 3710e24a..4415db99 100755 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,15 +1,51 @@ """Models""" -from .fengwu_ghr.layers import ( - ImageMetaModel, - LoRAModule, - MetaModel, - WrapperImageModel, - WrapperMetaModel, -) -from .layers.assimilator_decoder import AssimilatorDecoder -from .layers.assimilator_encoder import AssimilatorEncoder -from .layers.decoder import Decoder -from .layers.encoder import Encoder -from .layers.processor import Processor -from .layers.stochastic_decomposition import StochasticDecompositionLayer +# Using lazy loading to avoid dependency conflicts + +__all__ = [] + +def __getattr__(name): + """Lazy loading for models to avoid dependency conflicts.""" + if name in ['ImageMetaModel', 'LoRAModule', 'MetaModel', 'WrapperImageModel', 'WrapperMetaModel']: + from .fengwu_ghr.layers import ( + ImageMetaModel as IM, + LoRAModule as LM, + MetaModel as MM, + WrapperImageModel as WIM, + WrapperMetaModel as WMM, + ) + result = { + 'ImageMetaModel': IM, + 'LoRAModule': LM, + 'MetaModel': MM, + 'WrapperImageModel': WIM, + 'WrapperMetaModel': WMM, + }[name] + globals()[name] = result + return result + elif name == 'AssimilatorDecoder': + from .layers.assimilator_decoder import AssimilatorDecoder as AD + globals()[name] = AD + return AD + elif name == 'AssimilatorEncoder': + from .layers.assimilator_encoder import AssimilatorEncoder as AE + globals()[name] = AE + return AE + elif name == 'Decoder': + from .layers.decoder import Decoder as D + globals()[name] = D + return D + elif name == 'Encoder': + from .layers.encoder import Encoder as E + globals()[name] = E + return E + elif name == 'Processor': + from .layers.processor import Processor as P + globals()[name] = P + return P + elif name == 'StochasticDecompositionLayer': + from .layers.stochastic_decomposition import StochasticDecompositionLayer as SDL + globals()[name] = SDL + return SDL + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/graph_weather/models/data_assimilation/__init__.py b/graph_weather/models/data_assimilation/__init__.py index 2c40d060..53e8c969 100644 --- a/graph_weather/models/data_assimilation/__init__.py +++ b/graph_weather/models/data_assimilation/__init__.py @@ -1,14 +1,8 @@ """Data assimilation module initialization.""" from .data_assimilation_base import DataAssimilationBase, EnsembleGenerator -from .kalman_filter_da import KalmanFilterDA -from .particle_filter_da import ParticleFilterDA -from .variational_da import VariationalDA __all__ = [ "DataAssimilationBase", "EnsembleGenerator", - "KalmanFilterDA", - "ParticleFilterDA", - "VariationalDA", ] diff --git a/tests/models/data_assimilation/test_data_assimilation_base.py b/tests/models/data_assimilation/test_data_assimilation_base.py index 7e20a0f9..c86738d7 100644 --- a/tests/models/data_assimilation/test_data_assimilation_base.py +++ b/tests/models/data_assimilation/test_data_assimilation_base.py @@ -2,12 +2,7 @@ import torch from torch_geometric.data import Data -import sys - -sys.path.insert(0, "../../../graph_weather/models/data_assimilation") - -# Execute modules directly to avoid import issues -exec(open("graph_weather/models/data_assimilation/data_assimilation_base.py").read()) +from graph_weather.models.data_assimilation import DataAssimilationBase, EnsembleGenerator class MockDA(DataAssimilationBase): @@ -91,9 +86,4 @@ def test_compute_analysis_tensor(): assert torch.allclose(analysis, expected) -if __name__ == "__main__": - test_ensemble_generator_tensor() - test_ensemble_generator_graph() - test_data_assimilation_base_abstract_methods() - test_compute_analysis_tensor() - print("All tests passed!") +