diff --git a/scoringrules/_energy.py b/scoringrules/_energy.py index fdc101a..2ff3a08 100644 --- a/scoringrules/_energy.py +++ b/scoringrules/_energy.py @@ -22,6 +22,7 @@ def es_ensemble( ens_w: "Array" = None, estimator: str = "nrg", backend: "Backend" = None, + **kwargs, ) -> "Array": r"""Compute the Energy Score for a finite multivariate ensemble. @@ -71,18 +72,39 @@ def es_ensemble( Some theoretical background on scoring rules for multivariate forecasts. """ obs, fct = multivariate_array_check(obs, fct, m_axis, v_axis, backend=backend) + + if estimator == "akr_kband": + k = kwargs.get("k", 1) + if ens_w is None: if backend == "numba": estimator_check(estimator, energy.estimator_gufuncs) - return energy.estimator_gufuncs[estimator](obs, fct) + if estimator == "akr_kband": + return energy.estimator_gufuncs[estimator](obs, fct, k) + else: + return energy.estimator_gufuncs[estimator](obs, fct) else: + if estimator == "akr_kband": + return energy.es(obs, fct, estimator=estimator, backend=backend, k=k) return energy.es(obs, fct, estimator=estimator, backend=backend) else: ens_w = multivariate_weight_check(ens_w, fct, m_axis, backend=backend) if backend == "numba": estimator_check(estimator, energy.estimator_gufuncs_w) - return energy.estimator_gufuncs_w[estimator](obs, fct, ens_w) + if estimator == "akr_kband": + return energy.estimator_gufuncs_w[estimator](obs, fct, k, ens_w) + else: + return energy.estimator_gufuncs_w[estimator](obs, fct, ens_w) else: + if estimator == "akr_kband": + return energy.es_w( + obs, + fct, + ens_w, + estimator=estimator, + backend=backend, + k=k, + ) return energy.es_w(obs, fct, ens_w, estimator=estimator, backend=backend) diff --git a/scoringrules/core/energy/_gufuncs.py b/scoringrules/core/energy/_gufuncs.py index 2359760..13be2d5 100644 --- a/scoringrules/core/energy/_gufuncs.py +++ b/scoringrules/core/energy/_gufuncs.py @@ -83,6 +83,22 @@ def _energy_score_akr_circperm_gufunc( out[0] = e_1 / M - 0.5 * 1 / M * e_2 +@guvectorize("(d),(m,d),()->()") +def _energy_score_akr_kband_gufunc( + obs: np.ndarray, fct: np.ndarray, k: int, out: np.ndarray +): + """Compute the Energy Score for a finite ensemble using the AKR with k-band approximation.""" + M = fct.shape[0] + e_1 = 0.0 + e_2 = 0.0 + for i in range(M): + e_1 += float(np.linalg.norm(fct[i] - obs)) + for j in range(1, k + 1): + e_2 += float(np.linalg.norm(fct[i] - fct[(i + j) % M])) + + out[0] = e_1 / M - 0.5 * 1 / (M * k) * e_2 + + @guvectorize("(d),(m,d),(),(m)->()") def _owenergy_score_gufunc( obs: np.ndarray, @@ -133,6 +149,7 @@ def _vrenergy_score_gufunc( estimator_gufuncs = { "akr_circperm": lazy_gufunc_wrapper_mv(_energy_score_akr_circperm_gufunc), + "akr_kband": lazy_gufunc_wrapper_mv(_energy_score_akr_kband_gufunc), "akr": lazy_gufunc_wrapper_mv(_energy_score_akr_gufunc), "fair": lazy_gufunc_wrapper_mv(_energy_score_fair_gufunc), "nrg": lazy_gufunc_wrapper_mv(_energy_score_nrg_gufunc), @@ -142,6 +159,7 @@ def _vrenergy_score_gufunc( __all__ = [ "_energy_score_akr_circperm_gufunc", + "_energy_score_akr_kband_gufunc", "_energy_score_akr_gufunc", "_energy_score_fair_gufunc", "_energy_score_nrg_gufunc", diff --git a/scoringrules/core/energy/_gufuncs_w.py b/scoringrules/core/energy/_gufuncs_w.py index 7553a66..f7797f7 100644 --- a/scoringrules/core/energy/_gufuncs_w.py +++ b/scoringrules/core/energy/_gufuncs_w.py @@ -79,6 +79,27 @@ def _energy_score_akr_circperm_gufunc_w( out[0] = e_1 - 0.5 * e_2 +@guvectorize("(d),(m,d),(),(m)->()") +def _energy_score_akr_kband_gufunc_w( + obs: np.ndarray, fct: np.ndarray, k: int, ens_w: np.ndarray, out: np.ndarray +): + """Compute the Energy Score for a finite ensemble using the AKR with k-band approximation.""" + M = fct.shape[0] + + e_1 = 0.0 + e_2 = 0.0 + for i in range(M): + e_1 += float(np.linalg.norm(fct[i] - obs)) * ens_w[i] + for j in range(1, k + 1): + e_2 += ( + float(np.linalg.norm(fct[i] - fct[(i + j) % M])) + * ens_w[i] + * ens_w[(i + j) % M] + ) + + out[0] = e_1 - 0.5 * 1 / k * e_2 + + @guvectorize("(d),(m,d),(),(m),(m)->()") def _owenergy_score_gufunc_w( obs: np.ndarray, @@ -144,6 +165,7 @@ def _vrenergy_score_gufunc_w( estimator_gufuncs_w = { "akr_circperm": lazy_gufunc_wrapper_mv(_energy_score_akr_circperm_gufunc_w), + "akr_kband": lazy_gufunc_wrapper_mv(_energy_score_akr_kband_gufunc_w), "akr": lazy_gufunc_wrapper_mv(_energy_score_akr_gufunc_w), "fair": lazy_gufunc_wrapper_mv(_energy_score_fair_gufunc_w), "nrg": lazy_gufunc_wrapper_mv(_energy_score_nrg_gufunc_w), @@ -153,6 +175,7 @@ def _vrenergy_score_gufunc_w( __all__ = [ "_energy_score_akr_circperm_gufunc_w", + "_energy_score_akr_kband_gufunc_w", "_energy_score_akr_gufunc_w", "_energy_score_fair_gufunc_w", "_energy_score_nrg_gufunc_w", diff --git a/scoringrules/core/energy/_score.py b/scoringrules/core/energy/_score.py index 1dd6ec5..ec5fda5 100644 --- a/scoringrules/core/energy/_score.py +++ b/scoringrules/core/energy/_score.py @@ -7,7 +7,11 @@ def es_ensemble( - obs: "Array", fct: "Array", estimator: str = "nrg", backend=None + obs: "Array", + fct: "Array", + estimator: str = "nrg", + backend=None, + k: int = 1, ) -> "Array": """ Compute the energy score based on a finite ensemble. @@ -22,9 +26,11 @@ def es_ensemble( out = _es_ensemble_akr(obs, fct, backend=backend) elif estimator == "akr_circperm": out = _es_ensemble_akr_circperm(obs, fct, backend=backend) + elif estimator == "akr_kband": + out = _es_ensemble_akr_kband(obs, fct, k=k, backend=backend) else: raise ValueError( - f"For the energy score, {estimator} must be one of 'nrg', 'fair', 'akr', and 'akr_circperm'." + f"For the energy score, {estimator} must be one of 'nrg', 'fair', 'akr', 'akr_circperm', and 'akr_kband'." ) return out @@ -89,6 +95,28 @@ def _es_ensemble_akr_circperm( return E_1 - 0.5 * E_2 +def _es_ensemble_akr_kband( + obs: "Array", fct: "Array", k: int = 1, backend: "Backend" = None +) -> "Array": + """Compute the Energy Score for a finite ensemble using the AKR with k-band approximation.""" + B = backends.active if backend is None else backends[backend] + M: int = fct.shape[-2] + + if k < 1: + raise ValueError("For estimator='akr_kband', k must be >= 1.") + + err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) + E_1 = B.sum(err_norm, -1) / M + + E_2 = 0.0 + for j in range(1, k + 1): + spread_norm = B.norm(fct - B.roll(fct, shift=-j, axis=-2), -1) + E_2 += 2 * B.sum(spread_norm, -1) + E_2 = E_2 / (M * k) + + return E_1 - 0.5 * E_2 + + def owes_ensemble( obs: "Array", # (... D) fct: "Array", # (... M D) diff --git a/scoringrules/core/energy/_score_w.py b/scoringrules/core/energy/_score_w.py index 2d5d9a9..822581c 100644 --- a/scoringrules/core/energy/_score_w.py +++ b/scoringrules/core/energy/_score_w.py @@ -7,7 +7,12 @@ def es_ensemble_w( - obs: "Array", fct: "Array", ens_w: "Array", estimator: str = "nrg", backend=None + obs: "Array", + fct: "Array", + ens_w: "Array", + estimator: str = "nrg", + backend=None, + k: int = 1, ) -> "Array": """ Compute the energy score based on a finite ensemble. @@ -22,9 +27,11 @@ def es_ensemble_w( out = _es_ensemble_akr_w(obs, fct, ens_w, backend=backend) elif estimator == "akr_circperm": out = _es_ensemble_akr_circperm_w(obs, fct, ens_w, backend=backend) + elif estimator == "akr_kband": + out = _es_ensemble_akr_kband_w(obs, fct, ens_w, k=k, backend=backend) else: raise ValueError( - f"For the energy score, {estimator} must be one of 'nrg', 'fair', 'akr', and 'akr_circperm'." + f"For the energy score, {estimator} must be one of 'nrg', 'fair', 'akr', 'akr_circperm', and 'akr_kband'." ) return out @@ -102,6 +109,32 @@ def _es_ensemble_akr_circperm_w( return E_1 - 0.5 * E_2 +def _es_ensemble_akr_kband_w( + obs: "Array", + fct: "Array", + ens_w: "Array", + k: int = 1, + backend: "Backend" = None, +) -> "Array": + """Compute the weighted Energy Score using the AKR with k-band approximation.""" + B = backends.active if backend is None else backends[backend] + + if k < 1: + raise ValueError("For estimator='akr_kband', k must be >= 1.") + + err_norm = B.norm(fct - B.expand_dims(obs, -2), -1) + E_1 = B.sum(err_norm * ens_w, -1) + + E_2 = 0.0 + for j in range(1, k + 1): + fct_shift = B.roll(fct, shift=-j, axis=-2) + ens_w_shift = B.roll(ens_w, shift=-j, axis=-1) + spread_norm = B.norm(fct - fct_shift, -1) + E_2 += 2 * B.sum(spread_norm * ens_w * ens_w_shift, -1) + + return E_1 - 0.5 * E_2 / k + + def owes_ensemble_w( obs: "Array", # (... D) fct: "Array", # (... M D) diff --git a/tests/test_energy.py b/tests/test_energy.py index 6d7ec1e..206a36c 100644 --- a/tests/test_energy.py +++ b/tests/test_energy.py @@ -7,7 +7,7 @@ N = 20 N_VARS = 3 -ESTIMATORS = ["nrg", "fair", "akr", "akr_circperm"] +ESTIMATORS = ["nrg", "fair", "akr", "akr_circperm", "akr_kband"] @pytest.mark.parametrize("estimator", ESTIMATORS) @@ -18,13 +18,19 @@ def test_energy_score(estimator, backend): with pytest.raises(ValueError): obs = np.random.randn(N, N_VARS) fct = np.random.randn(N, ENSEMBLE_SIZE, N_VARS - 1) - sr.es_ensemble(obs, fct, estimator=estimator, backend=backend) + if estimator == "akr_kband": + sr.es_ensemble(obs, fct, estimator=estimator, backend=backend, k=2) + else: + sr.es_ensemble(obs, fct, estimator=estimator, backend=backend) # undefined estimator with pytest.raises(ValueError): fct = np.random.randn(N, ENSEMBLE_SIZE, N_VARS) est = "undefined_estimator" - sr.es_ensemble(obs, fct, estimator=est, backend=backend) + if estimator == "akr_kband": + sr.es_ensemble(obs, fct, estimator=est, backend=backend, k=2) + else: + sr.es_ensemble(obs, fct, estimator=est, backend=backend) # test output