diff --git a/connectomics/data/processing/lsd.py b/connectomics/data/processing/lsd.py index f2c173c5..241db099 100644 --- a/connectomics/data/processing/lsd.py +++ b/connectomics/data/processing/lsd.py @@ -21,14 +21,16 @@ from __future__ import annotations -from typing import Iterable, Optional, Sequence, Union +from typing import Any, Iterable, Optional, Sequence, Union, cast import numpy as np from numpy.lib.stride_tricks import as_strided -from scipy.ndimage import convolve, gaussian_filter +from scipy.ndimage import convolve, find_objects, gaussian_filter __all__ = ["LsdExtractor", "seg_to_lsd"] +TRUNCATE = 3.0 + def seg_to_lsd( label: np.ndarray, @@ -67,12 +69,10 @@ def seg_to_lsd( def _coerce_sigma(sigma: Union[float, Sequence[float]], ndim: int) -> tuple: """Broadcast a scalar sigma into per-axis tuple matching ``ndim``.""" if np.isscalar(sigma): - return tuple(float(sigma) for _ in range(ndim)) - sigma_tuple = tuple(float(v) for v in sigma) + return tuple(float(cast(Any, sigma)) for _ in range(ndim)) + sigma_tuple = tuple(float(v) for v in cast(Sequence[float], sigma)) if len(sigma_tuple) != ndim: - raise ValueError( - f"sigma length {len(sigma_tuple)} does not match label dim {ndim}" - ) + raise ValueError(f"sigma length {len(sigma_tuple)} does not match label dim {ndim}") return sigma_tuple @@ -113,11 +113,13 @@ def get_descriptors( # Trim to the 2D sigma if a 3D one was supplied. self.sigma = self.sigma[:2] - voxel_size_t = tuple(1 for _ in range(dims)) if voxel_size is None else tuple(int(v) for v in voxel_size) + voxel_size_t = ( + tuple(1 for _ in range(dims)) + if voxel_size is None + else tuple(int(v) for v in voxel_size) + ) if len(voxel_size_t) != dims: - raise ValueError( - f"voxel_size length {len(voxel_size_t)} != label dim {dims}" - ) + raise ValueError(f"voxel_size length {len(voxel_size_t)} != label dim {dims}") if labels is None: labels_arr = np.unique(segmentation) @@ -137,10 +139,61 @@ def get_descriptors( f"segmentation shape {segmentation.shape} is not divisible by " f"downsample factor {df}" ) - sub_shape = tuple(s // df for s in segmentation.shape) sub_voxel_size = tuple(v * df for v in voxel_size_t) sub_sigma_voxel = tuple(s / v for s, v in zip(self.sigma, sub_voxel_size)) + if df == 1: + self._accumulate_bbox( + descriptors, + segmentation, + labels_arr, + sub_sigma_voxel, + sub_voxel_size, + components, + dims, + ) + else: + self._accumulate_full( + descriptors, + segmentation, + labels_arr, + sub_sigma_voxel, + sub_voxel_size, + components, + df, + dims, + ) + + # Normalize to [0, 1]: mean offsets and Pearson coefficients have signed + # ranges that we shift into [0, 1] for prediction. + if self.mode == "gaussian": + # Farthest weighted voxel ≈ sigma (3-sigma cap is rarely reached). + max_distance = np.asarray(self.sigma, dtype=np.float32) + else: # sphere + max_distance = np.asarray([0.5 * s for s in self.sigma], dtype=np.float32) + + seg_mask = (segmentation != 0).astype(np.float32) + + if dims == 3: + self._normalize_3d(descriptors, max_distance, seg_mask, components) + else: + self._normalize_2d(descriptors, max_distance, seg_mask, components) + + np.clip(descriptors, 0.0, 1.0, out=descriptors) + return descriptors + + def _accumulate_full( + self, + descriptors: np.ndarray, + segmentation: np.ndarray, + labels_arr: np.ndarray, + sub_sigma_voxel: tuple, + sub_voxel_size: tuple, + components: Optional[str], + df: int, + dims: int, + ) -> None: + sub_shape = tuple(s // df for s in segmentation.shape) coords = self._get_or_build_coords(sub_shape, sub_voxel_size) for raw_label in labels_arr: @@ -162,23 +215,69 @@ def get_descriptors( descriptor = self._upsample(sub_descriptor, df) descriptors += descriptor * mask - # Normalize to [0, 1]: mean offsets and Pearson coefficients have signed - # ranges that we shift into [0, 1] for prediction. - if self.mode == "gaussian": - # Farthest weighted voxel ≈ sigma (3-sigma cap is rarely reached). - max_distance = np.asarray(self.sigma, dtype=np.float32) - else: # sphere - max_distance = np.asarray([0.5 * s for s in self.sigma], dtype=np.float32) - - seg_mask = (segmentation != 0).astype(np.float32) - - if dims == 3: - self._normalize_3d(descriptors, max_distance, seg_mask, components) - else: - self._normalize_2d(descriptors, max_distance, seg_mask, components) + def _accumulate_bbox( + self, + descriptors: np.ndarray, + segmentation: np.ndarray, + labels_arr: np.ndarray, + sub_sigma_voxel: tuple, + sub_voxel_size: tuple, + components: Optional[str], + dims: int, + ) -> None: + present = [int(raw_label) for raw_label in labels_arr if int(raw_label) != 0] + if not present: + return - np.clip(descriptors, 0.0, 1.0, out=descriptors) - return descriptors + radius = tuple(int(np.ceil(TRUNCATE * sigma)) for sigma in sub_sigma_voxel) + max_label = int(segmentation.max()) + use_find_objects = ( + np.issubdtype(segmentation.dtype, np.integer) + and max_label >= 1 + and max_label <= max(64, 8 * len(present)) + ) + objects = find_objects(segmentation) if use_find_objects else None + + for label in present: + bbox = None + if objects is not None: + if 1 <= label <= len(objects): + bbox = objects[label - 1] + if bbox is None: + continue + else: + eq = segmentation == label + if not np.any(eq): + continue + slices: list[slice] = [] + for axis in range(dims): + other_axes = tuple(d for d in range(dims) if d != axis) + occupied = np.where(eq.any(axis=other_axes))[0] + if occupied.size == 0: + slices = [] + break + slices.append(slice(int(occupied[0]), int(occupied[-1]) + 1)) + if not slices: + continue + bbox = tuple(slices) + + crop = tuple( + slice( + max(0, bbox[d].start - radius[d]), + min(segmentation.shape[d], bbox[d].stop + radius[d]), + ) + for d in range(dims) + ) + sub = segmentation[crop] + mask = (sub == label).astype(np.float32) + coords_local = self._get_or_build_coords(mask.shape, sub_voxel_size) + offset = np.asarray( + [crop[d].start * sub_voxel_size[d] for d in range(dims)], + dtype=np.float32, + ).reshape((dims,) + (1,) * dims) + coords_local = coords_local + offset + desc = np.concatenate(self._get_stats(coords_local, mask, sub_sigma_voxel, components)) + descriptors[(slice(None),) + crop] += desc * mask[None] def _get_or_build_coords(self, sub_shape: tuple, sub_voxel_size: tuple) -> np.ndarray: key = (sub_shape, sub_voxel_size) @@ -209,14 +308,10 @@ def _get_stats( count = np.where(count == 0, 1.0, count) # Mean (center-of-mass per voxel) along each axis. - mean = np.stack( - [self._aggregate(masked_coords[d], sigma_voxel) for d in range(count_len)] - ) + mean = np.stack([self._aggregate(masked_coords[d], sigma_voxel) for d in range(count_len)]) mean = mean / count - need_mean_offset = components is None or any( - str(c) in components for c in range(count_len) - ) + need_mean_offset = components is None or any(str(c) in components for c in range(count_len)) need_cov = components is None or any( str(c) in components for c in range(count_len, 4 * count_len - 3) ) @@ -229,9 +324,7 @@ def _get_stats( if need_cov: coords_outer = self._outer_product(masked_coords) entries = [0, 4, 8, 1, 2, 5] if count_len == 3 else [0, 3, 1] - covariance = np.stack( - [self._aggregate(coords_outer[d], sigma_voxel) for d in entries] - ) + covariance = np.stack([self._aggregate(coords_outer[d], sigma_voxel) for d in entries]) covariance = covariance / count covariance -= self._outer_product(mean)[entries] @@ -275,9 +368,7 @@ def _get_stats( elif i == 9: ret.append(count[None, :]) else: - raise ValueError( - f"3D LSD components must be in 0..9, got {i}" - ) + raise ValueError(f"3D LSD components must be in 0..9, got {i}") else: # 2D if 0 <= i < 2: ret.append(mean_offset[[i]]) @@ -288,16 +379,12 @@ def _get_stats( elif i == 5: ret.append(count[None, :]) else: - raise ValueError( - f"2D LSD components must be in 0..5, got {i}" - ) + raise ValueError(f"2D LSD components must be in 0..5, got {i}") return tuple(ret) def _aggregate(self, array: np.ndarray, sigma: tuple) -> np.ndarray: if self.mode == "gaussian": - return gaussian_filter( - array, sigma=sigma, mode="constant", cval=0.0, truncate=3.0 - ) + return gaussian_filter(array, sigma=sigma, mode="constant", cval=0.0, truncate=TRUNCATE) radius = sigma[0] if any(s != radius for s in sigma): raise ValueError("mode='sphere' requires isotropic sigma") @@ -306,7 +393,7 @@ def _aggregate(self, array: np.ndarray, sigma: tuple) -> np.ndarray: @staticmethod def _make_sphere(radius: int) -> np.ndarray: - r2 = np.arange(-radius, radius) ** 2 + r2: np.ndarray = np.arange(-radius, radius) ** 2 dist2 = r2[:, None, None] + r2[:, None] + r2 return (dist2 <= radius**2).astype(np.float32) @@ -323,6 +410,8 @@ def _upsample(array: np.ndarray, factor: int) -> np.ndarray: return array shape = array.shape stride = array.strides + sh: tuple[int, ...] + st: tuple[int, ...] if array.ndim == 4: sh = (shape[0], shape[1], factor, shape[2], factor, shape[3], factor) st = (stride[0], stride[1], 0, stride[2], 0, stride[3], 0) @@ -350,9 +439,7 @@ def _normalize_3d( for slot, token in enumerate(components): c = int(token) if 0 <= c < 3: - descriptors[slot] = ( - descriptors[slot] / max_distance[c] * 0.5 + 0.5 - ) * seg_mask + descriptors[slot] = (descriptors[slot] / max_distance[c] * 0.5 + 0.5) * seg_mask elif 6 <= c < 9: descriptors[slot] = (descriptors[slot] * 0.5 + 0.5) * seg_mask @@ -364,17 +451,13 @@ def _normalize_2d( components: Optional[str], ) -> None: if components is None: - descriptors[[0, 1]] = ( - descriptors[[0, 1]] / max_distance[:, None, None] * 0.5 + 0.5 - ) + descriptors[[0, 1]] = descriptors[[0, 1]] / max_distance[:, None, None] * 0.5 + 0.5 descriptors[[4]] = descriptors[[4]] * 0.5 + 0.5 descriptors[[0, 1, 4]] *= seg_mask return for slot, token in enumerate(components): c = int(token) if 0 <= c < 2: - descriptors[slot] = ( - descriptors[slot] / max_distance[c] * 0.5 + 0.5 - ) * seg_mask + descriptors[slot] = (descriptors[slot] / max_distance[c] * 0.5 + 0.5) * seg_mask elif c == 4: descriptors[slot] = (descriptors[slot] * 0.5 + 0.5) * seg_mask diff --git a/tests/unit/test_lsd.py b/tests/unit/test_lsd.py new file mode 100644 index 00000000..1e16d8c3 --- /dev/null +++ b/tests/unit/test_lsd.py @@ -0,0 +1,227 @@ +import os +import time + +import numpy as np +import pytest + +from connectomics.data.processing.lsd import LsdExtractor + +ATOL = 1e-5 +RTOL = 1e-4 + + +def _blob_labels(shape, count, seed, radius_range=(2, 5)): + rng = np.random.default_rng(seed) + segmentation = np.zeros(shape, dtype=np.int32) + grids = np.ogrid[tuple(slice(0, size) for size in shape)] + + for label in range(1, count + 1): + radii = rng.integers(radius_range[0], radius_range[1] + 1, size=len(shape)) + center = tuple(int(rng.integers(0, size)) for size in shape) + dist = np.zeros(shape, dtype=np.float32) + for grid, c, r in zip(grids, center, radii): + dist += ((grid - c) / float(r)) ** 2 + segmentation[dist <= 1.0] = label + + return segmentation + + +def _full_descriptors( + segmentation, + sigma, + *, + components=None, + voxel_size=None, + labels=None, + mode="gaussian", +): + extractor = LsdExtractor(sigma, mode=mode, downsample=1) + dims = segmentation.ndim + if dims == 2 and len(extractor.sigma) == 3: + extractor.sigma = extractor.sigma[:2] + + voxel_size_t = ( + tuple(1 for _ in range(dims)) if voxel_size is None else tuple(int(v) for v in voxel_size) + ) + labels_arr = np.unique(segmentation) if labels is None else np.asarray(list(labels)) + num_channels = (10 if dims == 3 else 6) if components is None else len(components) + descriptors = np.zeros((num_channels,) + segmentation.shape, dtype=np.float32) + + df = extractor.downsample + if any(s % df != 0 for s in segmentation.shape): + raise ValueError( + f"segmentation shape {segmentation.shape} is not divisible by " + f"downsample factor {df}" + ) + sub_voxel_size = tuple(v * df for v in voxel_size_t) + sub_sigma_voxel = tuple(s / v for s, v in zip(extractor.sigma, sub_voxel_size)) + + extractor._accumulate_full( + descriptors, + segmentation, + labels_arr, + sub_sigma_voxel, + sub_voxel_size, + components, + df, + dims, + ) + + if extractor.mode == "gaussian": + max_distance = np.asarray(extractor.sigma, dtype=np.float32) + else: + max_distance = np.asarray([0.5 * s for s in extractor.sigma], dtype=np.float32) + + seg_mask = (segmentation != 0).astype(np.float32) + if dims == 3: + extractor._normalize_3d(descriptors, max_distance, seg_mask, components) + else: + extractor._normalize_2d(descriptors, max_distance, seg_mask, components) + + np.clip(descriptors, 0.0, 1.0, out=descriptors) + return descriptors + + +def _bbox_descriptors( + segmentation, + sigma, + *, + components=None, + voxel_size=None, + labels=None, + mode="gaussian", +): + return LsdExtractor(sigma, mode=mode, downsample=1).get_descriptors( + segmentation, components=components, voxel_size=voxel_size, labels=labels + ) + + +def _assert_matches_full( + segmentation, + sigma, + *, + components=None, + voxel_size=None, + labels=None, + mode="gaussian", +): + actual = _bbox_descriptors( + segmentation, + sigma, + components=components, + voxel_size=voxel_size, + labels=labels, + mode=mode, + ) + expected = _full_descriptors( + segmentation, + sigma, + components=components, + voxel_size=voxel_size, + labels=labels, + mode=mode, + ) + max_diff = float(np.max(np.abs(actual - expected))) if actual.size else 0.0 + assert np.allclose(actual, expected, atol=ATOL, rtol=RTOL), max_diff + + +def test_bbox_matches_full_for_3d_gaussian_components_none(): + segmentation = _blob_labels((48, 48, 48), count=12, seed=11) + + _assert_matches_full(segmentation, sigma=(4.0, 4.0, 4.0)) + + +def test_bbox_matches_full_for_3d_subset_and_non_unit_voxel_size(): + segmentation = _blob_labels((64, 64, 64), count=16, seed=23) + + _assert_matches_full( + segmentation, + sigma=(6.0, 4.0, 4.0), + components="0129", + voxel_size=(2, 1, 1), + ) + + +def test_bbox_matches_full_for_3d_full_components_non_unit_voxel_size(): + segmentation = _blob_labels((32, 36, 40), count=6, seed=31) + + _assert_matches_full( + segmentation, + sigma=(6.0, 4.0, 4.0), + voxel_size=(2, 1, 1), + ) + + +def test_bbox_matches_full_for_2d_gaussian(): + segmentation = _blob_labels((64, 72), count=10, seed=41) + + _assert_matches_full(segmentation, sigma=(5.0, 4.0)) + + +def test_bbox_handles_empty_volume_and_absent_explicit_label(): + segmentation = np.zeros((24, 24, 24), dtype=np.int32) + + _assert_matches_full(segmentation, sigma=(3.0, 3.0, 3.0), labels=[99]) + actual = _bbox_descriptors(segmentation, sigma=(3.0, 3.0, 3.0)) + assert np.count_nonzero(actual) == 0 + + +def test_bbox_matches_full_for_single_label_filling_volume(): + segmentation = np.ones((24, 24, 24), dtype=np.int32) + + _assert_matches_full(segmentation, sigma=(3.0, 3.0, 3.0)) + + +def test_bbox_matches_full_for_border_touching_label(): + segmentation = np.zeros((32, 32, 32), dtype=np.int32) + segmentation[:6, :8, :10] = 1 + segmentation[14:22, 15:24, 16:25] = 2 + + _assert_matches_full(segmentation, sigma=(4.0, 4.0, 4.0)) + + +def test_bbox_matches_full_for_sparse_large_label_ids(): + segmentation = np.zeros((30, 32, 34), dtype=np.int32) + segmentation[3:8, 4:10, 5:12] = 1000 + + _assert_matches_full( + segmentation, + sigma=(3.0, 3.0, 3.0), + labels=[1000, 2000], + ) + + +def test_bbox_matches_full_for_float32_label_array(): + segmentation = _blob_labels((32, 32, 32), count=6, seed=47).astype(np.float32) + + _assert_matches_full(segmentation, sigma=(4.0, 4.0, 4.0)) + + +def test_bbox_matches_full_for_3d_sphere(): + segmentation = _blob_labels((32, 32, 32), count=5, seed=53) + + _assert_matches_full(segmentation, sigma=(3.0, 3.0, 3.0), mode="sphere") + + +@pytest.mark.skipif( + os.environ.get("LSD_BENCH") != "1", + reason="set LSD_BENCH=1 to run the LSD bbox timing benchmark", +) +def test_bbox_timing_benchmark(): + segmentation = _blob_labels((128, 128, 128), count=30, seed=67) + sigma = (8.0, 8.0, 8.0) + + start = time.perf_counter() + expected = _full_descriptors(segmentation, sigma) + full_time = time.perf_counter() - start + + start = time.perf_counter() + actual = _bbox_descriptors(segmentation, sigma) + bbox_time = time.perf_counter() - start + + max_diff = float(np.max(np.abs(actual - expected))) if actual.size else 0.0 + assert np.allclose(actual, expected, atol=ATOL, rtol=RTOL), max_diff + + speedup = full_time / bbox_time + print("LSD_BENCH " f"full={full_time:.4f}s bbox={bbox_time:.4f}s speedup={speedup:.2f}x") + assert speedup > 1.0