diff --git a/scoringrules/_energy.py b/scoringrules/_energy.py index fdc101a..058169c 100644 --- a/scoringrules/_energy.py +++ b/scoringrules/_energy.py @@ -7,10 +7,11 @@ estimator_check, mv_weighted_score_chain, mv_weighted_score_weights, + apply_nan_policy_ens_mv, ) if tp.TYPE_CHECKING: - from scoringrules.core.typing import Array, ArrayLike, Backend + from scoringrules.core.typing import Array, ArrayLike, Backend, NanPolicy def es_ensemble( @@ -21,6 +22,7 @@ def es_ensemble( *, ens_w: "Array" = None, estimator: str = "nrg", + nan_policy: "NanPolicy" = "propagate", backend: "Backend" = None, ) -> "Array": r"""Compute the Energy Score for a finite multivariate ensemble. @@ -50,6 +52,12 @@ def es_ensemble( Default is equal weighting. Weights are normalised so that they sum to one across the ensemble members. estimator : str The energy score estimator to be used. + nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate' + How to handle NaN ensemble members (a member is invalid if any of its + variables, or its weight, is NaN). 'propagate' lets NaN flow to a NaN + score; 'raise' errors on any NaN; 'omit' drops invalid members by + zero-weighting them. 'omit' is not implemented for the 'akr' and + 'akr_circperm' estimators. NaN in `obs` always propagates. backend : str The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'. @@ -71,6 +79,15 @@ def es_ensemble( Some theoretical background on scoring rules for multivariate forecasts. """ obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend) + + # apply nan policy before weight normalisation so NaN positions in a + # user-supplied ens_w stay detectable. For 'omit', invalid members become + # zero-weighted and the score is routed through the weighted path; ens_w is + # returned aligned ensemble-last. + obs, fct, ens_w = apply_nan_policy_ens_mv( + obs, fct, nan_policy, ens_w, estimator=estimator, m_axis=m_axis, backend=backend + ) + if ens_w is None: if backend == "numba": estimator_check(estimator, energy.estimator_gufuncs) @@ -78,7 +95,10 @@ def es_ensemble( else: return energy.es(obs, fct, estimator=estimator, backend=backend) else: - ens_w = multivariate_weight_check(ens_w, fct, m_axis, backend=backend) + # ens_w is already aligned ensemble-last by the nan-policy helper, so + # m_axis=-2 maps to that last axis (a no-op realignment) while still + # running the shape/non-negativity checks and renormalisation. + ens_w = multivariate_weight_check(ens_w, fct, -2, backend=backend) if backend == "numba": estimator_check(estimator, energy.estimator_gufuncs_w) return energy.estimator_gufuncs_w[estimator](obs, fct, ens_w) @@ -95,6 +115,7 @@ def twes_ensemble( *, ens_w: "Array" = None, estimator: str = "nrg", + nan_policy: "NanPolicy" = "propagate", backend: "Backend" = None, ) -> "Array": r"""Compute the Threshold-Weighted Energy Score (twES) for a finite multivariate ensemble. @@ -129,6 +150,12 @@ def twes_ensemble( Default is equal weighting. Weights are normalised so that they sum to one across the ensemble members. estimator : str The energy score estimator to be used. + nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate' + How to handle NaN ensemble members (a member is invalid if any of its + variables, or its weight, is NaN). 'propagate' lets NaN flow to a NaN + score; 'raise' errors on any NaN; 'omit' drops invalid members by + zero-weighting them. 'omit' is not implemented for the 'akr' and + 'akr_circperm' estimators. NaN in `obs` always propagates. backend : str The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'. @@ -145,6 +172,7 @@ def twes_ensemble( v_axis=v_axis, ens_w=ens_w, estimator=estimator, + nan_policy=nan_policy, backend=backend, ) @@ -157,6 +185,7 @@ def owes_ensemble( v_axis: int = -1, *, ens_w: "Array" = None, + nan_policy: "NanPolicy" = "propagate", backend: "Backend" = None, ) -> "Array": r"""Compute the Outcome-Weighted Energy Score (owES) for a finite multivariate ensemble. @@ -192,6 +221,11 @@ def owes_ensemble( ens_w : array_like Weights assigned to the ensemble members. Array with one less dimension than fct (without the v_axis dimension). Default is equal weighting. Weights are normalised so that they sum to one across the ensemble members. + nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate' + How to handle NaN ensemble members (a member is invalid if any of its + variables, or its weight, is NaN). 'propagate' lets NaN flow to a NaN + score; 'raise' errors on any NaN; 'omit' drops invalid members by + zero-weighting them. NaN in `obs` always propagates. backend : str The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'. @@ -201,6 +235,11 @@ def owes_ensemble( The computed Outcome-Weighted Energy Score. """ obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend) + # apply nan policy before deriving the outcome weights, so the (now zeroed) + # NaN members do not leak NaN into obs_w/fct_w; they carry zero ens_w anyway. + obs, fct, ens_w = apply_nan_policy_ens_mv( + obs, fct, nan_policy, ens_w, m_axis=m_axis, backend=backend + ) obs_w, fct_w = mv_weighted_score_weights(obs, fct, w_func=w_func, backend=backend) if ens_w is None: if backend == "numba": @@ -208,7 +247,7 @@ def owes_ensemble( else: return energy.owes(obs, fct, obs_w, fct_w, backend=backend) else: - ens_w = multivariate_weight_check(ens_w, fct, m_axis, backend=backend) + ens_w = multivariate_weight_check(ens_w, fct, -2, backend=backend) if backend == "numba": return energy.estimator_gufuncs_w["ownrg"](obs, fct, obs_w, fct_w, ens_w) else: @@ -221,6 +260,7 @@ def vres_ensemble( w_func: tp.Callable[["ArrayLike"], "ArrayLike"], *, ens_w: "Array" = None, + nan_policy: "NanPolicy" = "propagate", m_axis: int = -2, v_axis: int = -1, backend: "Backend" = None, @@ -259,6 +299,11 @@ def vres_ensemble( ens_w : array_like Weights assigned to the ensemble members. Array with one less dimension than fct (without the v_axis dimension). Default is equal weighting. Weights are normalised so that they sum to one across the ensemble members. + nan_policy : {'propagate', 'omit', 'raise'}, default 'propagate' + How to handle NaN ensemble members (a member is invalid if any of its + variables, or its weight, is NaN). 'propagate' lets NaN flow to a NaN + score; 'raise' errors on any NaN; 'omit' drops invalid members by + zero-weighting them. NaN in `obs` always propagates. backend : str The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'. @@ -268,6 +313,10 @@ def vres_ensemble( The computed Vertically Re-scaled Energy Score. """ obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend) + # apply nan policy before deriving the outcome weights (see owes_ensemble). + obs, fct, ens_w = apply_nan_policy_ens_mv( + obs, fct, nan_policy, ens_w, m_axis=m_axis, backend=backend + ) obs_w, fct_w = mv_weighted_score_weights(obs, fct, w_func=w_func, backend=backend) if ens_w is None: if backend == "numba": @@ -275,7 +324,7 @@ def vres_ensemble( else: return energy.vres(obs, fct, obs_w, fct_w, backend=backend) else: - ens_w = multivariate_weight_check(ens_w, fct, m_axis, backend=backend) + ens_w = multivariate_weight_check(ens_w, fct, -2, backend=backend) if backend == "numba": return energy.estimator_gufuncs_w["vrnrg"](obs, fct, obs_w, fct_w, ens_w) else: diff --git a/scoringrules/core/energy/_gufuncs_w.py b/scoringrules/core/energy/_gufuncs_w.py index 7553a66..2cd8655 100644 --- a/scoringrules/core/energy/_gufuncs_w.py +++ b/scoringrules/core/energy/_gufuncs_w.py @@ -90,7 +90,6 @@ def _owenergy_score_gufunc_w( ): """Compute the Outcome-Weighted Energy Score for a finite ensemble.""" M = fct.shape[0] - ow = ow[0] e_1 = 0.0 e_2 = 0.0 @@ -104,7 +103,7 @@ def _owenergy_score_gufunc_w( * ens_w[j] ) - wbar = np.mean(fw) + wbar = np.sum(fw * ens_w) out[0] = e_1 / (wbar) - 0.5 * e_2 / (wbar**2) @@ -120,7 +119,6 @@ def _vrenergy_score_gufunc_w( ): """Compute the Vertically Re-scaled Energy Score for a finite ensemble.""" M = fct.shape[0] - ow = ow[0] e_1 = 0.0 e_2 = 0.0 diff --git a/scoringrules/core/utils.py b/scoringrules/core/utils.py index 592e8f8..fc99adc 100644 --- a/scoringrules/core/utils.py +++ b/scoringrules/core/utils.py @@ -278,33 +278,86 @@ def apply_nan_policy_ens_uv( ) -def apply_nan_policy_ens_mv(obs, fct, nan_policy="propagate", backend=None): - """Apply NaN policy to multivariate ensemble forecasts (fct shape: ..., M, D). +def apply_nan_policy_ens_mv( + obs: "Array", + fct: "Array", + nan_policy: "NanPolicy" = "propagate", + ens_w: "Array | None" = None, + estimator: str | None = None, + m_axis: int = -2, + backend: "Backend" = None, +) -> "tuple[Array, Array, Array | None]": + """Apply a NaN policy to multivariate ensemble forecasts (``fct`` shape ``..., M, D``). - A NaN in any variable of an ensemble member marks the entire member as invalid. + A member is treated as invalid if **any** of its variables is NaN, or if its + weight (``ens_w``) is NaN. NaN in ``obs`` always propagates and is never + omitted. All policies return ``(obs, fct, ens_w)`` so the caller keeps a + fixed signature. When supplied or produced, ``ens_w`` has shape ``..., M`` + with its ensemble axis (``m_axis``) moved to the last position, aligned with + the already-permuted ``fct`` (whose ensemble axis is at ``-2`` and variable + axis at ``-1``). - For 'propagate': no-op, returns (obs, fct, None). - For 'raise': raises ValueError if any NaN in fct or obs. - For 'omit': returns (obs, fct_zeroed, nan_mask) where nan_mask has shape - (..., M) — True for invalid members — and NaN members are replaced with 0.0. + - ``'propagate'``: no-op; NaN members/weights flow through to a NaN result. + - ``'raise'``: raise ``ValueError`` if any NaN is present in ``obs``, ``fct`` + or ``ens_w``. + - ``'omit'``: invalid members are given zero weight and the caller routes the + score through the weighted path. NaN forecast values are replaced with + ``0.0`` purely to keep them out of the arithmetic (their value is + irrelevant once their weight is zero). ``ens_w`` is returned as a 0/1 mask + when no weights were supplied, otherwise with invalid entries zeroed. Not + implemented for the ``'akr'``/``'akr_circperm'`` estimators. + + Note + ---- + A user-supplied ``ens_w`` is assumed to have its ensemble axis at the same + position as in ``fct`` with the variable axis *after* it (the default + ``m_axis=-2, v_axis=-1``). Layouts that place the ensemble axis after the + variable axis are not supported for weighted inputs, matching the existing + assumption in :func:`multivariate_weight_check`. """ B = backends.active if backend is None else backends[backend] + if ens_w is not None: + # The user passes ens_w with its ensemble axis at the original m_axis + # (one fewer axis than fct: no variable axis). Mirror multivariate_weight_check: + # convert m_axis to positive relative to fct.ndim, then move that axis to + # the last position of ens_w. For the default (m_axis=-2, v_axis=-1) this + # is a no-op. Cast to float so isnan works on integer weights. + m_axis_p = m_axis if m_axis >= 0 else fct.ndim + m_axis + ens_w = B.moveaxis(B.asarray(ens_w, dtype=fct.dtype), m_axis_p, -1) + + if nan_policy == "propagate": + return obs, fct, ens_w + + nan_mask = B.any(B.isnan(fct), axis=-1) # (..., M) + if ens_w is not None: + nan_mask = nan_mask | B.isnan(ens_w) + if nan_policy == "raise": - if B.any(B.isnan(fct)) or B.any(B.isnan(obs)): + if B.any(nan_mask) or B.any(B.isnan(obs)): raise ValueError( "NaN values encountered in input. " "Use nan_policy='propagate' or nan_policy='omit' to handle NaN values." ) - return obs, fct, None + return obs, fct, ens_w if nan_policy == "omit": - nan_mask = B.any(B.isnan(fct), axis=-1) # shape (..., M) + if estimator in ["akr", "akr_circperm"]: + raise NotImplementedError( + f"NaN handling with nan_policy='omit' is not implemented for estimator '{estimator}'." + ) + # zero only the NaN forecast values; zero-weighted members never + # contribute, so their (now 0.0) value does not affect the score. fct = B.where(nan_mask[..., None], B.asarray(0.0), fct) - return obs, fct, nan_mask + if ens_w is None: + ens_w = B.asarray(~nan_mask, dtype=fct.dtype) + else: + ens_w = B.where(nan_mask, B.asarray(0.0), ens_w) + return obs, fct, ens_w - # propagate - return obs, fct, None + raise ValueError( + f"Invalid nan_policy '{nan_policy}'. Must be one of 'propagate', 'omit', 'raise'." + ) def lazy_gufunc_wrapper_uv(func): diff --git a/tests/test_energy.py b/tests/test_energy.py index 6d7ec1e..e9dae17 100644 --- a/tests/test_energy.py +++ b/tests/test_energy.py @@ -86,3 +86,190 @@ def test_energy_score(estimator, backend): obs, fct, ens_w=w_scaled, estimator=estimator, backend=backend ) np.testing.assert_allclose(res_scaled, res, atol=1e-6) + + +OMIT_ESTIMATORS = ["nrg", "fair"] + + +@pytest.mark.parametrize("estimator", ESTIMATORS) +def test_energy_score_nan_policy(estimator, backend): + """es_ensemble with NaN forecast members, no user weights.""" + kwargs = {"estimator": estimator, "backend": backend} + + obs = np.random.randn(N, N_VARS) + fct = np.expand_dims(obs, axis=-2) + np.random.randn(N, ENSEMBLE_SIZE, N_VARS) + + fct_nan = fct.copy() + fct_nan[0, [0, 3, 6], 0] = np.nan + fct_nan[2, [5], 1] = np.nan + nan_positions = np.isnan(fct_nan).any(axis=(1, 2)) + + # default + explicit 'propagate': corrupted samples return NaN + for policy in (None, "propagate"): + pol_kwargs = kwargs if policy is None else {**kwargs, "nan_policy": policy} + res = np.asarray(sr.es_ensemble(obs, fct_nan, **pol_kwargs)) + assert np.all(np.isnan(res[nan_positions])) + assert not np.any(np.isnan(res[~nan_positions])) + + # 'raise' + with pytest.raises(ValueError): + sr.es_ensemble(obs, fct_nan, nan_policy="raise", **kwargs) + + # 'omit' unsupported for akr / akr_circperm + if estimator not in OMIT_ESTIMATORS: + with pytest.raises(NotImplementedError): + sr.es_ensemble(obs, fct_nan, nan_policy="omit", **kwargs) + return + + # 'omit': no NaN in result + res = np.asarray(sr.es_ensemble(obs, fct_nan, nan_policy="omit", **kwargs)) + assert not np.any(np.isnan(res)) + + # 'omit': equivalence with the surviving sub-ensemble, per row + for i in range(fct.shape[0]): + valid = ~np.isnan(fct_nan[i]).any(axis=-1) + fct_clean = fct_nan[i][valid] + res_clean = sr.es_ensemble(obs[i], fct_clean, **kwargs) + res_omit = sr.es_ensemble(obs[i], fct_nan[i], nan_policy="omit", **kwargs) + assert np.allclose(res_omit, res_clean) + + # all-members-NaN row: 'omit' yields NaN (documented 0/0 edge case) + fct_allnan = fct.copy() + fct_allnan[3, :, 0] = np.nan + res_allnan = np.asarray( + sr.es_ensemble(obs, fct_allnan, nan_policy="omit", **kwargs) + ) + assert np.isnan(res_allnan[3]) + + # NaN in obs always propagates, even under 'omit' + obs_nan = obs.copy() + obs_nan[1, 0] = np.nan + res = np.asarray(sr.es_ensemble(obs_nan, fct, nan_policy="omit", **kwargs)) + assert np.isnan(res[1]) + + +@pytest.mark.parametrize("estimator", ESTIMATORS) +def test_energy_score_w_ens_nan_policy(estimator, backend): + """es_ensemble with NaN forecast members AND user-supplied ens_w.""" + kwargs = {"estimator": estimator, "backend": backend} + + obs = np.random.randn(N, N_VARS) + fct = np.expand_dims(obs, axis=-2) + np.random.randn(N, ENSEMBLE_SIZE, N_VARS) + uniform_ens_w = np.ones(fct.shape[:-1]) + non_uniform_ens_w = np.random.rand(*fct.shape[:-1]) + + fct_nan = fct.copy() + fct_nan[0, [0, 3, 6], 0] = np.nan + fct_nan[2, [5], 1] = np.nan + nan_positions = np.isnan(fct_nan).any(axis=(1, 2)) + + # 'propagate' (default + explicit) + for policy in (None, "propagate"): + pol_kwargs = kwargs if policy is None else {**kwargs, "nan_policy": policy} + res = np.asarray( + sr.es_ensemble(obs, fct_nan, ens_w=uniform_ens_w, **pol_kwargs) + ) + assert np.all(np.isnan(res[nan_positions])) + + # 'raise' + with pytest.raises(ValueError): + sr.es_ensemble(obs, fct_nan, ens_w=uniform_ens_w, nan_policy="raise", **kwargs) + + # 'omit' unsupported for akr / akr_circperm + if estimator not in OMIT_ESTIMATORS: + with pytest.raises(NotImplementedError): + sr.es_ensemble( + obs, fct_nan, ens_w=uniform_ens_w, nan_policy="omit", **kwargs + ) + return + + # 'omit' with uniform weights: no NaN + res = np.asarray( + sr.es_ensemble(obs, fct_nan, ens_w=uniform_ens_w, nan_policy="omit", **kwargs) + ) + assert not np.any(np.isnan(res)) + + # 'omit' with non-uniform weights: weights actually matter + res = np.asarray( + sr.es_ensemble(obs, fct, ens_w=non_uniform_ens_w, nan_policy="omit", **kwargs) + ) + res_nans = np.asarray( + sr.es_ensemble( + obs, fct_nan, ens_w=non_uniform_ens_w, nan_policy="omit", **kwargs + ) + ) + assert not np.any(np.isnan(res_nans)) + assert not np.allclose(res[nan_positions], res_nans[nan_positions]) + assert np.allclose(res[~nan_positions], res_nans[~nan_positions]) + + # 'omit': equivalence with clean weighted sub-ensemble, per row + for i in range(fct.shape[0]): + valid = ~np.isnan(fct_nan[i]).any(axis=-1) + res_clean = sr.es_ensemble( + obs[i], + fct_nan[i][valid], + ens_w=uniform_ens_w[i][valid], + nan_policy="omit", + **kwargs, + ) + res_omit = sr.es_ensemble( + obs[i], fct_nan[i], ens_w=uniform_ens_w[i], nan_policy="omit", **kwargs + ) + assert np.allclose(res_omit, res_clean) + + # 'omit' with non-uniform weights: equivalence with clean weighted sub-ensemble + for i in range(fct.shape[0]): + valid = ~np.isnan(fct_nan[i]).any(axis=-1) + res_clean = sr.es_ensemble( + obs[i], + fct_nan[i][valid], + ens_w=non_uniform_ens_w[i][valid], + nan_policy="omit", + **kwargs, + ) + res_omit = sr.es_ensemble( + obs[i], fct_nan[i], ens_w=non_uniform_ens_w[i], nan_policy="omit", **kwargs + ) + assert np.allclose(res_omit, res_clean) + + +@pytest.mark.parametrize("estimator", ESTIMATORS) +def test_energy_score_nan_weights(estimator, backend): + """es_ensemble when ens_w itself contains NaN (forecasts are clean).""" + kwargs = {"estimator": estimator, "backend": backend} + + obs = np.random.randn(N, N_VARS) + fct = np.expand_dims(obs, axis=-2) + np.random.randn(N, ENSEMBLE_SIZE, N_VARS) + ens_w = np.random.rand(N, ENSEMBLE_SIZE) + ens_w[0, [0, 3, 6]] = np.nan + ens_w[2, [5]] = np.nan + nan_positions = np.isnan(ens_w).any(axis=1) + + # 'propagate' + res = np.asarray( + sr.es_ensemble(obs, fct, ens_w=ens_w, nan_policy="propagate", **kwargs) + ) + assert np.all(np.isnan(res[nan_positions])) + + # 'raise' + with pytest.raises(ValueError): + sr.es_ensemble(obs, fct, ens_w=ens_w, nan_policy="raise", **kwargs) + + # 'omit' unsupported for akr / akr_circperm + if estimator not in OMIT_ESTIMATORS: + with pytest.raises(NotImplementedError): + sr.es_ensemble(obs, fct, ens_w=ens_w, nan_policy="omit", **kwargs) + return + + # 'omit': NaN-weighted members get zero weight (dropped) + res = np.asarray(sr.es_ensemble(obs, fct, ens_w=ens_w, nan_policy="omit", **kwargs)) + assert not np.any(np.isnan(res)) + for i in range(fct.shape[0]): + valid = ~np.isnan(ens_w[i]) + res_clean = sr.es_ensemble( + obs[i], fct[i][valid], ens_w=ens_w[i][valid], nan_policy="omit", **kwargs + ) + res_omit = sr.es_ensemble( + obs[i], fct[i], ens_w=ens_w[i], nan_policy="omit", **kwargs + ) + assert np.allclose(res_omit, res_clean) diff --git a/tests/test_wenergy.py b/tests/test_wenergy.py index 368c4b7..97a62ba 100644 --- a/tests/test_wenergy.py +++ b/tests/test_wenergy.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import scoringrules as sr from scoringrules.backend import backends @@ -82,3 +83,152 @@ def v_func(x): res = sr.twes_ensemble(obs, fct, v_func, backend=backend) np.testing.assert_allclose(res, 0.3345418, rtol=1e-6) + + +@pytest.mark.parametrize("score_fn", [sr.owes_ensemble, sr.vres_ensemble]) +def test_weighted_energy_uniform_ens_w_equivalence(score_fn, backend): + """Uniform ens_w must reproduce the unweighted score. This guards the + weighted numba gufuncs (_owenergy_score_gufunc_w / _vrenergy_score_gufunc_w) + against the trusted unweighted kernels.""" + obs = np.random.randn(N, N_VARS) + fct = np.expand_dims(obs, axis=-2) + np.random.randn(N, ENSEMBLE_SIZE, N_VARS) + w_func = lambda x: backends[backend].mean(x) * 0.0 + 1.0 # noqa: E731 + ens_w = np.ones(fct.shape[:-1]) + res = np.asarray(score_fn(obs, fct, w_func, backend=backend)) + res_w = np.asarray(score_fn(obs, fct, w_func, ens_w=ens_w, backend=backend)) + # loose tolerance: jax runs in float32 here, and the two paths accumulate + # differently, so this guards against gross errors (a real bug differs by O(1)). + np.testing.assert_allclose(res, res_w, rtol=1e-4, atol=1e-5) + + +def _nan_policy_check_mv(score_fn, backend): + """Shared NaN-policy assertions for owes/vres (no estimator argument).""" + obs = np.random.randn(N, N_VARS) + fct = np.expand_dims(obs, axis=-2) + np.random.randn(N, ENSEMBLE_SIZE, N_VARS) + # non-constant weight function: this exercises the weighted denominator and + # would mask nothing — a constant w_func hides the owes wbar normalisation. + w_func = lambda x: backends[backend].sum(x**2) # noqa: E731 + + fct_nan = fct.copy() + fct_nan[0, [0, 3, 6], 0] = np.nan + fct_nan[2, [5], 1] = np.nan + nan_positions = np.isnan(fct_nan).any(axis=(1, 2)) + + # 'propagate' (default + explicit) + for policy in (None, "propagate"): + kw = {} if policy is None else {"nan_policy": policy} + res = np.asarray(score_fn(obs, fct_nan, w_func, backend=backend, **kw)) + assert np.all(np.isnan(res[nan_positions])) + + # 'raise' + with pytest.raises(ValueError): + score_fn(obs, fct_nan, w_func, nan_policy="raise", backend=backend) + + # 'omit': no NaN in result + res = np.asarray(score_fn(obs, fct_nan, w_func, nan_policy="omit", backend=backend)) + assert not np.any(np.isnan(res)) + + # 'omit': equivalence with the surviving sub-ensemble, per row + for i in range(fct.shape[0]): + valid = ~np.isnan(fct_nan[i]).any(axis=-1) + res_clean = score_fn(obs[i], fct_nan[i][valid], w_func, backend=backend) + res_omit = score_fn( + obs[i], fct_nan[i], w_func, nan_policy="omit", backend=backend + ) + assert np.allclose(res_omit, res_clean) + + # NaN in obs always propagates, even under 'omit' + obs_nan = obs.copy() + obs_nan[1, 0] = np.nan + res = np.asarray(score_fn(obs_nan, fct, w_func, nan_policy="omit", backend=backend)) + assert np.isnan(res[1]) + + +def test_owes_ensemble_nan_policy(backend): + _nan_policy_check_mv(sr.owes_ensemble, backend) + + +def test_vres_ensemble_nan_policy(backend): + _nan_policy_check_mv(sr.vres_ensemble, backend) + + +def test_twes_ensemble_nan_policy(backend): + obs = np.random.randn(N, N_VARS) + fct = np.expand_dims(obs, axis=-2) + np.random.randn(N, ENSEMBLE_SIZE, N_VARS) + + fct_nan = fct.copy() + fct_nan[0, [0, 3, 6], 0] = np.nan + nan_positions = np.isnan(fct_nan).any(axis=(1, 2)) + + # 'propagate' + res = np.asarray(sr.twes_ensemble(obs, fct_nan, lambda x: x, backend=backend)) + assert np.all(np.isnan(res[nan_positions])) + + # 'raise' + with pytest.raises(ValueError): + sr.twes_ensemble(obs, fct_nan, lambda x: x, nan_policy="raise", backend=backend) + + # 'omit': finite, and equals twes over the surviving members (identity v_func) + res = np.asarray( + sr.twes_ensemble(obs, fct_nan, lambda x: x, nan_policy="omit", backend=backend) + ) + assert not np.any(np.isnan(res)) + for i in range(fct.shape[0]): + valid = ~np.isnan(fct_nan[i]).any(axis=-1) + res_clean = sr.twes_ensemble( + obs[i], fct_nan[i][valid], lambda x: x, backend=backend + ) + res_omit = sr.twes_ensemble( + obs[i], fct_nan[i], lambda x: x, nan_policy="omit", backend=backend + ) + assert np.allclose(res_omit, res_clean) + + # NaN in obs always propagates, even under 'omit' + obs_nan = obs.copy() + obs_nan[1, 0] = np.nan + res = np.asarray( + sr.twes_ensemble(obs_nan, fct, lambda x: x, nan_policy="omit", backend=backend) + ) + assert np.isnan(res[1]) + + +@pytest.mark.parametrize("score_fn", [sr.owes_ensemble, sr.vres_ensemble]) +def test_weighted_energy_nan_policy_with_ens_w(score_fn, backend): + """owes/vres under 'omit' WITH user-supplied ens_w. Exercises the weighted + estimators (incl. the fixed numba kernels) under the exact scenario that + motivated the wbar / ow-indexing bugfixes.""" + obs = np.random.randn(N, N_VARS) + fct = np.expand_dims(obs, axis=-2) + np.random.randn(N, ENSEMBLE_SIZE, N_VARS) + w_func = lambda x: backends[backend].sum(x**2) # noqa: E731 + ens_w = np.ones(fct.shape[:-1]) + + fct_nan = fct.copy() + fct_nan[0, [0, 3, 6], 0] = np.nan + fct_nan[2, [5], 1] = np.nan + + # 'omit': no NaN in result + res = np.asarray( + score_fn(obs, fct_nan, w_func, ens_w=ens_w, nan_policy="omit", backend=backend) + ) + assert not np.any(np.isnan(res)) + + # 'omit': equivalence with the clean weighted sub-ensemble, per row + for i in range(fct.shape[0]): + valid = ~np.isnan(fct_nan[i]).any(axis=-1) + res_clean = score_fn( + obs[i], + fct_nan[i][valid], + w_func, + ens_w=ens_w[i][valid], + nan_policy="omit", + backend=backend, + ) + res_omit = score_fn( + obs[i], + fct_nan[i], + w_func, + ens_w=ens_w[i], + nan_policy="omit", + backend=backend, + ) + assert np.allclose(res_omit, res_clean)