Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions scoringrules/_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
estimator_check,
mv_weighted_score_chain,
mv_weighted_score_weights,
apply_nan_policy_ens_mv,
)

if tp.TYPE_CHECKING:
from scoringrules.core.typing import Array, ArrayLike, Backend
from scoringrules.core.typing import Array, ArrayLike, Backend, NanPolicy


def es_ensemble(
Expand All @@ -21,6 +22,7 @@ def es_ensemble(
*,
ens_w: "Array" = None,
estimator: str = "nrg",
nan_policy: "NanPolicy" = "propagate",
backend: "Backend" = None,
) -> "Array":
r"""Compute the Energy Score for a finite multivariate ensemble.
Expand Down Expand Up @@ -50,6 +52,12 @@ def es_ensemble(
Default is equal weighting. Weights are normalised so that they sum to one across the ensemble members.
estimator : str
The energy score estimator to be used.
nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate'
How to handle NaN ensemble members (a member is invalid if any of its
variables, or its weight, is NaN). 'propagate' lets NaN flow to a NaN
score; 'raise' errors on any NaN; 'omit' drops invalid members by
zero-weighting them. 'omit' is not implemented for the 'akr' and
'akr_circperm' estimators. NaN in `obs` always propagates.
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Expand All @@ -71,14 +79,26 @@ def es_ensemble(
Some theoretical background on scoring rules for multivariate forecasts.
"""
obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend)

# apply nan policy before weight normalisation so NaN positions in a
# user-supplied ens_w stay detectable. For 'omit', invalid members become
# zero-weighted and the score is routed through the weighted path; ens_w is
# returned aligned ensemble-last.
obs, fct, ens_w = apply_nan_policy_ens_mv(
obs, fct, nan_policy, ens_w, estimator=estimator, m_axis=m_axis, backend=backend
)

