From a4072c2e05e0d2a841244badec2baba1b0a4617c Mon Sep 17 00:00:00 2001 From: ACSE-vg822 Date: Sat, 25 Jan 2025 18:07:03 +0000 Subject: [PATCH 1/2] changed config in two files --- .../models/gencast/graph/graph_builder.py | 32 +++++++++++++------ graph_weather/models/gencast/train.py | 18 ++++++++++- requirements.txt | 1 + 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/graph_weather/models/gencast/graph/graph_builder.py b/graph_weather/models/gencast/graph/graph_builder.py index 68b57678..58b1e71b 100644 --- a/graph_weather/models/gencast/graph/graph_builder.py +++ b/graph_weather/models/gencast/graph/graph_builder.py @@ -14,18 +14,32 @@ import torch from torch_geometric.data import Data, HeteroData +from dataclasses import dataclass +from dacite import from_dict + + # from torch_geometric.transforms import TwoHop from graph_weather.models.gencast.graph import grid_mesh_connectivity, icosahedral_mesh, model_utils -# Some configs from graphcast: -_spatial_features_kwargs = dict( - add_node_positions=False, - add_node_latitude=True, - add_node_longitude=True, - add_relative_positions=True, - relative_longitude_local_coordinates=True, - relative_latitude_local_coordinates=True, -) +@dataclass +class SpatialFeaturesConfig: + add_node_positions: bool = False + add_node_latitude: bool = True + add_node_longitude: bool = True + add_relative_positions: bool = True + relative_longitude_local_coordinates: bool = True + relative_latitude_local_coordinates: bool = True + +config_dict = { + "add_node_positions": False, + "add_node_latitude": True, + "add_node_longitude": True, + "add_relative_positions": True, + "relative_longitude_local_coordinates": True, + "relative_latitude_local_coordinates": True, +} + +_spatial_features_kwargs = from_dict(data_class=SpatialFeaturesConfig, data=config_dict) # radius_query_fraction_edge_length: Scalar that will be multiplied by the # length of the longest edge of the finest mesh to define the radius of diff --git a/graph_weather/models/gencast/train.py b/graph_weather/models/gencast/train.py index a6bdb127..1643dffe 100644 --- a/graph_weather/models/gencast/train.py +++ b/graph_weather/models/gencast/train.py @@ -14,6 +14,9 @@ from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint # noqa: E402 from lightning.pytorch.loggers import WandbLogger # noqa: E402 from torch.utils.data import DataLoader # noqa: E402 +from dataclasses import dataclass # noqa: E402 +from typing import List # noqa: E402 +from dacite import from_dict # noqa: E402 from graph_weather.data.gencast_dataloader import GenCastDataset # noqa: E402 from graph_weather.models.gencast import Denoiser, Sampler, WeightedMSELoss # noqa: E402 @@ -37,7 +40,19 @@ # model configs CHECKPOINT_PATH = "checkpoints/epoch=3-step=10776.ckpt" -CFG = { + +@dataclass +class Config: + hidden_dims: List[int] + num_blocks: int + num_heads: int + splits: int + num_hops: int + sparse: bool + use_edges_features: bool + scale_factor: float + +config = { "hidden_dims": [512, 512], "num_blocks": 16, "num_heads": 4, @@ -47,6 +62,7 @@ "use_edges_features": False, "scale_factor": 1.0, } +CFG = from_dict(data_class=Config, data=config) # dataset configs atmospheric_features = [ diff --git a/requirements.txt b/requirements.txt index 79d45e1b..2bc337dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ xarray setuptools pydantic safetensors +dacite \ No newline at end of file From 41fa2b790611428f94ef532ccfc6e67839ce8066 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 25 Jan 2025 18:12:17 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/models/gencast/graph/graph_builder.py | 8 ++++---- graph_weather/models/gencast/train.py | 9 ++++++--- requirements.txt | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/graph_weather/models/gencast/graph/graph_builder.py b/graph_weather/models/gencast/graph/graph_builder.py index 58b1e71b..8a48aeb0 100644 --- a/graph_weather/models/gencast/graph/graph_builder.py +++ b/graph_weather/models/gencast/graph/graph_builder.py @@ -9,18 +9,17 @@ """ import gc +from dataclasses import dataclass import numpy as np import torch -from torch_geometric.data import Data, HeteroData - -from dataclasses import dataclass from dacite import from_dict - +from torch_geometric.data import Data, HeteroData # from torch_geometric.transforms import TwoHop from graph_weather.models.gencast.graph import grid_mesh_connectivity, icosahedral_mesh, model_utils + @dataclass class SpatialFeaturesConfig: add_node_positions: bool = False @@ -30,6 +29,7 @@ class SpatialFeaturesConfig: relative_longitude_local_coordinates: bool = True relative_latitude_local_coordinates: bool = True + config_dict = { "add_node_positions": False, "add_node_latitude": True, diff --git a/graph_weather/models/gencast/train.py b/graph_weather/models/gencast/train.py index 1643dffe..0526c36c 100644 --- a/graph_weather/models/gencast/train.py +++ b/graph_weather/models/gencast/train.py @@ -7,16 +7,17 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0,3" +from dataclasses import dataclass # noqa: E402 +from typing import List # noqa: E402 + import lightning as L # noqa: E402 import matplotlib.pyplot as plt # noqa: E402 import numpy as np # noqa: E402 import torch # noqa: E402 +from dacite import from_dict # noqa: E402 from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint # noqa: E402 from lightning.pytorch.loggers import WandbLogger # noqa: E402 from torch.utils.data import DataLoader # noqa: E402 -from dataclasses import dataclass # noqa: E402 -from typing import List # noqa: E402 -from dacite import from_dict # noqa: E402 from graph_weather.data.gencast_dataloader import GenCastDataset # noqa: E402 from graph_weather.models.gencast import Denoiser, Sampler, WeightedMSELoss # noqa: E402 @@ -41,6 +42,7 @@ # model configs CHECKPOINT_PATH = "checkpoints/epoch=3-step=10776.ckpt" + @dataclass class Config: hidden_dims: List[int] @@ -52,6 +54,7 @@ class Config: use_edges_features: bool scale_factor: float + config = { "hidden_dims": [512, 512], "num_blocks": 16, diff --git a/requirements.txt b/requirements.txt index 2bc337dd..7612a6a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ xarray setuptools pydantic safetensors -dacite \ No newline at end of file +dacite