diff --git a/src/cchdo/hydro/checks.py b/src/cchdo/hydro/checks.py index 055a994..8713b60 100644 --- a/src/cchdo/hydro/checks.py +++ b/src/cchdo/hydro/checks.py @@ -33,7 +33,7 @@ def check_ancillary_variables(ds: xr.Dataset): """Check that everything in an ancillary_variables attribute appears as a variable Check that every variable that is known ancillary appears in at least one ancillary_variable attribute """ - looks_ancillary_suffixes = ("_qc", "_error") + looks_ancillary_suffixes = ("_qc", "_error", "_url") ancillary_variables_attrs = defaultdict(list) looks_ancillary = set() diff --git a/src/cchdo/hydro/core.py b/src/cchdo/hydro/core.py index 8492616..60c6789 100644 --- a/src/cchdo/hydro/core.py +++ b/src/cchdo/hydro/core.py @@ -1,6 +1,8 @@ """Core operations on a CCHDO CF/netCDF file.""" from collections.abc import Hashable +from enum import StrEnum, auto +from typing import cast import numpy as np import numpy.typing as npt @@ -40,38 +42,77 @@ } +class ColumnType(StrEnum): + DATA = auto() + FLAG = auto() + ERROR = auto() + URL = auto() + + def dataarray_factory( param: WHPName, - ctype="data", + ctype: ColumnType = ColumnType.DATA, N_PROF=0, N_LEVELS=0, strlen=0, + url_shape=(), ) -> xr.DataArray: dtype = dtype_map[param.dtype] if param.dtype == "string": dtype = f"U{strlen}" fill = FILLS_MAP[param.dtype] name = param.full_nc_name + scope = param.scope - if ctype == "flag": + if ctype == ColumnType.FLAG: dtype = dtype_map["integer"] fill = FILLS_MAP["integer"] name = param.nc_name_flag - if param.scope == "profile": - arr = np.full((N_PROF), fill_value=fill, dtype=dtype) - if param.scope == "sample": - arr = np.full((N_PROF, N_LEVELS), fill_value=fill, dtype=dtype) + if ctype == ColumnType.ERROR: + name = param.nc_name_error + + if ctype == ColumnType.URL: + dtype = np.dtypes.StringDType + fill = FILLS_MAP["string"] + name = f"{name}_url" # TODO: upstream to cchdo.params + + match url_shape: + case (int(n),) | int(n): + scope = cast( + int, n + ) # TODO: remove cast when ty has support for type narrowing in match statements + case () if param.scope in ("sample", "profile", "cruise"): + scope = "cruise" + case ("N_PROF",) | "N_PROF" if param.scope in ("sample", "profile"): + scope = "profile" + case ("N_PROF", "N_LEVELS") if param.scope in ("sample"): + scope = "sample" + case _: + raise ValueError + + match scope: + case "profile": + arr = np.full((N_PROF), fill_value=fill, dtype=dtype) + var_da = xr.DataArray(arr, dims=DIMS[: arr.ndim], name=name) + case "sample": + arr = np.full((N_PROF, N_LEVELS), fill_value=fill, dtype=dtype) + var_da = xr.DataArray(arr, dims=DIMS[: arr.ndim], name=name) + case "cruise": + arr = np.full((), fill_value=fill, dtype=dtype) + var_da = xr.DataArray(arr, dims=DIMS[: arr.ndim], name=name) + case int(n): + arr = np.full(n, fill_value=fill, dtype=dtype) + var_da = xr.DataArray(arr, dims=f"N_{name}", name=name) attrs = param.get_nc_attrs() if "C_format" in attrs: attrs["C_format_source"] = "database" - if ctype == "error": + if ctype == ColumnType.ERROR: attrs = param.get_nc_attrs(error=True) - name = param.nc_name_error - if ctype == "flag" and param.flag_w in FLAG_SCHEME: + if ctype == ColumnType.FLAG and param.flag_w in FLAG_SCHEME: flag_defs = FLAG_SCHEME[param.flag_w] flag_values = [] flag_meanings = [] @@ -91,8 +132,10 @@ def dataarray_factory( "flag_meanings": " ".join(flag_meanings), "conventions": odv_conventions_map[param.flag_w], } + if ctype == ColumnType.URL: + attrs = {} - var_da = xr.DataArray(arr, dims=DIMS[: arr.ndim], attrs=attrs, name=name) + var_da.attrs.update(attrs) if param.dtype != "decimal": try: @@ -116,7 +159,7 @@ def dataarray_factory( if param.dtype == "integer": var_da = var_da.fillna(-999).astype("int32") - if ctype == "flag": + if ctype == ColumnType.FLAG: var_da.encoding["dtype"] = "int8" var_da.encoding["_FillValue"] = 9 @@ -223,6 +266,7 @@ def add_param( with_flag: bool = False, with_error: bool = False, with_ancillary=None, + with_url=None, ) -> xr.Dataset: """Add a new parameter, and optionally some ancillary associated variables to a dataset. @@ -252,17 +296,16 @@ def add_param( ) vars_to_add.append(var) + ancillary: set[str] = set(var.attrs.get("ancillary_variables", "").split()) + if with_flag and _param.nc_name_flag not in _ds: flag_var = dataarray_factory( _param, N_PROF=ds.sizes["N_PROF"], N_LEVELS=ds.sizes["N_LEVELS"], - ctype="flag", + ctype=ColumnType.FLAG, ) - ancillary = var.attrs.get("ancillary_variables", "").split() - if flag_var.name not in ancillary: - ancillary.append(flag_var.name) - var.attrs["ancillary_variables"] = " ".join(sorted(ancillary)) + ancillary.add(str(flag_var.name)) vars_to_add.append(flag_var) if with_error and _param.full_error_name is None: @@ -273,16 +316,29 @@ def add_param( _param, N_PROF=ds.sizes["N_PROF"], N_LEVELS=ds.sizes["N_LEVELS"], - ctype="error", + ctype=ColumnType.ERROR, ) - ancillary = var.attrs.get("ancillary_variables", "").split() - if error_var.name not in ancillary: - ancillary.append(error_var.name) - var.attrs["ancillary_variables"] = " ".join(sorted(ancillary)) + ancillary.add(str(error_var.name)) vars_to_add.append(error_var) - for var in vars_to_add: - _ds[var.name] = var + if ( + with_url is not None and f"{_param.full_nc_name}_url" not in _ds + ): # TODO: upstream to cchdo.params + url_var = dataarray_factory( + _param, + N_PROF=ds.sizes["N_PROF"], + N_LEVELS=ds.sizes["N_LEVELS"], + ctype=ColumnType.URL, + url_shape=with_url, + ) + ancillary.add(str(url_var.name)) + vars_to_add.append(url_var) + + if ancillary: + var.attrs["ancillary_variables"] = " ".join(sorted(ancillary)) + + for add_var in vars_to_add: + _ds[add_var.name] = add_var _ds = add_cdom_coordinate(_ds) check_ancillary_variables(_ds)