if ens_w is None:
if backend == "numba":
estimator_check(estimator, energy.estimator_gufuncs)
return energy.estimator_gufuncs[estimator](obs, fct)
else:
return energy.es(obs, fct, estimator=estimator, backend=backend)
else:
ens_w = multivariate_weight_check(ens_w, fct, m_axis, backend=backend)
# ens_w is already aligned ensemble-last by the nan-policy helper, so
# m_axis=-2 maps to that last axis (a no-op realignment) while still
# running the shape/non-negativity checks and renormalisation.
ens_w = multivariate_weight_check(ens_w, fct, -2, backend=backend)
if backend == "numba":
estimator_check(estimator, energy.estimator_gufuncs_w)
return energy.estimator_gufuncs_w[estimator](obs, fct, ens_w)
Expand All @@ -95,6 +115,7 @@ def twes_ensemble(
*,
ens_w: "Array" = None,
estimator: str = "nrg",
nan_policy: "NanPolicy" = "propagate",
backend: "Backend" = None,
) -> "Array":
r"""Compute the Threshold-Weighted Energy Score (twES) for a finite multivariate ensemble.
Expand Down Expand Up @@ -129,6 +150,12 @@ def twes_ensemble(
Default is equal weighting. Weights are normalised so that they sum to one across the ensemble members.
estimator : str
The energy score estimator to be used.
nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate'
How to handle NaN ensemble members (a member is invalid if any of its
variables, or its weight, is NaN). 'propagate' lets NaN flow to a NaN
score; 'raise' errors on any NaN; 'omit' drops invalid members by
zero-weighting them. 'omit' is not implemented for the 'akr' and
'akr_circperm' estimators. NaN in `obs` always propagates.
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Expand All @@ -145,6 +172,7 @@ def twes_ensemble(
v_axis=v_axis,
ens_w=ens_w,
estimator=estimator,
nan_policy=nan_policy,
backend=backend,
)

Expand All @@ -157,6 +185,7 @@ def owes_ensemble(
v_axis: int = -1,
*,
ens_w: "Array" = None,
nan_policy: "NanPolicy" = "propagate",
backend: "Backend" = None,
) -> "Array":
r"""Compute the Outcome-Weighted Energy Score (owES) for a finite multivariate ensemble.
Expand Down Expand Up @@ -192,6 +221,11 @@ def owes_ensemble(
ens_w : array_like
Weights assigned to the ensemble members. Array with one less dimension than fct (without the v_axis dimension).
Default is equal weighting. Weights are normalised so that they sum to one across the ensemble members.
nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate'
How to handle NaN ensemble members (a member is invalid if any of its
variables, or its weight, is NaN). 'propagate' lets NaN flow to a NaN
score; 'raise' errors on any NaN; 'omit' drops invalid members by
zero-weighting them. NaN in `obs` always propagates.
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Expand All @@ -201,14 +235,19 @@ def owes_ensemble(
The computed Outcome-Weighted Energy Score.
"""
obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend)
# apply nan policy before deriving the outcome weights, so the (now zeroed)
# NaN members do not leak NaN into obs_w/fct_w; they carry zero ens_w anyway.
obs, fct, ens_w = apply_nan_policy_ens_mv(
obs, fct, nan_policy, ens_w, m_axis=m_axis, backend=backend
)
obs_w, fct_w = mv_weighted_score_weights(obs, fct, w_func=w_func, backend=backend)
if ens_w is None:
if backend == "numba":
return energy.estimator_gufuncs["ownrg"](obs, fct, obs_w, fct_w)
else:
return energy.owes(obs, fct, obs_w, fct_w, backend=backend)
else:
ens_w = multivariate_weight_check(ens_w, fct, m_axis, backend=backend)
ens_w = multivariate_weight_check(ens_w, fct, -2, backend=backend)
if backend == "numba":
return energy.estimator_gufuncs_w["ownrg"](obs, fct, obs_w, fct_w, ens_w)
else:
Expand All @@ -221,6 +260,7 @@ def vres_ensemble(
w_func: tp.Callable[["ArrayLike"], "ArrayLike"],
*,
ens_w: "Array" = None,
nan_policy: "NanPolicy" = "propagate",
m_axis: int = -2,
v_axis: int = -1,
backend: "Backend" = None,
Expand Down Expand Up @@ -259,6 +299,11 @@ def vres_ensemble(
ens_w : array_like
Weights assigned to the ensemble members. Array with one less dimension than fct (without the v_axis dimension).
Default is equal weighting. Weights are normalised so that they sum to one across the ensemble members.
nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate'
How to handle NaN ensemble members (a member is invalid if any of its
variables, or its weight, is NaN). 'propagate' lets NaN flow to a NaN
score; 'raise' errors on any NaN; 'omit' drops invalid members by
zero-weighting them. NaN in `obs` always propagates.
backend : str
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.

Expand All @@ -268,14 +313,18 @@ def vres_ensemble(
The computed Vertically Re-scaled Energy Score.
"""
obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend)
# apply nan policy before deriving the outcome weights (see owes_ensemble).
obs, fct, ens_w = apply_nan_policy_ens_mv(
obs, fct, nan_policy, ens_w, m_axis=m_axis, backend=backend
)
obs_w, fct_w = mv_weighted_score_weights(obs, fct, w_func=w_func, backend=backend)
if ens_w is None:
if backend == "numba":
return energy.estimator_gufuncs["vrnrg"](obs, fct, obs_w, fct_w)
else:
return energy.vres(obs, fct, obs_w, fct_w, backend=backend)
else:
ens_w = multivariate_weight_check(ens_w, fct, m_axis, backend=backend)
ens_w = multivariate_weight_check(ens_w, fct, -2, backend=backend)
if backend == "numba":
return energy.estimator_gufuncs_w["vrnrg"](obs, fct, obs_w, fct_w, ens_w)
else:
Expand Down
4 changes: 1 addition & 3 deletions scoringrules/core/energy/_gufuncs_w.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def _owenergy_score_gufunc_w(
):
"""Compute the Outcome-Weighted Energy Score for a finite ensemble."""
M = fct.shape[0]
ow = ow[0]

e_1 = 0.0
e_2 = 0.0
Expand All @@ -104,7 +103,7 @@ def _owenergy_score_gufunc_w(
* ens_w[j]
)

wbar = np.mean(fw)
wbar = np.sum(fw * ens_w)

out[0] = e_1 / (wbar) - 0.5 * e_2 / (wbar**2)

Expand All @@ -120,7 +119,6 @@ def _vrenergy_score_gufunc_w(
):
"""Compute the Vertically Re-scaled Energy Score for a finite ensemble."""
M = fct.shape[0]
ow = ow[0]

e_1 = 0.0
e_2 = 0.0
Expand Down
79 changes: 66 additions & 13 deletions scoringrules/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,33 +278,86 @@ def apply_nan_policy_ens_uv(
)


def apply_nan_policy_ens_mv(obs, fct, nan_policy="propagate", backend=None):
"""Apply NaN policy to multivariate ensemble forecasts (fct shape: ..., M, D).
def apply_nan_policy_ens_mv(
obs: "Array",
fct: "Array",
nan_policy: "NanPolicy" = "propagate",
ens_w: "Array | None" = None,
estimator: str | None = None,
m_axis: int = -2,
backend: "Backend" = None,
) -> "tuple[Array, Array, Array | None]":
"""Apply a NaN policy to multivariate ensemble forecasts (``fct`` shape ``..., M, D``).

A NaN in any variable of an ensemble member marks the entire member as invalid.
A member is treated as invalid if **any** of its variables is NaN, or if its
weight (``ens_w``) is NaN. NaN in ``obs`` always propagates and is never
omitted. All policies return ``(obs, fct, ens_w)`` so the caller keeps a
fixed signature. When supplied or produced, ``ens_w`` has shape ``..., M``
with its ensemble axis (``m_axis``) moved to the last position, aligned with
the already-permuted ``fct`` (whose ensemble axis is at ``-2`` and variable
axis at ``-1``).

For 'propagate': no-op, returns (obs, fct, None).
For 'raise': raises ValueError if any NaN in fct or obs.
For 'omit': returns (obs, fct_zeroed, nan_mask) where nan_mask has shape
(..., M) — True for invalid members — and NaN members are replaced with 0.0.
- ``'propagate'``: no-op; NaN members/weights flow through to a NaN result.
- ``'raise'``: raise ``ValueError`` if any NaN is present in ``obs``, ``fct``
or ``ens_w``.
- ``'omit'``: invalid members are given zero weight and the caller routes the
score through the weighted path. NaN forecast values are replaced with
``0.0`` purely to keep them out of the arithmetic (their value is
irrelevant once their weight is zero). ``ens_w`` is returned as a 0/1 mask
when no weights were supplied, otherwise with invalid entries zeroed. Not
implemented for the ``'akr'``/``'akr_circperm'`` estimators.

Note
----
A user-supplied ``ens_w`` is assumed to have its ensemble axis at the same
position as in ``fct`` with the variable axis *after* it (the default
``m_axis=-2, v_axis=-1``). Layouts that place the ensemble axis after the
variable axis are not supported for weighted inputs, matching the existing
assumption in :func:`multivariate_weight_check`.
"""
B = backends.active if backend is None else backends[backend]

if ens_w is not None:
# The user passes ens_w with its ensemble axis at the original m_axis
# (one fewer axis than fct: no variable axis). Mirror multivariate_weight_check:
# convert m_axis to positive relative to fct.ndim, then move that axis to
# the last position of ens_w. For the default (m_axis=-2, v_axis=-1) this
# is a no-op. Cast to float so isnan works on integer weights.
m_axis_p = m_axis if m_axis >= 0 else fct.ndim + m_axis
ens_w = B.moveaxis(B.asarray(ens_w, dtype=fct.dtype), m_axis_p, -1)

if nan_policy == "propagate":
return obs, fct, ens_w

nan_mask = B.any(B.isnan(fct), axis=-1) # (..., M)
if ens_w is not None:
nan_mask = nan_mask | B.isnan(ens_w)

if nan_policy == "raise":
if B.any(B.isnan(fct)) or B.any(B.isnan(obs)):
if B.any(nan_mask) or B.any(B.isnan(obs)):
raise ValueError(
"NaN values encountered in input. "
"Use nan_policy='propagate' or nan_policy='omit' to handle NaN values."
)
return obs, fct, None
return obs, fct, ens_w

if nan_policy == "omit":
nan_mask = B.any(B.isnan(fct), axis=-1) # shape (..., M)
if estimator in ["akr", "akr_circperm"]:
raise NotImplementedError(
f"NaN handling with nan_policy='omit' is not implemented for estimator '{estimator}'."
)
# zero only the NaN forecast values; zero-weighted members never
# contribute, so their (now 0.0) value does not affect the score.
fct = B.where(nan_mask[..., None], B.asarray(0.0), fct)
return obs, fct, nan_mask
if ens_w is None:
ens_w = B.asarray(~nan_mask, dtype=fct.dtype)
else:
ens_w = B.where(nan_mask, B.asarray(0.0), ens_w)
return obs, fct, ens_w

# propagate
return obs, fct, None
raise ValueError(
f"Invalid nan_policy '{nan_policy}'. Must be one of 'propagate', 'omit', 'raise'."
)


def lazy_gufunc_wrapper_uv(func):
Expand Down
Loading
Loading