Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 139 additions & 56 deletions connectomics/data/processing/lsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@

from __future__ import annotations

from typing import Iterable, Optional, Sequence, Union
from typing import Any, Iterable, Optional, Sequence, Union, cast

import numpy as np
from numpy.lib.stride_tricks import as_strided
from scipy.ndimage import convolve, gaussian_filter
from scipy.ndimage import convolve, find_objects, gaussian_filter

__all__ = ["LsdExtractor", "seg_to_lsd"]

TRUNCATE = 3.0


def seg_to_lsd(
label: np.ndarray,
Expand Down Expand Up @@ -67,12 +69,10 @@ def seg_to_lsd(
def _coerce_sigma(sigma: Union[float, Sequence[float]], ndim: int) -> tuple:
"""Broadcast a scalar sigma into per-axis tuple matching ``ndim``."""
if np.isscalar(sigma):
return tuple(float(sigma) for _ in range(ndim))
sigma_tuple = tuple(float(v) for v in sigma)
return tuple(float(cast(Any, sigma)) for _ in range(ndim))
sigma_tuple = tuple(float(v) for v in cast(Sequence[float], sigma))
if len(sigma_tuple) != ndim:
raise ValueError(
f"sigma length {len(sigma_tuple)} does not match label dim {ndim}"
)
raise ValueError(f"sigma length {len(sigma_tuple)} does not match label dim {ndim}")
return sigma_tuple


Expand Down Expand Up @@ -113,11 +113,13 @@ def get_descriptors(
# Trim to the 2D sigma if a 3D one was supplied.
self.sigma = self.sigma[:2]

voxel_size_t = tuple(1 for _ in range(dims)) if voxel_size is None else tuple(int(v) for v in voxel_size)
voxel_size_t = (
tuple(1 for _ in range(dims))
if voxel_size is None
else tuple(int(v) for v in voxel_size)
)
if len(voxel_size_t) != dims:
raise ValueError(
f"voxel_size length {len(voxel_size_t)} != label dim {dims}"
)
raise ValueError(f"voxel_size length {len(voxel_size_t)} != label dim {dims}")

if labels is None:
labels_arr = np.unique(segmentation)
Expand All @@ -137,10 +139,61 @@ def get_descriptors(
f"segmentation shape {segmentation.shape} is not divisible by "
f"downsample factor {df}"
)
sub_shape = tuple(s // df for s in segmentation.shape)
sub_voxel_size = tuple(v * df for v in voxel_size_t)
sub_sigma_voxel = tuple(s / v for s, v in zip(self.sigma, sub_voxel_size))

if df == 1:
self._accumulate_bbox(
descriptors,
segmentation,
labels_arr,
sub_sigma_voxel,
sub_voxel_size,
components,
dims,
)
else:
self._accumulate_full(
descriptors,
segmentation,
labels_arr,
sub_sigma_voxel,
sub_voxel_size,
components,
df,
dims,
)

# Normalize to [0, 1]: mean offsets and Pearson coefficients have signed
# ranges that we shift into [0, 1] for prediction.
if self.mode == "gaussian":
# Farthest weighted voxel ≈ sigma (3-sigma cap is rarely reached).
max_distance = np.asarray(self.sigma, dtype=np.float32)
else: # sphere
max_distance = np.asarray([0.5 * s for s in self.sigma], dtype=np.float32)

seg_mask = (segmentation != 0).astype(np.float32)

if dims == 3:
self._normalize_3d(descriptors, max_distance, seg_mask, components)
else:
self._normalize_2d(descriptors, max_distance, seg_mask, components)

np.clip(descriptors, 0.0, 1.0, out=descriptors)
return descriptors

def _accumulate_full(
self,
descriptors: np.ndarray,
segmentation: np.ndarray,
labels_arr: np.ndarray,
sub_sigma_voxel: tuple,
sub_voxel_size: tuple,
components: Optional[str],
df: int,
dims: int,
) -> None:
sub_shape = tuple(s // df for s in segmentation.shape)
coords = self._get_or_build_coords(sub_shape, sub_voxel_size)

for raw_label in labels_arr:
Expand All @@ -162,23 +215,69 @@ def get_descriptors(
descriptor = self._upsample(sub_descriptor, df)
descriptors += descriptor * mask

# Normalize to [0, 1]: mean offsets and Pearson coefficients have signed
# ranges that we shift into [0, 1] for prediction.
if self.mode == "gaussian":
# Farthest weighted voxel ≈ sigma (3-sigma cap is rarely reached).
max_distance = np.asarray(self.sigma, dtype=np.float32)
else: # sphere
max_distance = np.asarray([0.5 * s for s in self.sigma], dtype=np.float32)

seg_mask = (segmentation != 0).astype(np.float32)

if dims == 3:
self._normalize_3d(descriptors, max_distance, seg_mask, components)
else:
self._normalize_2d(descriptors, max_distance, seg_mask, components)
def _accumulate_bbox(
self,
descriptors: np.ndarray,
segmentation: np.ndarray,
labels_arr: np.ndarray,
sub_sigma_voxel: tuple,
sub_voxel_size: tuple,
components: Optional[str],
dims: int,
) -> None:
present = [int(raw_label) for raw_label in labels_arr if int(raw_label) != 0]
if not present:
return

np.clip(descriptors, 0.0, 1.0, out=descriptors)
return descriptors
radius = tuple(int(np.ceil(TRUNCATE * sigma)) for sigma in sub_sigma_voxel)
max_label = int(segmentation.max())
use_find_objects = (
np.issubdtype(segmentation.dtype, np.integer)
and max_label >= 1
and max_label <= max(64, 8 * len(present))
)
objects = find_objects(segmentation) if use_find_objects else None

for label in present:
bbox = None
if objects is not None:
if 1 <= label <= len(objects):
bbox = objects[label - 1]
if bbox is None:
continue
else:
eq = segmentation == label
if not np.any(eq):
continue
slices: list[slice] = []
for axis in range(dims):
other_axes = tuple(d for d in range(dims) if d != axis)
occupied = np.where(eq.any(axis=other_axes))[0]
if occupied.size == 0:
slices = []
break
slices.append(slice(int(occupied[0]), int(occupied[-1]) + 1))
if not slices:
continue
bbox = tuple(slices)

crop = tuple(
slice(
max(0, bbox[d].start - radius[d]),
min(segmentation.shape[d], bbox[d].stop + radius[d]),
)
for d in range(dims)
)
sub = segmentation[crop]
mask = (sub == label).astype(np.float32)
coords_local = self._get_or_build_coords(mask.shape, sub_voxel_size)
offset = np.asarray(
[crop[d].start * sub_voxel_size[d] for d in range(dims)],
dtype=np.float32,
).reshape((dims,) + (1,) * dims)
coords_local = coords_local + offset
desc = np.concatenate(self._get_stats(coords_local, mask, sub_sigma_voxel, components))
descriptors[(slice(None),) + crop] += desc * mask[None]

def _get_or_build_coords(self, sub_shape: tuple, sub_voxel_size: tuple) -> np.ndarray:
key = (sub_shape, sub_voxel_size)
Expand Down Expand Up @@ -209,14 +308,10 @@ def _get_stats(
count = np.where(count == 0, 1.0, count)

# Mean (center-of-mass per voxel) along each axis.
mean = np.stack(
[self._aggregate(masked_coords[d], sigma_voxel) for d in range(count_len)]
)
mean = np.stack([self._aggregate(masked_coords[d], sigma_voxel) for d in range(count_len)])
mean = mean / count

need_mean_offset = components is None or any(
str(c) in components for c in range(count_len)
)
need_mean_offset = components is None or any(str(c) in components for c in range(count_len))
need_cov = components is None or any(
str(c) in components for c in range(count_len, 4 * count_len - 3)
)
Expand All @@ -229,9 +324,7 @@ def _get_stats(
if need_cov:
coords_outer = self._outer_product(masked_coords)
entries = [0, 4, 8, 1, 2, 5] if count_len == 3 else [0, 3, 1]
covariance = np.stack(
[self._aggregate(coords_outer[d], sigma_voxel) for d in entries]
)
covariance = np.stack([self._aggregate(coords_outer[d], sigma_voxel) for d in entries])
covariance = covariance / count
covariance -= self._outer_product(mean)[entries]

Expand Down Expand Up @@ -275,9 +368,7 @@ def _get_stats(
elif i == 9:
ret.append(count[None, :])
else:
raise ValueError(
f"3D LSD components must be in 0..9, got {i}"
)
raise ValueError(f"3D LSD components must be in 0..9, got {i}")
else: # 2D
if 0 <= i < 2:
ret.append(mean_offset[[i]])
Expand All @@ -288,16 +379,12 @@ def _get_stats(
elif i == 5:
ret.append(count[None, :])
else:
raise ValueError(
f"2D LSD components must be in 0..5, got {i}"
)
raise ValueError(f"2D LSD components must be in 0..5, got {i}")
return tuple(ret)

def _aggregate(self, array: np.ndarray, sigma: tuple) -> np.ndarray:
if self.mode == "gaussian":
return gaussian_filter(
array, sigma=sigma, mode="constant", cval=0.0, truncate=3.0
)
return gaussian_filter(array, sigma=sigma, mode="constant", cval=0.0, truncate=TRUNCATE)
radius = sigma[0]
if any(s != radius for s in sigma):
raise ValueError("mode='sphere' requires isotropic sigma")
Expand All @@ -306,7 +393,7 @@ def _aggregate(self, array: np.ndarray, sigma: tuple) -> np.ndarray:

@staticmethod
def _make_sphere(radius: int) -> np.ndarray:
r2 = np.arange(-radius, radius) ** 2
r2: np.ndarray = np.arange(-radius, radius) ** 2
dist2 = r2[:, None, None] + r2[:, None] + r2
return (dist2 <= radius**2).astype(np.float32)

Expand All @@ -323,6 +410,8 @@ def _upsample(array: np.ndarray, factor: int) -> np.ndarray:
return array
shape = array.shape
stride = array.strides
sh: tuple[int, ...]
st: tuple[int, ...]
if array.ndim == 4:
sh = (shape[0], shape[1], factor, shape[2], factor, shape[3], factor)
st = (stride[0], stride[1], 0, stride[2], 0, stride[3], 0)
Expand Down Expand Up @@ -350,9 +439,7 @@ def _normalize_3d(
for slot, token in enumerate(components):
c = int(token)
if 0 <= c < 3:
descriptors[slot] = (
descriptors[slot] / max_distance[c] * 0.5 + 0.5
) * seg_mask
descriptors[slot] = (descriptors[slot] / max_distance[c] * 0.5 + 0.5) * seg_mask
elif 6 <= c < 9:
descriptors[slot] = (descriptors[slot] * 0.5 + 0.5) * seg_mask

Expand All @@ -364,17 +451,13 @@ def _normalize_2d(
components: Optional[str],
) -> None:
if components is None:
descriptors[[0, 1]] = (
descriptors[[0, 1]] / max_distance[:, None, None] * 0.5 + 0.5
)
descriptors[[0, 1]] = descriptors[[0, 1]] / max_distance[:, None, None] * 0.5 + 0.5
descriptors[[4]] = descriptors[[4]] * 0.5 + 0.5
descriptors[[0, 1, 4]] *= seg_mask
return
for slot, token in enumerate(components):
c = int(token)
if 0 <= c < 2:
descriptors[slot] = (
descriptors[slot] / max_distance[c] * 0.5 + 0.5
) * seg_mask
descriptors[slot] = (descriptors[slot] / max_distance[c] * 0.5 + 0.5) * seg_mask
elif c == 4:
descriptors[slot] = (descriptors[slot] * 0.5 + 0.5) * seg_mask
Loading
Loading