Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions graph_weather/models/gencast/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,37 @@
"""

import gc
from dataclasses import dataclass

import numpy as np
import torch
from dacite import from_dict
from torch_geometric.data import Data, HeteroData

# from torch_geometric.transforms import TwoHop
from graph_weather.models.gencast.graph import grid_mesh_connectivity, icosahedral_mesh, model_utils

# Some configs from graphcast:
_spatial_features_kwargs = dict(
add_node_positions=False,
add_node_latitude=True,
add_node_longitude=True,
add_relative_positions=True,
relative_longitude_local_coordinates=True,
relative_latitude_local_coordinates=True,
)

@dataclass
class SpatialFeaturesConfig:
add_node_positions: bool = False
add_node_latitude: bool = True
add_node_longitude: bool = True
add_relative_positions: bool = True
relative_longitude_local_coordinates: bool = True
relative_latitude_local_coordinates: bool = True


config_dict = {
"add_node_positions": False,
"add_node_latitude": True,
"add_node_longitude": True,
"add_relative_positions": True,
"relative_longitude_local_coordinates": True,
"relative_latitude_local_coordinates": True,
}

_spatial_features_kwargs = from_dict(data_class=SpatialFeaturesConfig, data=config_dict)

# radius_query_fraction_edge_length: Scalar that will be multiplied by the
# length of the longest edge of the finest mesh to define the radius of
Expand Down
21 changes: 20 additions & 1 deletion graph_weather/models/gencast/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,3"

from dataclasses import dataclass # noqa: E402
from typing import List # noqa: E402

import lightning as L # noqa: E402
import matplotlib.pyplot as plt # noqa: E402
import numpy as np # noqa: E402
import torch # noqa: E402
from dacite import from_dict # noqa: E402
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint # noqa: E402
from lightning.pytorch.loggers import WandbLogger # noqa: E402
from torch.utils.data import DataLoader # noqa: E402
Expand All @@ -37,7 +41,21 @@

# model configs
CHECKPOINT_PATH = "checkpoints/epoch=3-step=10776.ckpt"
CFG = {


@dataclass
class Config:
hidden_dims: List[int]
num_blocks: int
num_heads: int
splits: int
num_hops: int
sparse: bool
use_edges_features: bool
scale_factor: float


config = {
"hidden_dims": [512, 512],
"num_blocks": 16,
"num_heads": 4,
Expand All @@ -47,6 +65,7 @@
"use_edges_features": False,
"scale_factor": 1.0,
}
CFG = from_dict(data_class=Config, data=config)

# dataset configs
atmospheric_features = [
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ xarray
setuptools
pydantic
safetensors
dacite