diff --git a/jeiss_convert/misc.py b/jeiss_convert/misc.py index fd7a073..a95445e 100644 --- a/jeiss_convert/misc.py +++ b/jeiss_convert/misc.py @@ -17,3 +17,6 @@ HEADER_LENGTH = _misc["data_offset"] MAGIC_NUMBER = _misc["magic_number"] DATE_FORMAT = _misc["date_format"] + +HEADER_KEY = "_header" +FOOTER_KEY = "_footer" diff --git a/jeiss_convert/utils.py b/jeiss_convert/utils.py index c784202..7dffa9a 100644 --- a/jeiss_convert/utils.py +++ b/jeiss_convert/utils.py @@ -6,7 +6,14 @@ import numpy as np -from .misc import DEFAULT_AXIS_ORDER, DEFAULT_BYTE_ORDER, HEADER_LENGTH, SPEC_DIR +from .misc import ( + DEFAULT_AXIS_ORDER, + DEFAULT_BYTE_ORDER, + FOOTER_KEY, + HEADER_KEY, + HEADER_LENGTH, + SPEC_DIR, +) from .version import version logger = logging.getLogger(__name__) @@ -205,29 +212,29 @@ def split_channels( if json_metadata: meta = metadata_to_jso(all_data.meta) if all_data.header is not None: - meta["_header"] = all_data.header.hex() + meta[HEADER_KEY] = all_data.header.hex() if all_data.footer is not None: - meta["_footer"] = all_data.footer.hex() + meta[FOOTER_KEY] = all_data.footer.hex() else: meta = all_data.meta if all_data.header is not None: - meta["_header"] = np.frombuffer(all_data.header, "uint8") + meta[HEADER_KEY] = np.frombuffer(all_data.header, "uint8") if all_data.footer is not None: - meta["_footer"] = np.frombuffer(all_data.footer, "uint8") + meta[FOOTER_KEY] = np.frombuffer(all_data.footer, "uint8") meta["_dat2hdf5_version"] = version return meta, channel_names, all_data.data -def get_bytes(d: dict[str, tp.Any], key: str): - val = d.get(key) - if val is None: +def into_bytes(val) -> bytes: + if isinstance(val, bytes): + return val + elif val is None: return b"" - - if isinstance(val, str): + elif isinstance(val, str): return bytes.fromhex(val) elif isinstance(val, np.ndarray): - return val.tobytes() + return val.tobytes() # todo: this is C by default, might want A else: raise ValueError( "Expected str (hex-encoded) or uint8 numpy array " @@ -235,12 +242,6 @@ def get_bytes(d: dict[str, tp.Any], key: str): ) -def md5sum(b): - md5 = hashlib.md5() - md5.update(b) - return md5.hexdigest() - - def group_to_bytes(g, json_metadata=False, check_header=True): if json_metadata: meta = metadata_to_numpy(g.attrs) @@ -249,13 +250,13 @@ def group_to_bytes(g, json_metadata=False, check_header=True): header = write_header(meta) if check_header: - stored_header = get_bytes(meta, "_header") - if stored_header and md5sum(stored_header) != md5sum(header): + stored_header = into_bytes(g.attrs.get(HEADER_KEY)) + if stored_header and stored_header != header: raise RuntimeError( f"Stored header (length {len(stored_header)}) is different to " f"calculated header (length {len(header)})" ) - footer = get_bytes(meta, "_footer") + footer = into_bytes(g.attrs.get(FOOTER_KEY)) to_stack = [] for input_id in range(1, 5): @@ -268,3 +269,17 @@ def group_to_bytes(g, json_metadata=False, check_header=True): dtype = stacked.dtype.newbyteorder(DEFAULT_BYTE_ORDER) b = np.asarray(stacked, dtype, order="F").tobytes(order="F") return header + b + footer + + +def hashsum( + stream: tp.Union[bytes, tp.BinaryIO], + hash_cls=hashlib.blake2b, + chunk_size: int = 4096, +): + hasher = hash_cls() + if isinstance(stream, bytes): + hasher.update(stream) + else: + while chunk := stream.read(chunk_size): + hasher.update(chunk) + return hasher.hexdigest() diff --git a/jeiss_convert/verify.py b/jeiss_convert/verify.py index d0e3e8e..00c3d49 100644 --- a/jeiss_convert/verify.py +++ b/jeiss_convert/verify.py @@ -6,23 +6,20 @@ """ import sys from argparse import ArgumentParser +from contextlib import contextmanager from pathlib import Path from .hdf5 import hdf5_to_bytes -from .utils import md5sum +from .utils import hashsum from .version import version -def noop(arg): - return arg - - def warn(*args, **kwargs): kwargs.setdefault("file", sys.stderr) print(*args, **kwargs) -def write_dat(fpath: Path, hdf5_path, group=None): +def write_dat(fpath: Path, dat_bytes: bytes): if fpath.exists(): warn("Exiting due to existing file at " + str(fpath)) return 2 @@ -40,23 +37,24 @@ def write_dat(fpath: Path, hdf5_path, group=None): if response.lower() in ["", "n", "no"]: warn("Not writing or validating anything") return 0 - warn( - "Interpreting response non-'yes' response " - f"'{response}' as negative, exiting" - ) + warn("Interpreting non-'yes' response " f"'{response}' as negative, exiting") return 2 - b = hdf5_to_bytes(hdf5_path, group) if str(fpath) == "-": - sys.stdout.buffer.write(b) + sys.stdout.buffer.write(dat_bytes) else: - fpath.write_bytes(b) + fpath.write_bytes(dat_bytes) + return 0 -def read_bytes(fpath: Path): + +@contextmanager +def open_bytes(fpath: Path): if str(fpath) == "-": - return sys.stdin.buffer.read() - return fpath.read_bytes() + yield sys.stdin.buffer + else: + with open(fpath, "rb") as f: + yield f def main(args=None): @@ -74,12 +72,6 @@ def main(args=None): action="store_true", help="Delete the .dat file if the check succeeds", ) - parser.add_argument( - "-s", - "--strict", - action="store_true", - help="Check for identity of bytes rather than hash (slow and unnecessary)", - ) parser.add_argument( "--write-dat", action="store_true", @@ -95,15 +87,17 @@ def main(args=None): version=version, ) parsed = parser.parse_args(args) + reconverted_bytes = hdf5_to_bytes(parsed.hdf5, parsed.group) + if parsed.write_dat: - return write_dat(parsed.dat, parsed.hdf5, parsed.group) + return write_dat(parsed.dat, reconverted_bytes) - fn = noop if parsed.strict else md5sum + with open_bytes(parsed.dat) as f: + dat_sum = hashsum(f) - dat = fn(read_bytes(parsed.dat)) - hdf5 = fn(hdf5_to_bytes(parsed.hdf5, parsed.group)) + reconverted_sum = hashsum(reconverted_bytes) - if dat != hdf5: + if dat_sum != reconverted_sum: return 1 if parsed.delete_dat: diff --git a/jeiss_convert/zr.py b/jeiss_convert/zr.py new file mode 100644 index 0000000..a735c27 --- /dev/null +++ b/jeiss_convert/zr.py @@ -0,0 +1,78 @@ +import typing as tp +from pathlib import Path + +import zarr + +from .utils import group_to_bytes, split_channels + +StoreFactory = tp.Callable[[Path], zarr.storage.BaseStore] + + +def _dat_to_zarr( + store_factory: StoreFactory, + dat_path: Path, + container_path: Path, + group_name: tp.Optional[str] = None, + ds_kwargs: tp.Optional[dict[str, tp.Any]] = None, +): + meta, channel_names, data = split_channels(dat_path, True) + + if ds_kwargs is None: + ds_kwargs = dict() + + if not group_name: + group_name = "/" + + store = store_factory(container_path) + container: zarr.Group = zarr.open(store, mode="a") + if group_name == "/": + group = container + else: + group = container.create_group(group_name) + + group.attrs.update(meta) + + for idx, ds in enumerate(channel_names): + group.create_dataset(ds, data=data[idx], **ds_kwargs) + + +def _zarr_to_bytes( + store_factory: StoreFactory, + container_path: Path, + group_name: tp.Optional[str] = None, +): + if not group_name: + group_name = "/" + + store = store_factory(container_path) + container: zarr.Group = zarr.open(store, "r") + g = container[group_name] + return group_to_bytes(g, True) + + +def dat_to_zarr( + dat_path: Path, + container_path: Path, + group_name: tp.Optional[str] = None, + ds_kwargs: tp.Optional[dict[str, tp.Any]] = None, +): + return _dat_to_zarr( + zarr.NestedDirectoryStore, dat_path, container_path, group_name, ds_kwargs + ) + + +def zarr_to_bytes(container_path: Path, group_name: tp.Optional[str] = None): + return _zarr_to_bytes(lambda x: x, container_path, group_name) + + +def dat_to_n5( + dat_path: Path, + container_path: Path, + group_name: tp.Optional[str] = None, + ds_kwargs: tp.Optional[dict[str, tp.Any]] = None, +): + return _dat_to_zarr(zarr.N5Store, dat_path, container_path, group_name, ds_kwargs) + + +def n5_to_bytes(container_path: Path, group_name: tp.Optional[str] = None): + return _zarr_to_bytes(zarr.N5Store, container_path, group_name) diff --git a/requirements.txt b/requirements.txt index 0e7f15a..d8c8d1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ numpy h5py tomli;python_version<"3.11" +zarr # dev diff --git a/setup.py b/setup.py index 52817c0..db64007 100755 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ "numpy", "h5py", "tomli; python_version < '3.11'", + "zarr", ], python_requires=">=3.9, <4.0", classifiers=[ diff --git a/tests/conftest.py b/tests/conftest.py index 81f3f7b..e80c1cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,14 @@ import csv +import shutil from pathlib import Path +from typing import Callable, NamedTuple import pooch import pytest +from jeiss_convert.hdf5 import dat_to_hdf5, hdf5_to_bytes +from jeiss_convert.zr import dat_to_n5, dat_to_zarr, n5_to_bytes, zarr_to_bytes + project_dir = Path(__file__).resolve().parent.parent spec_dir = project_dir / "jeiss_convert" / "jeiss-specs" versions = [ @@ -22,7 +27,7 @@ def sample_dats(): return out -@pytest.fixture(params=versions) +@pytest.fixture(params=versions, scope="session") def dat_path(request, sample_dats): version = request.param if version not in sample_dats: @@ -30,3 +35,41 @@ def dat_path(request, sample_dats): md5sum, url = sample_dats[version] return Path(pooch.retrieve(url=url, known_hash="md5:" + md5sum)) + + +class Mode(NamedTuple): + name: str + dat_to_container: Callable + container_to_dat_bytes: Callable + json_metadata: bool + + +@pytest.fixture(params=["hdf5", "n5", "zarr"], scope="session") +def mode(request): + if request.param == "hdf5": + return Mode("hdf5", dat_to_hdf5, hdf5_to_bytes, False) + elif request.param == "n5": + return Mode("n5", dat_to_n5, n5_to_bytes, True) + elif request.param == "zarr": + return Mode("zarr", dat_to_zarr, zarr_to_bytes, True) + else: + raise ValueError("Unknown mode name") + + +class RoundtripResult(NamedTuple): + dat_path: Path + container_path: Path + written_bytes: bytes + json_metadata: bool + + +@pytest.fixture(scope="session") +def roundtripped(dat_path, mode, tmp_path_factory): + container_path = tmp_path_factory.mktemp("containers") / f"data.{mode.name}" + mode.dat_to_container(dat_path, container_path) + written_bytes = mode.container_to_dat_bytes(container_path) + yield RoundtripResult(dat_path, container_path, written_bytes, mode.json_metadata) + try: + shutil.rmtree(container_path) + except OSError: + pass diff --git a/tests/test_main.py b/tests/test_main.py index 173abc2..0b996b2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,22 @@ +from contextlib import contextmanager from pathlib import Path -from jeiss_convert import convert, verify +import h5py +import numpy as np +import zarr + +from jeiss_convert.misc import FOOTER_KEY, HEADER_LENGTH +from jeiss_convert.utils import ( + hashsum, + into_bytes, + metadata_to_jso, + metadata_to_numpy, + parse_file, + split_channels, + write_header, +) + +from .conftest import Mode, RoundtripResult def test_importable(): @@ -9,10 +25,158 @@ def test_importable(): assert jeiss_convert.__version__ -def test_convert_verify(dat_path, tmpdir): - hdf5_path = Path(tmpdir / "data.hdf5") - conv_status = convert.main([str(dat_path), str(hdf5_path)]) - assert conv_status == 0 - assert hdf5_path.is_file() - verif_status = verify.main([str(dat_path), str(hdf5_path)]) - assert verif_status == 0 +def test_can_write(dat_path, mode, tmp_path): + container_path = tmp_path / f"data.{mode.name}" + mode.dat_to_container(dat_path, container_path) + assert container_path.exists() + + +def test_can_reconvert_written(dat_path, mode, tmp_path): + container_path = tmp_path / f"data.{mode.name}" + mode.dat_to_container(dat_path, container_path) + written_bytes = mode.container_to_dat_bytes(container_path) + assert len(written_bytes) + + +# def test_cli_convert_verify(dat_path, tmp_path): +# hdf5_path = Path(tmp_path / "data.hdf5") +# conv_status = convert.main([str(dat_path), str(hdf5_path)]) +# assert conv_status == 0 +# assert hdf5_path.is_file() +# verif_status = verify.main([str(dat_path), str(hdf5_path)]) +# assert verif_status == 0 + + +def assert_arrays_equal(test: np.ndarray, ref: np.ndarray): + # assert test.dtype == ref.dtype + assert test.shape == ref.shape + assert test.ravel()[0] == ref.ravel()[0] + assert np.allclose(test, ref) + + +def assert_channel_arrays_equal( + test: dict[str, np.ndarray], ref: dict[str, np.ndarray] +): + assert set(test) == set(ref) + for key, a_val in test.items(): + assert_arrays_equal(a_val, ref[key]) + + +def channel_dict(names, data) -> dict[str, np.ndarray]: + d = {cn: data[idx] for idx, cn in enumerate(names)} + return d + + +def test_array_roundtrip(roundtripped: RoundtripResult): + dat_path, container_path, _, _ = roundtripped + + _, channel_names, data = split_channels(dat_path) + dat_arrays = channel_dict(channel_names, data) + + with open_root(container_path) as f: + for cn in channel_names: + dat_arr = dat_arrays[cn] + written_arr = f[cn][:] + + assert_arrays_equal(written_arr, dat_arr) + # assert_bytes_equal(into_bytes(written_arr), into_bytes(dat_arr)) + + +def assert_bytes_equal(test: bytes, ref: bytes, ends=64, test_full=False): + assert len(test) == len(ref) + assert test[:ends] == ref[:ends] + assert test[ends:] == ref[ends:] + assert hashsum(test) == hashsum(ref) + if test_full: + assert test == ref + + +def assert_metadata_contains_values(test, ref): + """test must be superset of ref. + test must contain vals coercible into dtypes of ref""" + for k, ref_v in ref.items(): + if k.startswith("_"): + continue + assert k in test + test_v = test[k] + assert test_v == ref_v + + +def read_header_bytes(path): + with open(path, "rb") as f: + return f.read(HEADER_LENGTH) + + +def test_header_roundtrip(roundtripped: RoundtripResult): + dat_path, container_path, _, json_metadata = roundtripped + meta = parse_file(dat_path) + ref_jso = metadata_to_jso(meta) + + with open_root(container_path) as f: + test_jso = f.attrs + if not json_metadata: + test_jso = metadata_to_jso(test_jso) + + test_bytes = write_header(f.attrs) + + assert_metadata_contains_values(test_jso, ref_jso) + + ref_bytes = read_header_bytes(dat_path) + assert test_bytes == ref_bytes + + +def test_meta_jso_roundtrip(dat_path): + ref = parse_file(dat_path) + meta_jso = metadata_to_jso(ref) + test = metadata_to_numpy(meta_jso) + + for k, ref_v in ref.items(): + if k.startswith("_"): + continue + test_v = test[k] + if isinstance(ref_v, bytes): + assert test_v == ref_v + else: + assert np.allclose(test_v, ref_v) + + assert write_header(test) == read_header_bytes(dat_path) + + +@contextmanager +def open_root(container_path: Path): + mode = container_path.suffix + if mode.endswith("hdf5"): + with h5py.File(container_path, "r") as f: + yield f + elif mode.endswith("zarr"): + yield zarr.open(container_path, "r") + elif mode.endswith("n5"): + yield zarr.open(zarr.N5Store(container_path), "r") + else: + raise ValueError(f"Unknown mode '{mode}'") + + +def test_footer_roundtrip(roundtripped: RoundtripResult): + dat_path, container_path, written_bytes, _ = roundtripped + + stored_footer = None + + with open_root(container_path) as g: + stored_footer = into_bytes(g.attrs[FOOTER_KEY]) + + if stored_footer is None: + raise ValueError("Could not read test footer") + + dat_bytes = dat_path.read_bytes() + dat_tail = dat_bytes[-len(stored_footer) :] + assert_bytes_equal(stored_footer, dat_tail) + written_tail = written_bytes[-len(stored_footer) :] + assert_bytes_equal(written_tail, dat_tail) + + +def test_roundtrip(dat_path, tmp_path, mode: Mode): + out_path = Path(tmp_path / f"data.{mode}") + mode.dat_to_container(dat_path, out_path) + written_bytes = mode.container_to_dat_bytes(out_path) + orig_bytes = dat_path.read_bytes() + assert_bytes_equal(written_bytes, orig_bytes)