Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions dpsynth/discrete_mechanisms/aim_gdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Iterable, Mapping
import dataclasses
import time
from typing import TypeAlias
import typing

from absl import logging
import dp_accounting
Expand All @@ -29,7 +29,7 @@
import mbi.junction_tree
import numpy as np

MarginalQuery: TypeAlias = tuple[str, ...]
MarginalQuery: typing.TypeAlias = tuple[str, ...]


def _filter_candidates(
Expand Down Expand Up @@ -265,11 +265,16 @@ def __call__(
assert isinstance(model, mbi.MarkovRandomField)
logging.info('[AIM] Estimated initial model.')

# The initial model is fitted from 1-way measurements only, so it IS the
# independence model. compute_independence_errors is much faster than
# bulk_variable_elimination for this case (pure numpy, no XLA compilation).
budget_remaining -= 0.5 * budget_per_round
estimates = mbi.marginal_oracles.bulk_variable_elimination(
model.potentials, list(candidates), model.total
per_candidate_sigma = accounting.gdp_gaussian_sigma(
0.5 * budget_per_round / len(candidates)
)
errors = _compute_dp_errors(rng, answers, estimates, 0.5 * budget_per_round)
errors = common.compute_independence_errors(data, model, list(candidates))
for cl in errors:
errors[cl] += rng.normal(loc=0.0, scale=per_candidate_sigma)
logging.info('[AIM] Computed initial errors.')

t = 0
Expand Down Expand Up @@ -338,7 +343,7 @@ def __call__(
iters=self.pgm_iters,
callback_fn=callback_fn,
)
assert isinstance(model, mbi.MarkovRandomField)
model = typing.cast(mbi.MarkovRandomField, model)
t3 = time.time()
logging.info('[AIM] Mirror descent took %.2fs', t3 - t2)

Expand Down
27 changes: 26 additions & 1 deletion dpsynth/discrete_mechanisms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Common utility functions for synthetic data mechanisms."""

from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
import dataclasses
import functools
import itertools
Expand All @@ -26,6 +26,7 @@
import numpy as np
import scipy
import scipy.special
import tqdm


@dataclasses.dataclass
Expand Down Expand Up @@ -261,3 +262,27 @@ def score(cl):
for cl in downward_closure(workload.keys())
if domain.size(cl) <= max_marginal_size
}


def compute_independence_errors(
data: mbi.Projectable,
model: mbi.MarkovRandomField,
cliques: Sequence[mbi.Clique],
) -> dict[mbi.Clique, float]:
"""Computes L1 errors between actual marginals and the independence model."""
total = float(model.total)

# Pure numpy to avoid XLA recompilation: each distinct clique shape triggers
# a separate compilation, which dominates runtime for large candidate sets.
oneway = {
a: np.asarray(model.project((a,)).datavector()) / total
for a in data.domain
}

errors = {}
for cl in tqdm.tqdm(cliques, desc='Computing independence errors'):
estimate = functools.reduce(np.multiply.outer, (oneway[a] for a in cl))
actual = np.asarray(data.project(cl).datavector(flatten=False))
error = np.abs(total * estimate - actual).sum()
errors[cl] = error
return errors
17 changes: 5 additions & 12 deletions dpsynth/discrete_mechanisms/mst.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Sequence
import dataclasses
import itertools
import typing

from absl import logging
import dp_accounting
Expand All @@ -29,7 +30,6 @@
import networkx as nx
import numpy as np
from scipy.cluster.hierarchy import DisjointSet # pylint: disable=g-importing-member
import tqdm


def dp_maximum_spanning_tree(
Expand Down Expand Up @@ -127,26 +127,19 @@ def _select_two_way_marginal_queries(
independent_model = mbi.estimation.MirrorDescent().estimate(
data.domain, one_way_measurements, iters=2500
)

oneway_marginals = {
attr: np.array(independent_model.project(attr).datavector())
for attr in data.domain.attributes
}
independent_model = typing.cast(mbi.MarkovRandomField, independent_model)

# Construct a complete graph where nodes=attributes and weight of edge
# (a, b) is a sensitivity 1 measure of correlation between a and b.
weights = {}
candidates = [
cl
for cl in itertools.combinations(data.domain.attributes, 2)
if data.domain.size(cl) <= maximum_marginal_size
]
logging.info('[MST]: Computing Quality Scores')
for a, b in tqdm.tqdm(candidates):
# For efficiency, we compute the outer product of one-way marginals.
xhat = np.outer(oneway_marginals[a], oneway_marginals[b]).flatten()
x = data.project((a, b)).datavector()
weights[a, b] = np.linalg.norm(x - xhat, 1)
weights = common.compute_independence_errors(
data, independent_model, candidates
)

return dp_maximum_spanning_tree(
rng,
Expand Down
26 changes: 9 additions & 17 deletions dpsynth/discrete_mechanisms/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import dataclasses
import functools
import math
import typing

from absl import logging
import dp_accounting
Expand All @@ -37,11 +38,9 @@
from dpsynth.discrete_mechanisms import common
from dpsynth.discrete_mechanisms import swift_utils
from dpsynth.local_mode import primitives
import jax
import mbi
import networkx as nx
import numpy as np
import tqdm


@dataclasses.dataclass
Expand Down Expand Up @@ -151,7 +150,7 @@ def __call__(
iters=self.pgm_iters,
potentials=potentials,
)
assert isinstance(model, mbi.MarkovRandomField)
model = typing.cast(mbi.MarkovRandomField, model)
logging.info('[SWIFT] Estimated initial model.')

###########################################
Expand Down Expand Up @@ -293,21 +292,14 @@ def _compute_initial_errors(
data: mbi.Projectable,
model: mbi.MarkovRandomField,
cliques: Sequence[mbi.Clique],
gdp_mu: float,
gdp_budget: float,
) -> dict[mbi.Clique, float]:
"""Computes the initial errors for the SWIFT mechanism."""
gdp_per_clique = gdp_mu / len(cliques)
sigma_per_clique = accounting.gdp_gaussian_sigma(gdp_per_clique)
errors = {}
total = float(model.total)
oneway = {a: model.project((a,)) / total for a in data.domain}
oneway = jax.tree.map(np.asarray, oneway)
for cl in tqdm.tqdm(cliques, desc='Computing initial errors'):
estimate = functools.reduce(mbi.Factor.__mul__, (oneway[a] for a in cl))
actual = data.project(cl)
diff = (total * estimate - actual).datavector()
error = np.abs(diff).sum()
errors[cl] = error + rng.normal(loc=0.0, scale=sigma_per_clique)
"""Computes DP initial errors for the SWIFT mechanism."""
budget_per_clique = gdp_budget / len(cliques)
sigma_per_clique = accounting.gdp_gaussian_sigma(budget_per_clique)
errors = common.compute_independence_errors(data, model, cliques)
for cl in errors:
errors[cl] += rng.normal(loc=0.0, scale=sigma_per_clique)
return errors


Expand Down
Loading