Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions examples/jackknife-covariance.ipynb

Large diffs are not rendered by default.

142 changes: 110 additions & 32 deletions heracles/dices/jackknife.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
# You should have received a copy of the GNU Lesser General Public
# License along with DICES. If not, see <https://www.gnu.org/licenses/>.
import os
import numpy as np
import itertools
from copy import deepcopy
Expand All @@ -27,6 +28,8 @@
from ..twopoint import angular_power_spectra
from ..unmixing import _naturalspice
from ..transforms import cl2corr, corr2cl
from ..io import write_alms, read_alms, write, read
from ..progress import NoProgress

try:
from copy import replace
Expand All @@ -36,7 +39,15 @@


def jackknife_cls(
data_maps, vis_maps, jk_maps, fields, mask_correction="Fast", unmixed=False, nd=1
data_maps,
vis_maps,
jk_maps,
fields,
mask_correction="Fast",
unmixed=False,
nd=1,
dir="./dices",
progress=None,
):
"""
Compute the Cls of removing 1 Jackknife.
Expand All @@ -47,45 +58,89 @@ def jackknife_cls(
fields (dict): Dictionary of fields
mask_correction (str): Type of mask correction to apply ("Fast" or "Full")
nd (int): Number of Jackknife regions
mode (str): Type of statistic to compute ("Cls" or "PseudoCls")
dir (str): Directory for caching intermediate ALMs.
progress (Progress): Progress reporter.
returns:
cls (dict): Dictionary of data Cls
"""
if nd < 0 or nd > 2:
raise ValueError("number of deletions must be 0, 1, or 2")

if progress is None:
progress = NoProgress()

cls = {}
jkmap = jk_maps[list(jk_maps.keys())[0]]
njk = len(np.unique(jkmap)[np.unique(jkmap) != 0])

data_alms_regions = {}
vis_alms_regions = {}
for k in range(1, njk + 1):
print(f" - Computing ALMs for region {k}", end="\r", flush=True)
data_alms_regions[k] = transform(
fields, _get_region_maps(data_maps, jk_maps, k)
)
vis_alms_regions[k] = transform(fields, _get_region_maps(vis_maps, jk_maps, k))

mls0 = angular_power_spectra(_sum_alms_except(vis_alms_regions, ()))

for regions in combinations(range(1, njk + 1), nd):
print(f" - Computing Cls for regions {regions}", end="\r", flush=True)
alms_jk = _sum_alms_except(data_alms_regions, regions)
_cls = angular_power_spectra(alms_jk)
_cls = correct_bias(_cls, jk_maps, fields, *regions)
if mask_correction == "Full":
vis_alms_jk = _sum_alms_except(vis_alms_regions, regions)
_cls_mm = angular_power_spectra(vis_alms_jk)
_cls = correct_footprint_naturalspice(
_cls, _cls_mm, mls0, fields, unmixed=unmixed
)
elif mask_correction == "Fast":
_cls = correct_footprint_fsky(
_cls, jk_maps, fields, *regions, unmixed=unmixed
)
else:
raise ValueError("mask_correction must be 'Fast' or 'Full'")
cls[regions] = _cls
os.makedirs(dir, exist_ok=True)

all_regions = list(combinations(range(1, njk + 1), nd))
total = (njk + 1) + len(all_regions)
current = 0
progress.update(current, total)

# Compute ALMs
for k in range(0, njk + 1):
data_path = os.path.join(dir, f"data_alms_{k}.fits")
vis_path = os.path.join(dir, f"vis_alms_{k}.fits")
with progress.task(f"ALMs {k}"):
if not (os.path.exists(data_path) and os.path.exists(vis_path)):
if k == 0:
data_alms_k = transform(fields, data_maps)
vis_alms_k = transform(fields, vis_maps)
else:
data_alms_k = transform(
fields, _get_region_maps(data_maps, jk_maps, k)
)
vis_alms_k = transform(
fields, _get_region_maps(vis_maps, jk_maps, k)
)
write_alms(data_path, data_alms_k, clobber=True)
write_alms(vis_path, vis_alms_k, clobber=True)
current += 1
progress.update(current, total)

data_alms_full = read_alms(os.path.join(dir, "data_alms_0.fits"))
vis_alms_full = read_alms(os.path.join(dir, "vis_alms_0.fits"))
mls0 = angular_power_spectra(vis_alms_full)

