diff --git a/graph_weather/models/gencast/graph/graph_builder.py b/graph_weather/models/gencast/graph/graph_builder.py index 68b57678..8a48aeb0 100644 --- a/graph_weather/models/gencast/graph/graph_builder.py +++ b/graph_weather/models/gencast/graph/graph_builder.py @@ -9,23 +9,37 @@ """ import gc +from dataclasses import dataclass import numpy as np import torch +from dacite import from_dict from torch_geometric.data import Data, HeteroData # from torch_geometric.transforms import TwoHop from graph_weather.models.gencast.graph import grid_mesh_connectivity, icosahedral_mesh, model_utils -# Some configs from graphcast: -_spatial_features_kwargs = dict( - add_node_positions=False, - add_node_latitude=True, - add_node_longitude=True, - add_relative_positions=True, - relative_longitude_local_coordinates=True, - relative_latitude_local_coordinates=True, -) + +@dataclass +class SpatialFeaturesConfig: + add_node_positions: bool = False + add_node_latitude: bool = True + add_node_longitude: bool = True + add_relative_positions: bool = True + relative_longitude_local_coordinates: bool = True + relative_latitude_local_coordinates: bool = True + + +config_dict = { + "add_node_positions": False, + "add_node_latitude": True, + "add_node_longitude": True, + "add_relative_positions": True, + "relative_longitude_local_coordinates": True, + "relative_latitude_local_coordinates": True, +} + +_spatial_features_kwargs = from_dict(data_class=SpatialFeaturesConfig, data=config_dict) # radius_query_fraction_edge_length: Scalar that will be multiplied by the # length of the longest edge of the finest mesh to define the radius of diff --git a/graph_weather/models/gencast/train.py b/graph_weather/models/gencast/train.py index a6bdb127..0526c36c 100644 --- a/graph_weather/models/gencast/train.py +++ b/graph_weather/models/gencast/train.py @@ -7,10 +7,14 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0,3" +from dataclasses import dataclass # noqa: E402 +from typing import List # noqa: E402 + import lightning as L # noqa: E402 import matplotlib.pyplot as plt # noqa: E402 import numpy as np # noqa: E402 import torch # noqa: E402 +from dacite import from_dict # noqa: E402 from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint # noqa: E402 from lightning.pytorch.loggers import WandbLogger # noqa: E402 from torch.utils.data import DataLoader # noqa: E402 @@ -37,7 +41,21 @@ # model configs CHECKPOINT_PATH = "checkpoints/epoch=3-step=10776.ckpt" -CFG = { + + +@dataclass +class Config: + hidden_dims: List[int] + num_blocks: int + num_heads: int + splits: int + num_hops: int + sparse: bool + use_edges_features: bool + scale_factor: float + + +config = { "hidden_dims": [512, 512], "num_blocks": 16, "num_heads": 4, @@ -47,6 +65,7 @@ "use_edges_features": False, "scale_factor": 1.0, } +CFG = from_dict(data_class=Config, data=config) # dataset configs atmospheric_features = [ diff --git a/requirements.txt b/requirements.txt index 79d45e1b..7612a6a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ xarray setuptools pydantic safetensors +dacite