Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jeiss_convert/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
HEADER_LENGTH = _misc["data_offset"]
MAGIC_NUMBER = _misc["magic_number"]
DATE_FORMAT = _misc["date_format"]

HEADER_KEY = "_header"
FOOTER_KEY = "_footer"
55 changes: 35 additions & 20 deletions jeiss_convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

import numpy as np

from .misc import DEFAULT_AXIS_ORDER, DEFAULT_BYTE_ORDER, HEADER_LENGTH, SPEC_DIR
from .misc import (
DEFAULT_AXIS_ORDER,
DEFAULT_BYTE_ORDER,
FOOTER_KEY,
HEADER_KEY,
HEADER_LENGTH,
SPEC_DIR,
)
from .version import version

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -205,42 +212,36 @@ def split_channels(
if json_metadata:
meta = metadata_to_jso(all_data.meta)
if all_data.header is not None:
meta["_header"] = all_data.header.hex()
meta[HEADER_KEY] = all_data.header.hex()
if all_data.footer is not None:
meta["_footer"] = all_data.footer.hex()
meta[FOOTER_KEY] = all_data.footer.hex()
else:
meta = all_data.meta
if all_data.header is not None:
meta["_header"] = np.frombuffer(all_data.header, "uint8")
meta[HEADER_KEY] = np.frombuffer(all_data.header, "uint8")
if all_data.footer is not None:
meta["_footer"] = np.frombuffer(all_data.footer, "uint8")
meta[FOOTER_KEY] = np.frombuffer(all_data.footer, "uint8")

meta["_dat2hdf5_version"] = version
return meta, channel_names, all_data.data


def get_bytes(d: dict[str, tp.Any], key: str):
val = d.get(key)
if val is None:
def into_bytes(val) -> bytes:
if isinstance(val, bytes):
return val
elif val is None:
return b""

if isinstance(val, str):
elif isinstance(val, str):
return bytes.fromhex(val)
elif isinstance(val, np.ndarray):
return val.tobytes()
return val.tobytes() # todo: this is C by default, might want A
else:
raise ValueError(
"Expected str (hex-encoded) or uint8 numpy array "
f"to convert into bytes, got {type(val)}"
)


def md5sum(b):
md5 = hashlib.md5()
md5.update(b)
return md5.hexdigest()


def group_to_bytes(g, json_metadata=False, check_header=True):
if json_metadata:
meta = metadata_to_numpy(g.attrs)
Expand All @@ -249,13 +250,13 @@ def group_to_bytes(g, json_metadata=False, check_header=True):

header = write_header(meta)
if check_header:
stored_header = get_bytes(meta, "_header")
if stored_header and md5sum(stored_header) != md5sum(header):
stored_header = into_bytes(g.attrs.get(HEADER_KEY))
if stored_header and stored_header != header:
raise RuntimeError(
f"Stored header (length {len(stored_header)}) is different to "
f"calculated header (length {len(header)})"
)
footer = get_bytes(meta, "_footer")
footer = into_bytes(g.attrs.get(FOOTER_KEY))

to_stack = []
for input_id in range(1, 5):
Expand All @@ -268,3 +269,17 @@ def group_to_bytes(g, json_metadata=False, check_header=True):
dtype = stacked.dtype.newbyteorder(DEFAULT_BYTE_ORDER)
b = np.asarray(stacked, dtype, order="F").tobytes(order="F")
return header + b + footer


def hashsum(
stream: tp.Union[bytes, tp.BinaryIO],
hash_cls=hashlib.blake2b,
chunk_size: int = 4096,
):
hasher = hash_cls()
if isinstance(stream, bytes):
hasher.update(stream)
else:
while chunk := stream.read(chunk_size):
hasher.update(chunk)
return hasher.hexdigest()
48 changes: 21 additions & 27 deletions jeiss_convert/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,20 @@
"""
import sys
from argparse import ArgumentParser
from contextlib import contextmanager
from pathlib import Path

from .hdf5 import hdf5_to_bytes
from .utils import md5sum
from .utils import hashsum
from .version import version


def noop(arg):
return arg


def warn(*args, **kwargs):
kwargs.setdefault("file", sys.stderr)
print(*args, **kwargs)


def write_dat(fpath: Path, hdf5_path, group=None):
def write_dat(fpath: Path, dat_bytes: bytes):
if fpath.exists():
warn("Exiting due to existing file at " + str(fpath))
return 2
Expand All @@ -40,23 +37,24 @@ def write_dat(fpath: Path, hdf5_path, group=None):
if response.lower() in ["", "n", "no"]:
warn("Not writing or validating anything")
return 0
warn(
"Interpreting response non-'yes' response "
f"'{response}' as negative, exiting"
)
warn("Interpreting non-'yes' response " f"'{response}' as negative, exiting")
return 2

b = hdf5_to_bytes(hdf5_path, group)
if str(fpath) == "-":
sys.stdout.buffer.write(b)
sys.stdout.buffer.write(dat_bytes)
else:
fpath.write_bytes(b)
fpath.write_bytes(dat_bytes)

return 0

def read_bytes(fpath: Path):

@contextmanager
def open_bytes(fpath: Path):
if str(fpath) == "-":
return sys.stdin.buffer.read()
return fpath.read_bytes()
yield sys.stdin.buffer
else:
with open(fpath, "rb") as f:
yield f


def main(args=None):
Expand All @@ -74,12 +72,6 @@ def main(args=None):
action="store_true",
help="Delete the .dat file if the check succeeds",
)
parser.add_argument(
"-s",
"--strict",
action="store_true",
help="Check for identity of bytes rather than hash (slow and unnecessary)",
)
parser.add_argument(
"--write-dat",
action="store_true",
Expand All @@ -95,15 +87,17 @@ def main(args=None):
version=version,
)
parsed = parser.parse_args(args)
reconverted_bytes = hdf5_to_bytes(parsed.hdf5, parsed.group)

if parsed.write_dat:
return write_dat(parsed.dat, parsed.hdf5, parsed.group)
return write_dat(parsed.dat, reconverted_bytes)

fn = noop if parsed.strict else md5sum
with open_bytes(parsed.dat) as f:
dat_sum = hashsum(f)

dat = fn(read_bytes(parsed.dat))
hdf5 = fn(hdf5_to_bytes(parsed.hdf5, parsed.group))
reconverted_sum = hashsum(reconverted_bytes)

if dat != hdf5:
if dat_sum != reconverted_sum:
return 1

if parsed.delete_dat:
Expand Down
78 changes: 78 additions & 0 deletions jeiss_convert/zr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import typing as tp
from pathlib import Path

import zarr

from .utils import group_to_bytes, split_channels

StoreFactory = tp.Callable[[Path], zarr.storage.BaseStore]


def _dat_to_zarr(
store_factory: StoreFactory,
dat_path: Path,
container_path: Path,
group_name: tp.Optional[str] = None,
ds_kwargs: tp.Optional[dict[str, tp.Any]] = None,
):
meta, channel_names, data = split_channels(dat_path, True)

if ds_kwargs is None:
ds_kwargs = dict()

if not group_name:
group_name = "/"

store = store_factory(container_path)
container: zarr.Group = zarr.open(store, mode="a")
if group_name == "/":
group = container
else:
group = container.create_group(group_name)

group.attrs.update(meta)

for idx, ds in enumerate(channel_names):
group.create_dataset(ds, data=data[idx], **ds_kwargs)


def _zarr_to_bytes(
store_factory: StoreFactory,
container_path: Path,
group_name: tp.Optional[str] = None,
):
if not group_name:
group_name = "/"

store = store_factory(container_path)
container: zarr.Group = zarr.open(store, "r")
g = container[group_name]
return group_to_bytes(g, True)


def dat_to_zarr(
dat_path: Path,
container_path: Path,
group_name: tp.Optional[str] = None,
ds_kwargs: tp.Optional[dict[str, tp.Any]] = None,
):
return _dat_to_zarr(
zarr.NestedDirectoryStore, dat_path, container_path, group_name, ds_kwargs
)


def zarr_to_bytes(container_path: Path, group_name: tp.Optional[str] = None):
return _zarr_to_bytes(lambda x: x, container_path, group_name)


def dat_to_n5(
dat_path: Path,
container_path: Path,
group_name: tp.Optional[str] = None,
ds_kwargs: tp.Optional[dict[str, tp.Any]] = None,
):
return _dat_to_zarr(zarr.N5Store, dat_path, container_path, group_name, ds_kwargs)


def n5_to_bytes(container_path: Path, group_name: tp.Optional[str] = None):
return _zarr_to_bytes(zarr.N5Store, container_path, group_name)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
numpy
h5py
tomli;python_version<"3.11"
zarr

# dev

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"numpy",
"h5py",
"tomli; python_version < '3.11'",
"zarr",
],
python_requires=">=3.9, <4.0",
classifiers=[
Expand Down
45 changes: 44 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import csv
import shutil
from pathlib import Path
from typing import Callable, NamedTuple

import pooch
import pytest

from jeiss_convert.hdf5 import dat_to_hdf5, hdf5_to_bytes
from jeiss_convert.zr import dat_to_n5, dat_to_zarr, n5_to_bytes, zarr_to_bytes

project_dir = Path(__file__).resolve().parent.parent
spec_dir = project_dir / "jeiss_convert" / "jeiss-specs"
versions = [
Expand All @@ -22,11 +27,49 @@ def sample_dats():
return out


@pytest.fixture(params=versions)
@pytest.fixture(params=versions, scope="session")
def dat_path(request, sample_dats):
version = request.param
if version not in sample_dats:
pytest.skip(f"No sample file for version {version}")

md5sum, url = sample_dats[version]
return Path(pooch.retrieve(url=url, known_hash="md5:" + md5sum))


class Mode(NamedTuple):
name: str
dat_to_container: Callable
container_to_dat_bytes: Callable
json_metadata: bool


@pytest.fixture(params=["hdf5", "n5", "zarr"], scope="session")
def mode(request):
if request.param == "hdf5":
return Mode("hdf5", dat_to_hdf5, hdf5_to_bytes, False)
elif request.param == "n5":
return Mode("n5", dat_to_n5, n5_to_bytes, True)
elif request.param == "zarr":
return Mode("zarr", dat_to_zarr, zarr_to_bytes, True)
else:
raise ValueError("Unknown mode name")


class RoundtripResult(NamedTuple):
dat_path: Path
container_path: Path
written_bytes: bytes
json_metadata: bool


@pytest.fixture(scope="session")
def roundtripped(dat_path, mode, tmp_path_factory):
container_path = tmp_path_factory.mktemp("containers") / f"data.{mode.name}"
mode.dat_to_container(dat_path, container_path)
written_bytes = mode.container_to_dat_bytes(container_path)
yield RoundtripResult(dat_path, container_path, written_bytes, mode.json_metadata)
try:
shutil.rmtree(container_path)
except OSError:
pass
Loading