From 4beec5e4a90e75705c4ce21998ec7d5d9b6f4411 Mon Sep 17 00:00:00 2001 From: Ryan McKenna Date: Thu, 25 Jun 2026 09:38:15 -0700 Subject: [PATCH] Add compute_independence_errors to common and use in MST, AIM_GDP, and SWIFT. Uses numpy to avoid expensive Jax compilations of many small programs. PiperOrigin-RevId: 938032588 --- dpsynth/discrete_mechanisms/aim_gdp.py | 17 ++++++++++------ dpsynth/discrete_mechanisms/common.py | 27 +++++++++++++++++++++++++- dpsynth/discrete_mechanisms/mst.py | 17 +++++----------- dpsynth/discrete_mechanisms/swift.py | 26 +++++++++---------------- 4 files changed, 51 insertions(+), 36 deletions(-) diff --git a/dpsynth/discrete_mechanisms/aim_gdp.py b/dpsynth/discrete_mechanisms/aim_gdp.py index 749a342..2adb14c 100644 --- a/dpsynth/discrete_mechanisms/aim_gdp.py +++ b/dpsynth/discrete_mechanisms/aim_gdp.py @@ -17,7 +17,7 @@ from collections.abc import Iterable, Mapping import dataclasses import time -from typing import TypeAlias +import typing from absl import logging import dp_accounting @@ -29,7 +29,7 @@ import mbi.junction_tree import numpy as np -MarginalQuery: TypeAlias = tuple[str, ...] +MarginalQuery: typing.TypeAlias = tuple[str, ...] def _filter_candidates( @@ -265,11 +265,16 @@ def __call__( assert isinstance(model, mbi.MarkovRandomField) logging.info('[AIM] Estimated initial model.') + # The initial model is fitted from 1-way measurements only, so it IS the + # independence model. compute_independence_errors is much faster than + # bulk_variable_elimination for this case (pure numpy, no XLA compilation). budget_remaining -= 0.5 * budget_per_round - estimates = mbi.marginal_oracles.bulk_variable_elimination( - model.potentials, list(candidates), model.total + per_candidate_sigma = accounting.gdp_gaussian_sigma( + 0.5 * budget_per_round / len(candidates) ) - errors = _compute_dp_errors(rng, answers, estimates, 0.5 * budget_per_round) + errors = common.compute_independence_errors(data, model, list(candidates)) + for cl in errors: + errors[cl] += rng.normal(loc=0.0, scale=per_candidate_sigma) logging.info('[AIM] Computed initial errors.') t = 0 @@ -338,7 +343,7 @@ def __call__( iters=self.pgm_iters, callback_fn=callback_fn, ) - assert isinstance(model, mbi.MarkovRandomField) + model = typing.cast(mbi.MarkovRandomField, model) t3 = time.time() logging.info('[AIM] Mirror descent took %.2fs', t3 - t2) diff --git a/dpsynth/discrete_mechanisms/common.py b/dpsynth/discrete_mechanisms/common.py index c4f0038..e8bb5fd 100644 --- a/dpsynth/discrete_mechanisms/common.py +++ b/dpsynth/discrete_mechanisms/common.py @@ -14,7 +14,7 @@ """Common utility functions for synthetic data mechanisms.""" -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Mapping, Sequence import dataclasses import functools import itertools @@ -26,6 +26,7 @@ import numpy as np import scipy import scipy.special +import tqdm @dataclasses.dataclass @@ -261,3 +262,27 @@ def score(cl): for cl in downward_closure(workload.keys()) if domain.size(cl) <= max_marginal_size } + + +def compute_independence_errors( + data: mbi.Projectable, + model: mbi.MarkovRandomField, + cliques: Sequence[mbi.Clique], +) -> dict[mbi.Clique, float]: + """Computes L1 errors between actual marginals and the independence model.""" + total = float(model.total) + + # Pure numpy to avoid XLA recompilation: each distinct clique shape triggers + # a separate compilation, which dominates runtime for large candidate sets. + oneway = { + a: np.asarray(model.project((a,)).datavector()) / total + for a in data.domain + } + + errors = {} + for cl in tqdm.tqdm(cliques, desc='Computing independence errors'): + estimate = functools.reduce(np.multiply.outer, (oneway[a] for a in cl)) + actual = np.asarray(data.project(cl).datavector(flatten=False)) + error = np.abs(total * estimate - actual).sum() + errors[cl] = error + return errors diff --git a/dpsynth/discrete_mechanisms/mst.py b/dpsynth/discrete_mechanisms/mst.py index 148d40f..a8a58b0 100644 --- a/dpsynth/discrete_mechanisms/mst.py +++ b/dpsynth/discrete_mechanisms/mst.py @@ -19,6 +19,7 @@ from collections.abc import Sequence import dataclasses import itertools +import typing from absl import logging import dp_accounting @@ -29,7 +30,6 @@ import networkx as nx import numpy as np from scipy.cluster.hierarchy import DisjointSet # pylint: disable=g-importing-member -import tqdm def dp_maximum_spanning_tree( @@ -127,26 +127,19 @@ def _select_two_way_marginal_queries( independent_model = mbi.estimation.MirrorDescent().estimate( data.domain, one_way_measurements, iters=2500 ) - - oneway_marginals = { - attr: np.array(independent_model.project(attr).datavector()) - for attr in data.domain.attributes - } + independent_model = typing.cast(mbi.MarkovRandomField, independent_model) # Construct a complete graph where nodes=attributes and weight of edge # (a, b) is a sensitivity 1 measure of correlation between a and b. - weights = {} candidates = [ cl for cl in itertools.combinations(data.domain.attributes, 2) if data.domain.size(cl) <= maximum_marginal_size ] logging.info('[MST]: Computing Quality Scores') - for a, b in tqdm.tqdm(candidates): - # For efficiency, we compute the outer product of one-way marginals. - xhat = np.outer(oneway_marginals[a], oneway_marginals[b]).flatten() - x = data.project((a, b)).datavector() - weights[a, b] = np.linalg.norm(x - xhat, 1) + weights = common.compute_independence_errors( + data, independent_model, candidates + ) return dp_maximum_spanning_tree( rng, diff --git a/dpsynth/discrete_mechanisms/swift.py b/dpsynth/discrete_mechanisms/swift.py index 14f7532..1962dd6 100644 --- a/dpsynth/discrete_mechanisms/swift.py +++ b/dpsynth/discrete_mechanisms/swift.py @@ -29,6 +29,7 @@ import dataclasses import functools import math +import typing from absl import logging import dp_accounting @@ -37,11 +38,9 @@ from dpsynth.discrete_mechanisms import common from dpsynth.discrete_mechanisms import swift_utils from dpsynth.local_mode import primitives -import jax import mbi import networkx as nx import numpy as np -import tqdm @dataclasses.dataclass @@ -151,7 +150,7 @@ def __call__( iters=self.pgm_iters, potentials=potentials, ) - assert isinstance(model, mbi.MarkovRandomField) + model = typing.cast(mbi.MarkovRandomField, model) logging.info('[SWIFT] Estimated initial model.') ########################################### @@ -293,21 +292,14 @@ def _compute_initial_errors( data: mbi.Projectable, model: mbi.MarkovRandomField, cliques: Sequence[mbi.Clique], - gdp_mu: float, + gdp_budget: float, ) -> dict[mbi.Clique, float]: - """Computes the initial errors for the SWIFT mechanism.""" - gdp_per_clique = gdp_mu / len(cliques) - sigma_per_clique = accounting.gdp_gaussian_sigma(gdp_per_clique) - errors = {} - total = float(model.total) - oneway = {a: model.project((a,)) / total for a in data.domain} - oneway = jax.tree.map(np.asarray, oneway) - for cl in tqdm.tqdm(cliques, desc='Computing initial errors'): - estimate = functools.reduce(mbi.Factor.__mul__, (oneway[a] for a in cl)) - actual = data.project(cl) - diff = (total * estimate - actual).datavector() - error = np.abs(diff).sum() - errors[cl] = error + rng.normal(loc=0.0, scale=sigma_per_clique) + """Computes DP initial errors for the SWIFT mechanism.""" + budget_per_clique = gdp_budget / len(cliques) + sigma_per_clique = accounting.gdp_gaussian_sigma(budget_per_clique) + errors = common.compute_independence_errors(data, model, cliques) + for cl in errors: + errors[cl] += rng.normal(loc=0.0, scale=sigma_per_clique) return errors