diff --git a/dpsynth/local_mode/initialization.py b/dpsynth/local_mode/initialization.py index ab3685a..139a14c 100644 --- a/dpsynth/local_mode/initialization.py +++ b/dpsynth/local_mode/initialization.py @@ -26,7 +26,6 @@ import mbi import numpy as np - _M = TypeVar('_M') @@ -118,50 +117,77 @@ def __call__( Returns: A ColumnMeasurement with bin edges and optionally a heuristic measurement. """ - # Dedup: concentrated data can make quantiles return duplicate edges. raw_edges = _validate_mechanism(self.mechanism)(rng, data).quantiles - raw_edges = np.asarray(raw_edges, dtype=float) - if self.attribute.dtype == 'int': - # Snap edges to the integer lattice. Bins are right-closed (left, - # right] and discretize uses searchsorted with side='left', so - # floor preserves the partition: edge 4.7 → floor 4 gives the - # same integer split {≤4} | {≥5} via (…, 4] | (4, …]. - raw_edges = np.floor(raw_edges) - bin_edges, edge_counts = np.unique(raw_edges, return_counts=True) - # For integer data with upper=max_value+1, edges can land at max_value - # after floor. Remove such edges and absorb their count into the last - # bin's weight so categorical_attribute_from_edges doesn't create a - # degenerate (max_value, max_value] tail bin. - # At most one edge can equal max_value: DPQuantiles clamps outputs to - # [lower, upper), so after floor + unique only the last edge can hit it. - max_val = self.attribute.max_value - if len(bin_edges) > 0 and bin_edges[-1] >= max_val: - tail_count = edge_counts[-1] - bin_edges = bin_edges[:-1] - edge_counts = edge_counts[:-1] - bin_weights = np.append(edge_counts, tail_count + 1) - else: - bin_weights = np.append(edge_counts, 1) - cat_attr = vtx.categorical_attribute_from_edges(bin_edges, self.attribute) - - measurement = None - if estimated_total is not None: - rho = self._zcdp_rho - if not self.attribute.clip_to_range: - # Prepend zero weight for the OUT_OF_DOMAIN slot at index 0. - bin_weights = np.r_[0, bin_weights] - # Query is the normalized histogram (probabilities); the noise scale - # absorbs the 1/estimated_total factor from dividing counts by n. - normalized = bin_weights / bin_weights.sum() - stddev = 1.0 / (np.sqrt(rho) * estimated_total) - measurement = mbi.LinearMeasurement( - normalized, - (self.name,), - stddev=stddev, - query=lambda f: f.normalize(1.0).datavector(), - ) - - return ColumnMeasurement(cat_attr, bin_edges, measurement=measurement) + return edges_to_column_measurement( + raw_edges=raw_edges, + attribute=self.attribute, + name=self.name, + zcdp_rho=self._zcdp_rho, + estimated_total=estimated_total, + ) + + +def edges_to_column_measurement( + raw_edges, + attribute, + name, + zcdp_rho, + estimated_total=None, +): + """Converts raw quantile edges into a ColumnMeasurement. + + Handles integer snapping, edge deduplication, degenerate-bin removal, and + categorical attribute construction. Shared between the data-based + ``NumericalInitializer`` and the histogram-based + ``HistogramNumericalInitializer``. + + Args: + raw_edges: Quantile edge values (unsorted duplicates are fine). + attribute: The ``NumericalAttribute`` defining the data domain. + name: Attribute name used as the clique key in any measurement. + zcdp_rho: Total zCDP rho consumed by the quantile mechanism. + estimated_total: If provided, a heuristic one-way measurement is included. + + Returns: + A ``ColumnMeasurement`` with bin edges and optionally a measurement. + """ + raw_edges = np.asarray(raw_edges, dtype=float) + if attribute.dtype == 'int': + # Snap edges to the integer lattice. Bins are right-closed (left, + # right] and discretize uses searchsorted with side='left', so + # floor preserves the partition: edge 4.7 → floor 4 gives the + # same integer split {≤4} | {≥5} via (…, 4] | (4, …]. + raw_edges = np.floor(raw_edges) + bin_edges, edge_counts = np.unique(raw_edges, return_counts=True) + # For integer data with upper=max_value+1, edges can land at max_value + # after floor. Remove such edges and absorb their count into the last + # bin's weight so categorical_attribute_from_edges doesn't create a + # degenerate (max_value, max_value] tail bin. + max_val = attribute.max_value + if len(bin_edges) > 0 and bin_edges[-1] >= max_val: + tail_count = edge_counts[-1] + bin_edges = bin_edges[:-1] + edge_counts = edge_counts[:-1] + bin_weights = np.append(edge_counts, tail_count + 1) + else: + bin_weights = np.append(edge_counts, 1) + cat_attr = vtx.categorical_attribute_from_edges(bin_edges, attribute) + + measurement = None + if estimated_total is not None: + if not attribute.clip_to_range: + # Prepend zero weight for the OUT_OF_DOMAIN slot at index 0. + bin_weights = np.r_[0, bin_weights] + normalized = bin_weights / bin_weights.sum() + stddev = 1.0 / (np.sqrt(zcdp_rho) * estimated_total) + measurement = mbi.LinearMeasurement( + normalized, + (name,), + stddev=stddev, + query=lambda f: f.normalize(1.0).datavector(), + ) + + return ColumnMeasurement(cat_attr, bin_edges, measurement=measurement) @dataclasses.dataclass diff --git a/dpsynth/local_mode/sufficient_statistics.py b/dpsynth/local_mode/sufficient_statistics.py new file mode 100644 index 0000000..37af679 --- /dev/null +++ b/dpsynth/local_mode/sufficient_statistics.py @@ -0,0 +1,238 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Histogram-based numerical initialization from sufficient statistics. + +This module enables numerical attribute initialization from pre-aggregated +dense histograms, removing the need for raw-data access. The primary use +case is a two-pass pipeline: a first pass (e.g., in Apache Beam) computes a +dense histogram over a fine-grained grid, then this module computes DP +quantiles from that histogram to discretize the numerical domain — exactly +as ``NumericalInitializer`` does from raw data, but without ever touching +individual records after aggregation. + +Public API: + - ``quantiles_from_histogram``: DP quantiles via recursive median splits. + - ``HistogramNumericalInitializer``: ``DPMechanism`` that produces a + ``ColumnMeasurement`` from a dense histogram. +""" + +from __future__ import annotations + +import dataclasses + +import dp_accounting +from dpsynth import domain +from dpsynth.local_mode import initialization +from dpsynth.local_mode import primitives +import numpy as np +import scipy.special + + +def _median_from_histogram( + rng: np.random.Generator, + counts: np.ndarray, + epsilon: float, +) -> int: + """Returns the index of a DP median within a dense histogram. + + Args: + rng: A numpy random number generator. + counts: Dense 1D histogram counts. + epsilon: Exponential mechanism privacy parameter for this level. + + Returns: + The index of the selected median grid point within ``counts``. + """ + total_points = len(counts) + n = counts.sum() + target = n / 2.0 + cumsum = np.cumsum(counts) + + # Infinite budget = exact median, useful for testing. + if epsilon == np.inf: + return int(np.searchsorted(cumsum, target)) + + # Score u(v) = -dist(target, [L_v, R_v]), sensitivity 1/2. + left_ranks = np.r_[0, cumsum[:-1]] + scores = -np.maximum(0, np.maximum(left_ranks - target, target - cumsum)) + + probs = scipy.special.softmax(epsilon * scores) + return int(rng.choice(total_points, p=probs)) + + +def quantiles_from_histogram( + rng: np.random.Generator, + counts: np.ndarray, + lower: float, + upper: float, + epsilon_levels: np.ndarray, + grid_size: int = 10_000_000, +) -> list[float]: + """Computes DP quantiles from a dense histogram via recursive median splits. + + Uses the discrete exponential mechanism to recursively find medians, + splitting the histogram at each level to produce ``num_buckets - 1`` + quantile edges. The number of buckets is ``2 ** len(epsilon_levels)``. + + Args: + rng: A numpy random number generator. + counts: Dense 1D histogram of shape ``(grid_size,)``. + lower: Lower bound of the data domain. + upper: Upper bound of the data domain (exclusive). + epsilon_levels: Per-level exponential mechanism epsilons, ordered from the + deepest (finest) level to the shallowest (coarsest). + grid_size: Number of uniformly spaced grid points spanning ``[lower, + upper]``. + + Returns: + A sorted list of ``2 ** len(epsilon_levels) - 1`` quantile edge values. + """ + levels = len(epsilon_levels) + if levels == 0: + return [] + + # Uniform grid: counts[i] corresponds to value lower + i * delta. + delta = (upper - lower) / (grid_size - 1) + + def _rec(lo_idx, hi_idx, depth): + if depth == 0: + return [] + sub_counts = counts[lo_idx:hi_idx] + median_local = _median_from_histogram( + rng, sub_counts, epsilon_levels[depth - 1] + ) + median_global_idx = lo_idx + median_local + median_value = lower + median_global_idx * delta + left = _rec(lo_idx, median_global_idx, depth - 1) + right = _rec(median_global_idx, hi_idx, depth - 1) + return left + [median_value] + right + + return _rec(0, len(counts), levels) + + +@dataclasses.dataclass +class HistogramNumericalInitializer(primitives.DPMechanism): + """Initializes a numerical attribute from a pre-aggregated dense histogram. + + This mechanism mirrors ``NumericalInitializer`` but operates on a dense + histogram rather than raw data. It is a composition of exponential + mechanisms (one per recursion level), producing quantile edges that + discretize the numerical domain. + + Usage follows the standard three-phase ``DPMechanism`` pattern:: + + initializer = HistogramNumericalInitializer( + name='age', attribute=attr, num_buckets=4, grid_size=10001, + ).calibrate(zcdp_rho=1.0) + result = initializer(rng, counts) + + Attributes: + name: Attribute name used as the clique key in the measurement. + attribute: The ``NumericalAttribute`` defining the data domain. + num_buckets: Number of quantile buckets (must be a power of 2). + grid_size: Number of uniformly spaced grid points spanning the attribute's + ``[min_value, exclusive_max_value]`` range. + """ + + name: str + attribute: domain.NumericalAttribute + num_buckets: int = 32 + grid_size: int = 10_000_000 + _epsilon_levels: tuple[float, ...] | None = dataclasses.field( + default=None, repr=False + ) + + @property + def _num_levels(self) -> int: + result = int(np.log2(self.num_buckets)) + if 2**result != self.num_buckets: + raise ValueError(f'{self.num_buckets=} must be a power of 2.') + return result + + def calibrate( + self, *, zcdp_rho: float, epsilon_ratio: float = 2.0 + ) -> HistogramNumericalInitializer: + """Returns a copy calibrated to the given zCDP budget. + + Args: + zcdp_rho: The zCDP privacy budget (rho). + epsilon_ratio: Factor by which epsilon grows at each deeper level. + + Returns: + A calibrated ``HistogramNumericalInitializer``. + """ + if zcdp_rho <= 0: + raise ValueError(f'zcdp_rho must be positive, got {zcdp_rho}.') + levels = self._num_levels + if levels == 0: + return dataclasses.replace(self, _epsilon_levels=()) + rho_ratio = epsilon_ratio**2 + budget_weights = rho_ratio ** np.arange(levels)[::-1] + rho_levels = zcdp_rho * budget_weights / budget_weights.sum() + eps = np.sqrt(8.0 * rho_levels) + return dataclasses.replace(self, _epsilon_levels=tuple(eps.tolist())) + + @property + def dp_event(self) -> dp_accounting.DpEvent: + """Returns the composed privacy event for the quantile computation.""" + if self._epsilon_levels is None: + raise ValueError('Must call calibrate() before accessing dp_event.') + return dp_accounting.ComposedDpEvent([ + dp_accounting.ExponentialMechanismDpEvent(epsilon=float(eps)) + for eps in self._epsilon_levels + ]) + + def __call__( + self, + rng: np.random.Generator, + counts: np.ndarray, + *, + estimated_total: float | None = None, + out_of_domain_count: int | None = None, + ) -> initialization.ColumnMeasurement: + """Computes DP quantiles from a dense histogram and returns a ColumnMeasurement. + + Args: + rng: A numpy random number generator. + counts: Dense 1D histogram of shape ``(grid_size,)``. + estimated_total: If provided, a heuristic one-way measurement is included + assuming a uniform distribution over the bins. + out_of_domain_count: Count of records outside the domain range. May only + be provided when ``attribute.clip_to_range`` is False. Currently + unused; reserved for future OOD-aware measurement construction. + + Returns: + A ``ColumnMeasurement`` with bin edges and optionally a measurement. + """ + if self._epsilon_levels is None: + raise ValueError('Must call calibrate() before calling.') + del out_of_domain_count # Reserved for future use. + + raw_edges = quantiles_from_histogram( + rng, + counts, + self.attribute.min_value, + self.attribute.exclusive_max_value, + epsilon_levels=np.asarray(self._epsilon_levels), + grid_size=self.grid_size, + ) + rho = sum(e**2 / 8.0 for e in self._epsilon_levels) + return initialization.edges_to_column_measurement( + raw_edges=raw_edges, + attribute=self.attribute, + name=self.name, + zcdp_rho=rho, + estimated_total=estimated_total, + ) diff --git a/tests/local_mode/sufficient_statistics_test.py b/tests/local_mode/sufficient_statistics_test.py new file mode 100644 index 0000000..4c78510 --- /dev/null +++ b/tests/local_mode/sufficient_statistics_test.py @@ -0,0 +1,167 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import dp_accounting +from dpsynth import domain +from dpsynth.local_mode import initialization +from dpsynth.local_mode import sufficient_statistics +import numpy as np + + +def _dense_uniform_histogram(lower, upper, n, grid_size=1000): + """Creates a dense histogram from uniformly spaced data.""" + data = np.linspace(lower, upper, n) + indices = np.round((data - lower) / (upper - lower) * (grid_size - 1)).astype( + np.int64 + ) + counts = np.zeros(grid_size, dtype=int) + for idx in indices: + counts[idx] += 1 + return counts + + +class QuantilesFromHistogramTest(parameterized.TestCase): + + def test_no_levels_returns_empty(self): + rng = np.random.default_rng(0) + counts = np.array([10]) + edges = sufficient_statistics.quantiles_from_histogram( + rng, + counts, + 0.0, + 10.0, + epsilon_levels=np.array([]), + ) + self.assertEmpty(edges) + + @parameterized.parameters(1, 2, 3, 4) + def test_edge_count_matches_levels(self, levels): + rng = np.random.default_rng(0) + counts = _dense_uniform_histogram(0.0, 10.0, 200, grid_size=10001) + edges = sufficient_statistics.quantiles_from_histogram( + rng, + counts, + 0.0, + 10.0, + epsilon_levels=np.ones(levels), + grid_size=10001, + ) + self.assertLen(edges, 2**levels - 1) + + +class HistogramNumericalInitializerTest(absltest.TestCase): + + def test_calibrate_sets_dp_event(self): + attr = domain.NumericalAttribute(min_value=0, max_value=100) + init = sufficient_statistics.HistogramNumericalInitializer( + name='age', + attribute=attr, + num_buckets=4, + grid_size=10001, + ).calibrate(zcdp_rho=1.0) + event = init.dp_event + self.assertIsInstance(event, dp_accounting.ComposedDpEvent) + # 4 buckets = 2 levels. + self.assertLen(event.events, 2) + + def test_uncalibrated_raises(self): + attr = domain.NumericalAttribute(min_value=0, max_value=100) + init = sufficient_statistics.HistogramNumericalInitializer( + name='age', + attribute=attr, + ) + with self.assertRaises(ValueError): + init(np.random.default_rng(0), np.zeros(100)) + + def test_integer_attribute_snaps_edges(self): + rng = np.random.default_rng(42) + attr = domain.NumericalAttribute(min_value=0, max_value=10, dtype='int') + counts = _dense_uniform_histogram(0.0, 11.0, 200, grid_size=10001) + init = sufficient_statistics.HistogramNumericalInitializer( + name='count', + attribute=attr, + num_buckets=4, + grid_size=10001, + ).calibrate(zcdp_rho=1.0) + cm = init(rng, counts) + for edge in cm.bin_edges: + self.assertEqual(edge, int(edge)) + + +class ParityTest(absltest.TestCase): + """Verifies NumericalInitializer and HistogramNumericalInitializer agree.""" + + def test_noisy_edge_distributions_match(self): + """At finite rho, edge distributions should be statistically similar.""" + lower, upper = 0.0, 100.0 + grid_size = 10001 + num_buckets = 4 + rho = 10.0 + num_trials = 1000 + attr = domain.NumericalAttribute( + min_value=int(lower), + max_value=int(upper), + ) + + # Snap raw data to the grid so both mechanisms operate on identical + # discrete data, eliminating the continuous-vs-discrete confound. + delta = (upper - lower) / (grid_size - 1) + raw = np.linspace(lower, upper, 1000) + indices = np.round((raw - lower) / delta).astype(np.int64) + data = lower + indices * delta + counts = _dense_uniform_histogram(lower, upper, 1000, grid_size=grid_size) + + # Collect edge samples from both initializers. Seeds are fixed + # (deterministic), so there is no flakiness risk and we can use a tight + # p-value threshold. With 1000 samples per group, the two-sample KS test + # detects CDF shifts >~0.06 at alpha=0.01 — on a [0, 100] domain, any + # bug that shifts the median distribution by more than ~6 units is caught. + num_edges = num_buckets - 1 + raw_edges = np.zeros((num_trials, num_edges)) + hist_edges = np.zeros((num_trials, num_edges)) + for i in range(num_trials): + rng = np.random.default_rng(i) + raw_cm = initialization.NumericalInitializer( + name='x', + num_partitions=num_buckets, + attribute=attr, + ).calibrate(zcdp_rho=rho)(rng, data) + raw_edges[i, : len(raw_cm.bin_edges)] = raw_cm.bin_edges + + rng = np.random.default_rng(i + num_trials) + hist_cm = sufficient_statistics.HistogramNumericalInitializer( + name='x', + attribute=attr, + num_buckets=num_buckets, + grid_size=grid_size, + ).calibrate(zcdp_rho=rho)(rng, counts) + hist_edges[i, : len(hist_cm.bin_edges)] = hist_cm.bin_edges + + # We do NOT use a KS test here because the continuous mechanism uses + # log(interval_length) weighting while the discrete mechanism is uniform + # over grid points — their CDFs differ by design even on identical data. + # Instead we check practical equivalence: matching means and stds. + for j in range(num_edges): + raw_mean = raw_edges[:, j].mean() + hist_mean = hist_edges[:, j].mean() + raw_std = raw_edges[:, j].std() + hist_std = hist_edges[:, j].std() + np.testing.assert_allclose(hist_mean, raw_mean, atol=0.1) + np.testing.assert_allclose(hist_std, raw_std, rtol=0.5) + + +if __name__ == '__main__': + absltest.main()