# Compute Cls
for regions in all_regions:
regions_tag = "_".join(map(str, regions))
cls_path = os.path.join(dir, f"cls_{regions_tag}_unmixed_{unmixed}.fits")
with progress.task(f"Cls {regions}"):
if os.path.exists(cls_path):
cls[regions] = read(cls_path)
else:
alms_jk = _subtract_alms(
data_alms_full,
_accumulate_alms(
os.path.join(dir, f"data_alms_{r}.fits") for r in regions
),
)
_cls = angular_power_spectra(alms_jk)
_cls = correct_bias(_cls, jk_maps, fields, *regions)
if mask_correction == "Full":
vis_alms_jk = _subtract_alms(
vis_alms_full,
_accumulate_alms(
os.path.join(dir, f"vis_alms_{r}.fits") for r in regions
),
)
_cls_mm = angular_power_spectra(vis_alms_jk)
_cls = correct_footprint_naturalspice(
_cls, _cls_mm, mls0, fields, unmixed=unmixed
)
elif mask_correction == "Fast":
_cls = correct_footprint_fsky(
_cls, jk_maps, fields, *regions, unmixed=unmixed
)
else:
raise ValueError("mask_correction must be 'Fast' or 'Full'")
write(cls_path, _cls, clobber=True)
cls[regions] = _cls
current += 1
progress.update(current, total)
return cls


Expand Down Expand Up @@ -125,6 +180,29 @@ def _sum_alms_except(alms_regions, exclude=()):
return result


def _accumulate_alms(paths):
"""Reads ALMs from each path and returns their sum, loading one file at a time."""
result = None
for path in paths:
alms = read_alms(path)
if result is None:
result = {key: arr.copy() for key, arr in alms.items()}
else:
for key in result:
result[key] += alms[key]
return result


def _subtract_alms(full_alms, region_sum):
"""Returns full_alms minus region_sum, or a copy of full_alms if region_sum is None."""
result = {}
for key in full_alms:
result[key] = full_alms[key].copy()
if region_sum is not None:
result[key] -= region_sum[key]
return result


def bias(cls):
"""
Internal method to compute the bias.
Expand Down
7 changes: 4 additions & 3 deletions heracles/unmixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,17 @@ def _naturalspice(wd, wm, fields, theta_max=None):
xvals, _ = _cached_gauss_legendre(lmax_mask)
theta = np.arccos(xvals) * 180 / np.pi
i_theta_max = np.abs(theta - theta_max).argmin()
x0 = np.log10(abs(first_wm[i_theta_max]))
else:
x0 = -5

corr_wds = {}
for key in wd.keys():
a, b, i, j = key
m_key = (masks[a], masks[b], i, j)
_wm = get_cl(m_key, wm).array
_wd = wd[key].array
if theta_max is not None:
x0 = np.log10(abs(_wm[i_theta_max]))
_wm *= logistic(np.log10(abs(_wm)), x0=x0)
_wm *= logistic(np.log10(abs(_wm)), x0=x0)
corr_wds[key] = replace(wd[key], array=(_wd / _wm))

return corr_wds
22 changes: 18 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,31 @@ def mls0(fields, vis_maps):


@pytest.fixture(scope="session")
def cls1(fields, data_maps, vis_maps, jk_maps):
def cls1(fields, data_maps, vis_maps, jk_maps, tmp_path_factory):
from heracles.dices.jackknife import jackknife_cls

return jackknife_cls(data_maps, vis_maps, jk_maps, fields, nd=1)
return jackknife_cls(
data_maps,
vis_maps,
jk_maps,
fields,
nd=1,
dir=str(tmp_path_factory.mktemp("cls1")),
)


@pytest.fixture(scope="session")
def cls2(fields, data_maps, vis_maps, jk_maps):
def cls2(fields, data_maps, vis_maps, jk_maps, tmp_path_factory):
from heracles.dices.jackknife import jackknife_cls

return jackknife_cls(data_maps, vis_maps, jk_maps, fields, nd=2)
return jackknife_cls(
data_maps,
vis_maps,
jk_maps,
fields,
nd=2,
dir=str(tmp_path_factory.mktemp("cls2")),
)


@pytest.fixture(scope="session")
Expand Down
6 changes: 4 additions & 2 deletions tests/test_dices.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ def test_region_alm_cls(fields, data_maps, jk_maps, njk):
)


def test_cls(nside, cls0, fields, data_maps, vis_maps, jk_maps):
_cls0 = dices.jackknife_cls(data_maps, vis_maps, jk_maps, fields, nd=0)[()]
def test_cls(nside, cls0, fields, data_maps, vis_maps, jk_maps, tmp_path):
_cls0 = dices.jackknife_cls(
data_maps, vis_maps, jk_maps, fields, nd=0, dir=str(tmp_path)
)[()]
for key in list(_cls0.keys()):
_cl = _cls0[key]
*_, nells = _cl.shape
Expand Down
Loading