Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cchdo/hydro/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def check_ancillary_variables(ds: xr.Dataset):
"""Check that everything in an ancillary_variables attribute appears as a variable
Check that every variable that is known ancillary appears in at least one ancillary_variable attribute
"""
looks_ancillary_suffixes = ("_qc", "_error")
looks_ancillary_suffixes = ("_qc", "_error", "_url")

ancillary_variables_attrs = defaultdict(list)
looks_ancillary = set()
Expand Down
102 changes: 79 additions & 23 deletions src/cchdo/hydro/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Core operations on a CCHDO CF/netCDF file."""

from collections.abc import Hashable
from enum import StrEnum, auto
from typing import cast

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -40,38 +42,77 @@
}


class ColumnType(StrEnum):
DATA = auto()
FLAG = auto()
ERROR = auto()
URL = auto()


def dataarray_factory(
param: WHPName,
ctype="data",
ctype: ColumnType = ColumnType.DATA,
N_PROF=0,
N_LEVELS=0,
strlen=0,
url_shape=(),
) -> xr.DataArray:
dtype = dtype_map[param.dtype]
if param.dtype == "string":
dtype = f"U{strlen}"
fill = FILLS_MAP[param.dtype]
name = param.full_nc_name
scope = param.scope

if ctype == "flag":
if ctype == ColumnType.FLAG:
dtype = dtype_map["integer"]
fill = FILLS_MAP["integer"]
name = param.nc_name_flag

if param.scope == "profile":
arr = np.full((N_PROF), fill_value=fill, dtype=dtype)
if param.scope == "sample":
arr = np.full((N_PROF, N_LEVELS), fill_value=fill, dtype=dtype)
if ctype == ColumnType.ERROR:
name = param.nc_name_error

if ctype == ColumnType.URL:
dtype = np.dtypes.StringDType
fill = FILLS_MAP["string"]
name = f"{name}_url" # TODO: upstream to cchdo.params

match url_shape:
case (int(n),) | int(n):
scope = cast(
int, n
) # TODO: remove cast when ty has support for type narrowing in match statements
case () if param.scope in ("sample", "profile", "cruise"):
scope = "cruise"
case ("N_PROF",) | "N_PROF" if param.scope in ("sample", "profile"):
scope = "profile"
case ("N_PROF", "N_LEVELS") if param.scope in ("sample"):
scope = "sample"
case _:
raise ValueError

match scope:
case "profile":
arr = np.full((N_PROF), fill_value=fill, dtype=dtype)
var_da = xr.DataArray(arr, dims=DIMS[: arr.ndim], name=name)
case "sample":
arr = np.full((N_PROF, N_LEVELS), fill_value=fill, dtype=dtype)
var_da = xr.DataArray(arr, dims=DIMS[: arr.ndim], name=name)
case "cruise":
arr = np.full((), fill_value=fill, dtype=dtype)
var_da = xr.DataArray(arr, dims=DIMS[: arr.ndim], name=name)
case int(n):
arr = np.full(n, fill_value=fill, dtype=dtype)
var_da = xr.DataArray(arr, dims=f"N_{name}", name=name)

attrs = param.get_nc_attrs()
if "C_format" in attrs:
attrs["C_format_source"] = "database"

if ctype == "error":
if ctype == ColumnType.ERROR:
attrs = param.get_nc_attrs(error=True)
name = param.nc_name_error

if ctype == "flag" and param.flag_w in FLAG_SCHEME:
if ctype == ColumnType.FLAG and param.flag_w in FLAG_SCHEME:
flag_defs = FLAG_SCHEME[param.flag_w]
flag_values = []
flag_meanings = []
Expand All @@ -91,8 +132,10 @@ def dataarray_factory(
"flag_meanings": " ".join(flag_meanings),
"conventions": odv_conventions_map[param.flag_w],
}
if ctype == ColumnType.URL:
attrs = {}

var_da = xr.DataArray(arr, dims=DIMS[: arr.ndim], attrs=attrs, name=name)
var_da.attrs.update(attrs)

if param.dtype != "decimal":
try:
Expand All @@ -116,7 +159,7 @@ def dataarray_factory(
if param.dtype == "integer":
var_da = var_da.fillna(-999).astype("int32")

if ctype == "flag":
if ctype == ColumnType.FLAG:
var_da.encoding["dtype"] = "int8"
var_da.encoding["_FillValue"] = 9

Expand Down Expand Up @@ -223,6 +266,7 @@ def add_param(
with_flag: bool = False,
with_error: bool = False,
with_ancillary=None,
with_url=None,
) -> xr.Dataset:
"""Add a new parameter, and optionally some ancillary associated variables to a dataset.

Expand Down Expand Up @@ -252,17 +296,16 @@ def add_param(
)
vars_to_add.append(var)

ancillary: set[str] = set(var.attrs.get("ancillary_variables", "").split())

if with_flag and _param.nc_name_flag not in _ds:
flag_var = dataarray_factory(
_param,
N_PROF=ds.sizes["N_PROF"],
N_LEVELS=ds.sizes["N_LEVELS"],
ctype="flag",
ctype=ColumnType.FLAG,
)
ancillary = var.attrs.get("ancillary_variables", "").split()
if flag_var.name not in ancillary:
ancillary.append(flag_var.name)
var.attrs["ancillary_variables"] = " ".join(sorted(ancillary))
ancillary.add(str(flag_var.name))
vars_to_add.append(flag_var)

if with_error and _param.full_error_name is None:
Expand All @@ -273,16 +316,29 @@ def add_param(
_param,
N_PROF=ds.sizes["N_PROF"],
N_LEVELS=ds.sizes["N_LEVELS"],
ctype="error",
ctype=ColumnType.ERROR,
)
ancillary = var.attrs.get("ancillary_variables", "").split()
if error_var.name not in ancillary:
ancillary.append(error_var.name)
var.attrs["ancillary_variables"] = " ".join(sorted(ancillary))
ancillary.add(str(error_var.name))
vars_to_add.append(error_var)

for var in vars_to_add:
_ds[var.name] = var
if (
with_url is not None and f"{_param.full_nc_name}_url" not in _ds
): # TODO: upstream to cchdo.params
url_var = dataarray_factory(
_param,
N_PROF=ds.sizes["N_PROF"],
N_LEVELS=ds.sizes["N_LEVELS"],
ctype=ColumnType.URL,
url_shape=with_url,
)
ancillary.add(str(url_var.name))
vars_to_add.append(url_var)

if ancillary:
var.attrs["ancillary_variables"] = " ".join(sorted(ancillary))

for add_var in vars_to_add:
_ds[add_var.name] = add_var

_ds = add_cdom_coordinate(_ds)
check_ancillary_variables(_ds)
Expand Down
Loading