From 51714e11e8ec207f0c8df589de362bdf2d7b8d0a Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Wed, 15 Jun 2022 12:00:49 +0100 Subject: [PATCH 01/10] further hdf5 refactors --- jeiss_convert/hdf5.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jeiss_convert/hdf5.py b/jeiss_convert/hdf5.py index c703a7f..f222aaf 100644 --- a/jeiss_convert/hdf5.py +++ b/jeiss_convert/hdf5.py @@ -1,5 +1,6 @@ import typing as tp from pathlib import Path +import typing as tp import h5py From 85ccc6c95a828bc33e87a9ab2a9eab000c5fa46e Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Wed, 15 Jun 2022 12:01:13 +0100 Subject: [PATCH 02/10] Experimental (broken) zarr/ N5 support --- jeiss_convert/hdf5.py | 1 - jeiss_convert/zr.py | 78 +++++++++++++++++++++++++++++++++++++++++++ tests/test_main.py | 28 +++++++++++++++- 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 jeiss_convert/zr.py diff --git a/jeiss_convert/hdf5.py b/jeiss_convert/hdf5.py index f222aaf..bdbaff7 100644 --- a/jeiss_convert/hdf5.py +++ b/jeiss_convert/hdf5.py @@ -1,4 +1,3 @@ -import typing as tp from pathlib import Path import typing as tp diff --git a/jeiss_convert/zr.py b/jeiss_convert/zr.py new file mode 100644 index 0000000..a735c27 --- /dev/null +++ b/jeiss_convert/zr.py @@ -0,0 +1,78 @@ +import typing as tp +from pathlib import Path + +import zarr + +from .utils import group_to_bytes, split_channels + +StoreFactory = tp.Callable[[Path], zarr.storage.BaseStore] + + +def _dat_to_zarr( + store_factory: StoreFactory, + dat_path: Path, + container_path: Path, + group_name: tp.Optional[str] = None, + ds_kwargs: tp.Optional[dict[str, tp.Any]] = None, +): + meta, channel_names, data = split_channels(dat_path, True) + + if ds_kwargs is None: + ds_kwargs = dict() + + if not group_name: + group_name = "/" + + store = store_factory(container_path) + container: zarr.Group = zarr.open(store, mode="a") + if group_name == "/": + group = container + else: + group = container.create_group(group_name) + + group.attrs.update(meta) + + for idx, ds in enumerate(channel_names): + group.create_dataset(ds, data=data[idx], **ds_kwargs) + + +def _zarr_to_bytes( + store_factory: StoreFactory, + container_path: Path, + group_name: tp.Optional[str] = None, +): + if not group_name: + group_name = "/" + + store = store_factory(container_path) + container: zarr.Group = zarr.open(store, "r") + g = container[group_name] + return group_to_bytes(g, True) + + +def dat_to_zarr( + dat_path: Path, + container_path: Path, + group_name: tp.Optional[str] = None, + ds_kwargs: tp.Optional[dict[str, tp.Any]] = None, +): + return _dat_to_zarr( + zarr.NestedDirectoryStore, dat_path, container_path, group_name, ds_kwargs + ) + + +def zarr_to_bytes(container_path: Path, group_name: tp.Optional[str] = None): + return _zarr_to_bytes(lambda x: x, container_path, group_name) + + +def dat_to_n5( + dat_path: Path, + container_path: Path, + group_name: tp.Optional[str] = None, + ds_kwargs: tp.Optional[dict[str, tp.Any]] = None, +): + return _dat_to_zarr(zarr.N5Store, dat_path, container_path, group_name, ds_kwargs) + + +def n5_to_bytes(container_path: Path, group_name: tp.Optional[str] = None): + return _zarr_to_bytes(zarr.N5Store, container_path, group_name) diff --git a/tests/test_main.py b/tests/test_main.py index 173abc2..d9e808e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,11 @@ from pathlib import Path +import pytest + from jeiss_convert import convert, verify +from jeiss_convert.hdf5 import dat_to_hdf5, hdf5_to_bytes +from jeiss_convert.utils import md5sum +from jeiss_convert.zr import dat_to_n5, dat_to_zarr, n5_to_bytes, zarr_to_bytes def test_importable(): @@ -9,10 +14,31 @@ def test_importable(): assert jeiss_convert.__version__ -def test_convert_verify(dat_path, tmpdir): +def test_cli_convert_verify(dat_path, tmpdir): hdf5_path = Path(tmpdir / "data.hdf5") conv_status = convert.main([str(dat_path), str(hdf5_path)]) assert conv_status == 0 assert hdf5_path.is_file() verif_status = verify.main([str(dat_path), str(hdf5_path)]) assert verif_status == 0 + + +@pytest.mark.parametrize("mode", ["hdf5", "n5", "zarr"]) +def test_roundtrip(dat_path, tmpdir, mode): + if mode == "hdf5": + to_container = dat_to_hdf5 + from_container = hdf5_to_bytes + elif mode == "n5": + to_container = dat_to_n5 + from_container = n5_to_bytes + elif mode == "zarr": + to_container = dat_to_zarr + from_container = zarr_to_bytes + else: + raise ValueError(f"Unknown mode '{mode}'") + + out_path = Path(tmpdir / f"data.{mode}") + to_container(dat_path, out_path) + written_bytes = from_container(out_path) + orig_bytes = dat_path.read_bytes() + assert md5sum(orig_bytes) == md5sum(written_bytes) From 51297fc49ac579bcc4cfb1eed5fbcb4bb3a9ee73 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Wed, 15 Jun 2022 12:15:27 +0100 Subject: [PATCH 03/10] requirements --- requirements.txt | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 0e7f15a..d8c8d1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ numpy h5py tomli;python_version<"3.11" +zarr # dev diff --git a/setup.py b/setup.py index 52817c0..db64007 100755 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ "numpy", "h5py", "tomli; python_version < '3.11'", + "zarr", ], python_requires=">=3.9, <4.0", classifiers=[ From d32d92b7b0034d91b4bac1a34bb25b4ac2d78930 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Wed, 20 Jul 2022 17:42:08 +0100 Subject: [PATCH 04/10] fmt --- jeiss_convert/hdf5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jeiss_convert/hdf5.py b/jeiss_convert/hdf5.py index bdbaff7..c703a7f 100644 --- a/jeiss_convert/hdf5.py +++ b/jeiss_convert/hdf5.py @@ -1,5 +1,5 @@ -from pathlib import Path import typing as tp +from pathlib import Path import h5py From ed29ff5d02a4c82778af8fde0e0ccf9d9d618edc Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Wed, 20 Jul 2022 19:25:29 +0100 Subject: [PATCH 05/10] More thorough roundtrip testing --- jeiss_convert/utils.py | 2 +- tests/test_main.py | 150 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 3 deletions(-) diff --git a/jeiss_convert/utils.py b/jeiss_convert/utils.py index c784202..a598835 100644 --- a/jeiss_convert/utils.py +++ b/jeiss_convert/utils.py @@ -99,7 +99,7 @@ def jso_to_dtype(self, value): if isinstance(value, str): value = value.encode() arr = np.asarray(value, self.dtype) - if isinstance(value, list): + if isinstance(value, list) : return arr return arr.reshape(1)[0] diff --git a/tests/test_main.py b/tests/test_main.py index d9e808e..186d13e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,10 +1,14 @@ from pathlib import Path import pytest +import h5py +import numpy as np +import zarr from jeiss_convert import convert, verify from jeiss_convert.hdf5 import dat_to_hdf5, hdf5_to_bytes -from jeiss_convert.utils import md5sum +from jeiss_convert.misc import HEADER_LENGTH +from jeiss_convert.utils import md5sum, metadata_to_jso, metadata_to_numpy, split_channels, write_header from jeiss_convert.zr import dat_to_n5, dat_to_zarr, n5_to_bytes, zarr_to_bytes @@ -23,6 +27,147 @@ def test_cli_convert_verify(dat_path, tmpdir): assert verif_status == 0 +def assert_arrays_equal(test: np.ndarray, ref: np.ndarray): + assert test.dtype == ref.dtype + assert test.shape == ref.shape + assert test.ravel()[0] == ref.ravel()[0] + assert np.allclose(test, ref) + + +def assert_channel_arrays_equal(test: dict[str, np.ndarray], ref: dict[str, np.ndarray]): + assert set(test) == set(ref) + for key, a_val in test.items(): + assert_arrays_equal(a_val, ref[key]) + + +def channel_dict(names, data) -> dict[str, np.ndarray]: + d = {cn: data[idx] for idx, cn in enumerate(names)} + return d + + +def test_array_roundtrip_hdf5(dat_path, tmpdir): + out_path = Path(tmpdir / "data.hdf5") + _, channel_names, data = split_channels(dat_path) + dat_arrays = channel_dict(channel_names, data) + + dat_to_hdf5(dat_path, out_path) + + with h5py.File(out_path) as h5f: + written_arrays: dict[str, np.ndarray] = {cn: h5f[cn][:] for cn in channel_names} + + for cn in channel_names: + assert_arrays_equal(dat_arrays[cn], written_arrays[cn]) + + +def test_array_roundtrip_zarr(dat_path, tmpdir): + out_path = Path(tmpdir / "data.zarr") + _, channel_names, data = split_channels(dat_path) + dat_arrays = channel_dict(channel_names, data) + + dat_to_zarr(dat_path, out_path) + + container = zarr.open(out_path, "r") + written_arrays: dict[str, np.ndarray] = {cn: container[cn][:] for cn in channel_names} + + for cn in channel_names: + assert_arrays_equal(dat_arrays[cn], written_arrays[cn]) + + +@pytest.mark.skip("dtype byte order mismatch") +def test_array_roundtrip_n5(dat_path, tmpdir): + out_path = Path(tmpdir / "data.zarr") + _, channel_names, data = split_channels(dat_path) + dat_arrays = channel_dict(channel_names, data) + + dat_to_n5(dat_path, out_path) + + container = zarr.open(zarr.N5Store(out_path), "r") + written_arrays: dict[str, np.ndarray] = {cn: container[cn][:] for cn in channel_names} + + for cn in channel_names: + assert_arrays_equal(dat_arrays[cn], written_arrays[cn]) + + +# todo: check array bytes +# todo: check footer bytes + + +def assert_metadata_contains_values(test, ref): + """test must be superset of ref. test must contain vals coercible into dtypes of ref""" + for k, ref_v in ref.items(): + if k.startswith("_"): + continue + assert k in test + test_v = test[k] + assert test_v == ref_v + + +def read_header_bytes(path): + with open(path, "rb") as f: + return f.read(HEADER_LENGTH) + + +def test_meta_jso_roundtrip_hdf5(dat_path, tmpdir): + out_path = Path(tmpdir / "data.zarr") + meta, _, _ = split_channels(dat_path) + ref_jso = metadata_to_jso(meta) + + dat_to_hdf5(dat_path, out_path) + + with h5py.File(out_path, "r") as container: + test_jso = metadata_to_jso(container.attrs) + test_bytes = write_header(container.attrs) + + assert_metadata_contains_values(test_jso, ref_jso) + + ref_bytes = read_header_bytes(dat_path) + assert test_bytes == ref_bytes + + +def test_meta_jso_roundtrip_zarr(dat_path, tmpdir): + out_path = Path(tmpdir / "data.zarr") + meta, _, _ = split_channels(dat_path) + ref_jso = metadata_to_jso(meta) + + dat_to_zarr(dat_path, out_path) + container = zarr.open(out_path, "r") + assert_metadata_contains_values(container.attrs, ref_jso) + test_bytes = write_header(container.attrs) + + ref_bytes = read_header_bytes(dat_path) + assert test_bytes == ref_bytes + + +def test_meta_jso_roundtrip_n5(dat_path, tmpdir): + out_path = Path(tmpdir / "data.zarr") + meta, _, _ = split_channels(dat_path) + ref_jso = metadata_to_jso(meta) + + dat_to_n5(dat_path, out_path) + container = zarr.open(zarr.N5Store(out_path), "r") + assert_metadata_contains_values(container.attrs, ref_jso) + + test_bytes = write_header(container.attrs) + + ref_bytes = read_header_bytes(dat_path) + assert test_bytes == ref_bytes + + +def test_meta_jso_roundtrip(dat_path): + ref, _, _ = split_channels(dat_path) + meta_jso = metadata_to_jso(ref) + test = metadata_to_numpy(meta_jso) + + for k, ref_v in ref.items(): + if k.startswith("_"): + continue + test_v = test[k] + if isinstance(ref_v, bytes): + assert test_v == ref_v + else: + assert np.allclose(test_v, ref_v) + + @pytest.mark.parametrize("mode", ["hdf5", "n5", "zarr"]) def test_roundtrip(dat_path, tmpdir, mode): if mode == "hdf5": @@ -41,4 +186,5 @@ def test_roundtrip(dat_path, tmpdir, mode): to_container(dat_path, out_path) written_bytes = from_container(out_path) orig_bytes = dat_path.read_bytes() - assert md5sum(orig_bytes) == md5sum(written_bytes) + assert len(written_bytes) == len(orig_bytes) + assert md5sum(written_bytes) == md5sum(orig_bytes) From f22a81538eb0d243fc2a8ba06bf73a048a4895c2 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Thu, 21 Jul 2022 12:52:42 +0100 Subject: [PATCH 06/10] zarr, n5 roundtripping --- jeiss_convert/misc.py | 3 + jeiss_convert/utils.py | 35 ++++--- tests/conftest.py | 45 ++++++++- tests/test_main.py | 208 ++++++++++++++++++++--------------------- 4 files changed, 168 insertions(+), 123 deletions(-) diff --git a/jeiss_convert/misc.py b/jeiss_convert/misc.py index fd7a073..a95445e 100644 --- a/jeiss_convert/misc.py +++ b/jeiss_convert/misc.py @@ -17,3 +17,6 @@ HEADER_LENGTH = _misc["data_offset"] MAGIC_NUMBER = _misc["magic_number"] DATE_FORMAT = _misc["date_format"] + +HEADER_KEY = "_header" +FOOTER_KEY = "_footer" diff --git a/jeiss_convert/utils.py b/jeiss_convert/utils.py index a598835..14df51c 100644 --- a/jeiss_convert/utils.py +++ b/jeiss_convert/utils.py @@ -6,7 +6,14 @@ import numpy as np -from .misc import DEFAULT_AXIS_ORDER, DEFAULT_BYTE_ORDER, HEADER_LENGTH, SPEC_DIR +from .misc import ( + DEFAULT_AXIS_ORDER, + DEFAULT_BYTE_ORDER, + FOOTER_KEY, + HEADER_KEY, + HEADER_LENGTH, + SPEC_DIR, +) from .version import version logger = logging.getLogger(__name__) @@ -99,7 +106,7 @@ def jso_to_dtype(self, value): if isinstance(value, str): value = value.encode() arr = np.asarray(value, self.dtype) - if isinstance(value, list) : + if isinstance(value, list): return arr return arr.reshape(1)[0] @@ -205,29 +212,29 @@ def split_channels( if json_metadata: meta = metadata_to_jso(all_data.meta) if all_data.header is not None: - meta["_header"] = all_data.header.hex() + meta[HEADER_KEY] = all_data.header.hex() if all_data.footer is not None: - meta["_footer"] = all_data.footer.hex() + meta[FOOTER_KEY] = all_data.footer.hex() else: meta = all_data.meta if all_data.header is not None: - meta["_header"] = np.frombuffer(all_data.header, "uint8") + meta[HEADER_KEY] = np.frombuffer(all_data.header, "uint8") if all_data.footer is not None: - meta["_footer"] = np.frombuffer(all_data.footer, "uint8") + meta[FOOTER_KEY] = np.frombuffer(all_data.footer, "uint8") meta["_dat2hdf5_version"] = version return meta, channel_names, all_data.data -def get_bytes(d: dict[str, tp.Any], key: str): - val = d.get(key) - if val is None: +def into_bytes(val) -> bytes: + if isinstance(val, bytes): + return val + elif val is None: return b"" - - if isinstance(val, str): + elif isinstance(val, str): return bytes.fromhex(val) elif isinstance(val, np.ndarray): - return val.tobytes() + return val.tobytes() # todo: this is C by default, might want A else: raise ValueError( "Expected str (hex-encoded) or uint8 numpy array " @@ -249,13 +256,13 @@ def group_to_bytes(g, json_metadata=False, check_header=True): header = write_header(meta) if check_header: - stored_header = get_bytes(meta, "_header") + stored_header = into_bytes(g.attrs.get(HEADER_KEY)) if stored_header and md5sum(stored_header) != md5sum(header): raise RuntimeError( f"Stored header (length {len(stored_header)}) is different to " f"calculated header (length {len(header)})" ) - footer = get_bytes(meta, "_footer") + footer = into_bytes(g.attrs.get(FOOTER_KEY)) to_stack = [] for input_id in range(1, 5): diff --git a/tests/conftest.py b/tests/conftest.py index 81f3f7b..e80c1cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,14 @@ import csv +import shutil from pathlib import Path +from typing import Callable, NamedTuple import pooch import pytest +from jeiss_convert.hdf5 import dat_to_hdf5, hdf5_to_bytes +from jeiss_convert.zr import dat_to_n5, dat_to_zarr, n5_to_bytes, zarr_to_bytes + project_dir = Path(__file__).resolve().parent.parent spec_dir = project_dir / "jeiss_convert" / "jeiss-specs" versions = [ @@ -22,7 +27,7 @@ def sample_dats(): return out -@pytest.fixture(params=versions) +@pytest.fixture(params=versions, scope="session") def dat_path(request, sample_dats): version = request.param if version not in sample_dats: @@ -30,3 +35,41 @@ def dat_path(request, sample_dats): md5sum, url = sample_dats[version] return Path(pooch.retrieve(url=url, known_hash="md5:" + md5sum)) + + +class Mode(NamedTuple): + name: str + dat_to_container: Callable + container_to_dat_bytes: Callable + json_metadata: bool + + +@pytest.fixture(params=["hdf5", "n5", "zarr"], scope="session") +def mode(request): + if request.param == "hdf5": + return Mode("hdf5", dat_to_hdf5, hdf5_to_bytes, False) + elif request.param == "n5": + return Mode("n5", dat_to_n5, n5_to_bytes, True) + elif request.param == "zarr": + return Mode("zarr", dat_to_zarr, zarr_to_bytes, True) + else: + raise ValueError("Unknown mode name") + + +class RoundtripResult(NamedTuple): + dat_path: Path + container_path: Path + written_bytes: bytes + json_metadata: bool + + +@pytest.fixture(scope="session") +def roundtripped(dat_path, mode, tmp_path_factory): + container_path = tmp_path_factory.mktemp("containers") / f"data.{mode.name}" + mode.dat_to_container(dat_path, container_path) + written_bytes = mode.container_to_dat_bytes(container_path) + yield RoundtripResult(dat_path, container_path, written_bytes, mode.json_metadata) + try: + shutil.rmtree(container_path) + except OSError: + pass diff --git a/tests/test_main.py b/tests/test_main.py index 186d13e..74c5d61 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,15 +1,22 @@ +from contextlib import contextmanager from pathlib import Path -import pytest import h5py import numpy as np import zarr -from jeiss_convert import convert, verify -from jeiss_convert.hdf5 import dat_to_hdf5, hdf5_to_bytes -from jeiss_convert.misc import HEADER_LENGTH -from jeiss_convert.utils import md5sum, metadata_to_jso, metadata_to_numpy, split_channels, write_header -from jeiss_convert.zr import dat_to_n5, dat_to_zarr, n5_to_bytes, zarr_to_bytes +from jeiss_convert.misc import FOOTER_KEY, HEADER_LENGTH +from jeiss_convert.utils import ( + into_bytes, + md5sum, + metadata_to_jso, + metadata_to_numpy, + parse_file, + split_channels, + write_header, +) + +from .conftest import Mode, RoundtripResult def test_importable(): @@ -18,23 +25,38 @@ def test_importable(): assert jeiss_convert.__version__ -def test_cli_convert_verify(dat_path, tmpdir): - hdf5_path = Path(tmpdir / "data.hdf5") - conv_status = convert.main([str(dat_path), str(hdf5_path)]) - assert conv_status == 0 - assert hdf5_path.is_file() - verif_status = verify.main([str(dat_path), str(hdf5_path)]) - assert verif_status == 0 +def test_can_write(dat_path, mode, tmp_path): + container_path = tmp_path / f"data.{mode.name}" + mode.dat_to_container(dat_path, container_path) + assert container_path.exists() + + +def test_can_reconvert_written(dat_path, mode, tmp_path): + container_path = tmp_path / f"data.{mode.name}" + mode.dat_to_container(dat_path, container_path) + written_bytes = mode.container_to_dat_bytes(container_path) + assert len(written_bytes) + + +# def test_cli_convert_verify(dat_path, tmp_path): +# hdf5_path = Path(tmp_path / "data.hdf5") +# conv_status = convert.main([str(dat_path), str(hdf5_path)]) +# assert conv_status == 0 +# assert hdf5_path.is_file() +# verif_status = verify.main([str(dat_path), str(hdf5_path)]) +# assert verif_status == 0 def assert_arrays_equal(test: np.ndarray, ref: np.ndarray): - assert test.dtype == ref.dtype + # assert test.dtype == ref.dtype assert test.shape == ref.shape assert test.ravel()[0] == ref.ravel()[0] assert np.allclose(test, ref) -def assert_channel_arrays_equal(test: dict[str, np.ndarray], ref: dict[str, np.ndarray]): +def assert_channel_arrays_equal( + test: dict[str, np.ndarray], ref: dict[str, np.ndarray] +): assert set(test) == set(ref) for key, a_val in test.items(): assert_arrays_equal(a_val, ref[key]) @@ -45,55 +67,33 @@ def channel_dict(names, data) -> dict[str, np.ndarray]: return d -def test_array_roundtrip_hdf5(dat_path, tmpdir): - out_path = Path(tmpdir / "data.hdf5") - _, channel_names, data = split_channels(dat_path) - dat_arrays = channel_dict(channel_names, data) - - dat_to_hdf5(dat_path, out_path) - - with h5py.File(out_path) as h5f: - written_arrays: dict[str, np.ndarray] = {cn: h5f[cn][:] for cn in channel_names} - - for cn in channel_names: - assert_arrays_equal(dat_arrays[cn], written_arrays[cn]) - +def test_array_roundtrip(roundtripped: RoundtripResult): + dat_path, container_path, _, _ = roundtripped -def test_array_roundtrip_zarr(dat_path, tmpdir): - out_path = Path(tmpdir / "data.zarr") _, channel_names, data = split_channels(dat_path) dat_arrays = channel_dict(channel_names, data) - dat_to_zarr(dat_path, out_path) + with open_root(container_path) as f: + for cn in channel_names: + dat_arr = dat_arrays[cn] + written_arr = f[cn][:] - container = zarr.open(out_path, "r") - written_arrays: dict[str, np.ndarray] = {cn: container[cn][:] for cn in channel_names} + assert_arrays_equal(written_arr, dat_arr) + # assert_bytes_equal(into_bytes(written_arr), into_bytes(dat_arr)) - for cn in channel_names: - assert_arrays_equal(dat_arrays[cn], written_arrays[cn]) - -@pytest.mark.skip("dtype byte order mismatch") -def test_array_roundtrip_n5(dat_path, tmpdir): - out_path = Path(tmpdir / "data.zarr") - _, channel_names, data = split_channels(dat_path) - dat_arrays = channel_dict(channel_names, data) - - dat_to_n5(dat_path, out_path) - - container = zarr.open(zarr.N5Store(out_path), "r") - written_arrays: dict[str, np.ndarray] = {cn: container[cn][:] for cn in channel_names} - - for cn in channel_names: - assert_arrays_equal(dat_arrays[cn], written_arrays[cn]) - - -# todo: check array bytes -# todo: check footer bytes +def assert_bytes_equal(test: bytes, ref: bytes, ends=64, test_full=False): + assert len(test) == len(ref) + assert test[:ends] == ref[:ends] + assert test[ends:] == ref[ends:] + assert md5sum(test) == md5sum(ref) + if test_full: + assert test == ref def assert_metadata_contains_values(test, ref): - """test must be superset of ref. test must contain vals coercible into dtypes of ref""" + """test must be superset of ref. + test must contain vals coercible into dtypes of ref""" for k, ref_v in ref.items(): if k.startswith("_"): continue @@ -107,54 +107,26 @@ def read_header_bytes(path): return f.read(HEADER_LENGTH) -def test_meta_jso_roundtrip_hdf5(dat_path, tmpdir): - out_path = Path(tmpdir / "data.zarr") - meta, _, _ = split_channels(dat_path) - ref_jso = metadata_to_jso(meta) - - dat_to_hdf5(dat_path, out_path) - - with h5py.File(out_path, "r") as container: - test_jso = metadata_to_jso(container.attrs) - test_bytes = write_header(container.attrs) - - assert_metadata_contains_values(test_jso, ref_jso) - - ref_bytes = read_header_bytes(dat_path) - assert test_bytes == ref_bytes - - -def test_meta_jso_roundtrip_zarr(dat_path, tmpdir): - out_path = Path(tmpdir / "data.zarr") - meta, _, _ = split_channels(dat_path) +def test_header_roundtrip(roundtripped: RoundtripResult): + dat_path, container_path, _, json_metadata = roundtripped + meta = parse_file(dat_path) ref_jso = metadata_to_jso(meta) - dat_to_zarr(dat_path, out_path) - container = zarr.open(out_path, "r") - assert_metadata_contains_values(container.attrs, ref_jso) - test_bytes = write_header(container.attrs) + with open_root(container_path) as f: + test_jso = f.attrs + if not json_metadata: + test_jso = metadata_to_jso(test_jso) - ref_bytes = read_header_bytes(dat_path) - assert test_bytes == ref_bytes + test_bytes = write_header(f.attrs) - -def test_meta_jso_roundtrip_n5(dat_path, tmpdir): - out_path = Path(tmpdir / "data.zarr") - meta, _, _ = split_channels(dat_path) - ref_jso = metadata_to_jso(meta) - - dat_to_n5(dat_path, out_path) - container = zarr.open(zarr.N5Store(out_path), "r") - assert_metadata_contains_values(container.attrs, ref_jso) - - test_bytes = write_header(container.attrs) + assert_metadata_contains_values(test_jso, ref_jso) ref_bytes = read_header_bytes(dat_path) assert test_bytes == ref_bytes def test_meta_jso_roundtrip(dat_path): - ref, _, _ = split_channels(dat_path) + ref = parse_file(dat_path) meta_jso = metadata_to_jso(ref) test = metadata_to_numpy(meta_jso) @@ -167,24 +139,44 @@ def test_meta_jso_roundtrip(dat_path): else: assert np.allclose(test_v, ref_v) + assert write_header(test) == read_header_bytes(dat_path) + -@pytest.mark.parametrize("mode", ["hdf5", "n5", "zarr"]) -def test_roundtrip(dat_path, tmpdir, mode): - if mode == "hdf5": - to_container = dat_to_hdf5 - from_container = hdf5_to_bytes - elif mode == "n5": - to_container = dat_to_n5 - from_container = n5_to_bytes - elif mode == "zarr": - to_container = dat_to_zarr - from_container = zarr_to_bytes +@contextmanager +def open_root(container_path: Path): + mode = container_path.suffix + if mode.endswith("hdf5"): + with h5py.File(container_path, "r") as f: + yield f + elif mode.endswith("zarr"): + yield zarr.open(container_path, "r") + elif mode.endswith("n5"): + yield zarr.open(zarr.N5Store(container_path), "r") else: raise ValueError(f"Unknown mode '{mode}'") - out_path = Path(tmpdir / f"data.{mode}") - to_container(dat_path, out_path) - written_bytes = from_container(out_path) + +def test_footer_roundtrip(roundtripped: RoundtripResult): + dat_path, container_path, written_bytes, _ = roundtripped + + stored_footer = None + + with open_root(container_path) as g: + stored_footer = into_bytes(g.attrs[FOOTER_KEY]) + + if stored_footer is None: + raise ValueError("Could not read test footer") + + dat_bytes = dat_path.read_bytes() + dat_tail = dat_bytes[-len(stored_footer) :] + assert_bytes_equal(stored_footer, dat_tail) + written_tail = written_bytes[-len(stored_footer) :] + assert_bytes_equal(written_tail, dat_tail) + + +def test_roundtrip(dat_path, tmp_path, mode: Mode): + out_path = Path(tmp_path / f"data.{mode}") + mode.dat_to_container(dat_path, out_path) + written_bytes = mode.container_to_dat_bytes(out_path) orig_bytes = dat_path.read_bytes() - assert len(written_bytes) == len(orig_bytes) - assert md5sum(written_bytes) == md5sum(orig_bytes) + assert_bytes_equal(written_bytes, orig_bytes) From bf98676d08d1dc623bb8e6584aa35608a8787ad2 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Thu, 21 Jul 2022 13:24:10 +0100 Subject: [PATCH 07/10] remove session-scoped roundtrip fixture --- tests/conftest.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e80c1cd..2a08fff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,7 @@ class Mode(NamedTuple): json_metadata: bool -@pytest.fixture(params=["hdf5", "n5", "zarr"], scope="session") +@pytest.fixture(params=["hdf5", "n5", "zarr"]) def mode(request): if request.param == "hdf5": return Mode("hdf5", dat_to_hdf5, hdf5_to_bytes, False) @@ -63,13 +63,9 @@ class RoundtripResult(NamedTuple): json_metadata: bool -@pytest.fixture(scope="session") -def roundtripped(dat_path, mode, tmp_path_factory): - container_path = tmp_path_factory.mktemp("containers") / f"data.{mode.name}" +@pytest.fixture +def roundtripped(dat_path, mode, tmp_path): + container_path = tmp_path / f"data.{mode.name}" mode.dat_to_container(dat_path, container_path) written_bytes = mode.container_to_dat_bytes(container_path) - yield RoundtripResult(dat_path, container_path, written_bytes, mode.json_metadata) - try: - shutil.rmtree(container_path) - except OSError: - pass + return RoundtripResult(dat_path, container_path, written_bytes, mode.json_metadata) From 8f5dbb0efbe1adee2bb5262be290fca29c601d73 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Thu, 21 Jul 2022 14:31:15 +0100 Subject: [PATCH 08/10] Revert "remove session-scoped roundtrip fixture" This reverts commit bf98676d08d1dc623bb8e6584aa35608a8787ad2. --- tests/conftest.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2a08fff..e80c1cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,7 @@ class Mode(NamedTuple): json_metadata: bool -@pytest.fixture(params=["hdf5", "n5", "zarr"]) +@pytest.fixture(params=["hdf5", "n5", "zarr"], scope="session") def mode(request): if request.param == "hdf5": return Mode("hdf5", dat_to_hdf5, hdf5_to_bytes, False) @@ -63,9 +63,13 @@ class RoundtripResult(NamedTuple): json_metadata: bool -@pytest.fixture -def roundtripped(dat_path, mode, tmp_path): - container_path = tmp_path / f"data.{mode.name}" +@pytest.fixture(scope="session") +def roundtripped(dat_path, mode, tmp_path_factory): + container_path = tmp_path_factory.mktemp("containers") / f"data.{mode.name}" mode.dat_to_container(dat_path, container_path) written_bytes = mode.container_to_dat_bytes(container_path) - return RoundtripResult(dat_path, container_path, written_bytes, mode.json_metadata) + yield RoundtripResult(dat_path, container_path, written_bytes, mode.json_metadata) + try: + shutil.rmtree(container_path) + except OSError: + pass From a31ad1d8a635dc403e448b11f2af30de830d2e6c Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Thu, 21 Jul 2022 15:54:54 +0100 Subject: [PATCH 09/10] Improve hashing --- jeiss_convert/utils.py | 22 +++++++++++++------ jeiss_convert/verify.py | 48 ++++++++++++++++++----------------------- tests/test_main.py | 5 +++-- 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/jeiss_convert/utils.py b/jeiss_convert/utils.py index 14df51c..7dffa9a 100644 --- a/jeiss_convert/utils.py +++ b/jeiss_convert/utils.py @@ -242,12 +242,6 @@ def into_bytes(val) -> bytes: ) -def md5sum(b): - md5 = hashlib.md5() - md5.update(b) - return md5.hexdigest() - - def group_to_bytes(g, json_metadata=False, check_header=True): if json_metadata: meta = metadata_to_numpy(g.attrs) @@ -257,7 +251,7 @@ def group_to_bytes(g, json_metadata=False, check_header=True): header = write_header(meta) if check_header: stored_header = into_bytes(g.attrs.get(HEADER_KEY)) - if stored_header and md5sum(stored_header) != md5sum(header): + if stored_header and stored_header != header: raise RuntimeError( f"Stored header (length {len(stored_header)}) is different to " f"calculated header (length {len(header)})" @@ -275,3 +269,17 @@ def group_to_bytes(g, json_metadata=False, check_header=True): dtype = stacked.dtype.newbyteorder(DEFAULT_BYTE_ORDER) b = np.asarray(stacked, dtype, order="F").tobytes(order="F") return header + b + footer + + +def hashsum( + stream: tp.Union[bytes, tp.BinaryIO], + hash_cls=hashlib.blake2b, + chunk_size: int = 4096, +): + hasher = hash_cls() + if isinstance(stream, bytes): + hasher.update(stream) + else: + while chunk := stream.read(chunk_size): + hasher.update(chunk) + return hasher.hexdigest() diff --git a/jeiss_convert/verify.py b/jeiss_convert/verify.py index d0e3e8e..ef15369 100644 --- a/jeiss_convert/verify.py +++ b/jeiss_convert/verify.py @@ -4,17 +4,14 @@ to an existing Jeiss FIBSEM .dat file, so that the .dat can be safely deleted. """ +from contextlib import contextmanager import sys from argparse import ArgumentParser from pathlib import Path from .hdf5 import hdf5_to_bytes -from .utils import md5sum from .version import version - - -def noop(arg): - return arg +from .utils import hashsum def warn(*args, **kwargs): @@ -22,7 +19,7 @@ def warn(*args, **kwargs): print(*args, **kwargs) -def write_dat(fpath: Path, hdf5_path, group=None): +def write_dat(fpath: Path, dat_bytes: bytes): if fpath.exists(): warn("Exiting due to existing file at " + str(fpath)) return 2 @@ -40,23 +37,24 @@ def write_dat(fpath: Path, hdf5_path, group=None): if response.lower() in ["", "n", "no"]: warn("Not writing or validating anything") return 0 - warn( - "Interpreting response non-'yes' response " - f"'{response}' as negative, exiting" - ) + warn("Interpreting non-'yes' response " f"'{response}' as negative, exiting") return 2 - b = hdf5_to_bytes(hdf5_path, group) if str(fpath) == "-": - sys.stdout.buffer.write(b) + sys.stdout.buffer.write(dat_bytes) else: - fpath.write_bytes(b) + fpath.write_bytes(dat_bytes) + return 0 -def read_bytes(fpath: Path): + +@contextmanager +def open_bytes(fpath: Path): if str(fpath) == "-": - return sys.stdin.buffer.read() - return fpath.read_bytes() + yield sys.stdin.buffer + else: + with open(fpath, "rb") as f: + yield f def main(args=None): @@ -74,12 +72,6 @@ def main(args=None): action="store_true", help="Delete the .dat file if the check succeeds", ) - parser.add_argument( - "-s", - "--strict", - action="store_true", - help="Check for identity of bytes rather than hash (slow and unnecessary)", - ) parser.add_argument( "--write-dat", action="store_true", @@ -95,15 +87,17 @@ def main(args=None): version=version, ) parsed = parser.parse_args(args) + reconverted_bytes = hdf5_to_bytes(parsed.hdf5, parsed.group) + if parsed.write_dat: - return write_dat(parsed.dat, parsed.hdf5, parsed.group) + return write_dat(parsed.dat, reconverted_bytes) - fn = noop if parsed.strict else md5sum + with open_bytes(parsed.dat) as f: + dat_sum = hashsum(f) - dat = fn(read_bytes(parsed.dat)) - hdf5 = fn(hdf5_to_bytes(parsed.hdf5, parsed.group)) + reconverted_sum = hashsum(reconverted_bytes) - if dat != hdf5: + if dat_sum != reconverted_sum: return 1 if parsed.delete_dat: diff --git a/tests/test_main.py b/tests/test_main.py index 74c5d61..5e25a11 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,5 +1,6 @@ from contextlib import contextmanager from pathlib import Path +import hashlib import h5py import numpy as np @@ -8,12 +9,12 @@ from jeiss_convert.misc import FOOTER_KEY, HEADER_LENGTH from jeiss_convert.utils import ( into_bytes, - md5sum, metadata_to_jso, metadata_to_numpy, parse_file, split_channels, write_header, + hashsum, ) from .conftest import Mode, RoundtripResult @@ -86,7 +87,7 @@ def assert_bytes_equal(test: bytes, ref: bytes, ends=64, test_full=False): assert len(test) == len(ref) assert test[:ends] == ref[:ends] assert test[ends:] == ref[ends:] - assert md5sum(test) == md5sum(ref) + assert hashsum(test) == hashsum(ref) if test_full: assert test == ref From d314daab9e34549b8616ecb9688a2f5cb2f60e89 Mon Sep 17 00:00:00 2001 From: Chris Barnes Date: Thu, 21 Jul 2022 16:19:09 +0100 Subject: [PATCH 10/10] fmt --- jeiss_convert/verify.py | 4 ++-- tests/test_main.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/jeiss_convert/verify.py b/jeiss_convert/verify.py index ef15369..00c3d49 100644 --- a/jeiss_convert/verify.py +++ b/jeiss_convert/verify.py @@ -4,14 +4,14 @@ to an existing Jeiss FIBSEM .dat file, so that the .dat can be safely deleted. """ -from contextlib import contextmanager import sys from argparse import ArgumentParser +from contextlib import contextmanager from pathlib import Path from .hdf5 import hdf5_to_bytes -from .version import version from .utils import hashsum +from .version import version def warn(*args, **kwargs): diff --git a/tests/test_main.py b/tests/test_main.py index 5e25a11..0b996b2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,5 @@ from contextlib import contextmanager from pathlib import Path -import hashlib import h5py import numpy as np @@ -8,13 +7,13 @@ from jeiss_convert.misc import FOOTER_KEY, HEADER_LENGTH from jeiss_convert.utils import ( + hashsum, into_bytes, metadata_to_jso, metadata_to_numpy, parse_file, split_channels, write_header, - hashsum, ) from .conftest import Mode, RoundtripResult