Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpsynth/discrete_mechanisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dpsynth.discrete_mechanisms.aim import AIMMechanism
from dpsynth.discrete_mechanisms.aim_gdp import AIMGDPMechanism
from dpsynth.discrete_mechanisms.common import DiscreteMechanismResult
from dpsynth.discrete_mechanisms.common import MechanismDiagnostics
from dpsynth.discrete_mechanisms.direct import DirectMechanism
from dpsynth.discrete_mechanisms.independent import IndependentMechanism
from dpsynth.discrete_mechanisms.mst import MSTMechanism
Expand Down
83 changes: 45 additions & 38 deletions dpsynth/discrete_mechanisms/aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from collections.abc import Iterable, Mapping
import dataclasses
import time
from typing import TypeAlias

from absl import logging
Expand Down Expand Up @@ -170,6 +169,7 @@ def __call__(
raise ValueError('Must call calibrate() before using the mechanism.')

logging.info('[AIM]: Starting Mechanism.')
phase_times = {}

zcdp_rho = self.zcdp_rho

Expand Down Expand Up @@ -222,23 +222,27 @@ def __call__(
########################################################################
# Select a marginal query worst approximated by the current model. #
########################################################################
t0 = time.time()
rho_remaining -= rho_per_round
fraction = self.select_budget_fraction
sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round)
epsilon = accounting.zcdp_exponential_eps(fraction * rho_per_round)
size_limit = self.max_model_size * (zcdp_rho - rho_remaining) / zcdp_rho
small_candidates = _filter_candidates(candidates, model, size_limit)

estimates = mbi.marginal_oracles.bulk_variable_elimination(
model.potentials, list(small_candidates), total=model.total
)
marginal_query = _worst_approximated(
rng, small_candidates, answers, estimates, epsilon, sigma, data.domain
)
with common.timed(phase_times, 'selection'):
rho_remaining -= rho_per_round
fraction = self.select_budget_fraction
sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round)
epsilon = accounting.zcdp_exponential_eps(fraction * rho_per_round)
size_limit = self.max_model_size * (zcdp_rho - rho_remaining) / zcdp_rho
small_candidates = _filter_candidates(candidates, model, size_limit)

estimates = mbi.marginal_oracles.bulk_variable_elimination(
model.potentials, list(small_candidates), total=model.total
)
marginal_query = _worst_approximated(
rng,
small_candidates,
answers,
estimates,
epsilon,
sigma,
data.domain,
)

t1 = time.time()
logging.info('[AIM] Found worst-approximated candidate in %.2fs', t1 - t0)
logging.info(
'[AIM] Round %d, Budget used: %.4f, Measuring: %s, Candidates: %d',
t,
Expand All @@ -250,31 +254,30 @@ def __call__(
######################################################################
# Measure the marginal query privately using the Gaussian mechanism. #
######################################################################
measurement = common.measure_marginals_with_noise(
rng, data, [marginal_query], sigma
)[0]
measurements.append(measurement)
old_estimate = model.project(marginal_query).datavector()
with common.timed(phase_times, 'measurement'):
measurement = common.measure_marginals_with_noise(
rng, data, [marginal_query], sigma
)[0]
measurements.append(measurement)
old_estimate = model.project(marginal_query).datavector()

#####################################################
# Estimate the data distribution using Private-PGM. #
#####################################################
t2 = time.time()
callback_fn = mbi.callbacks.default(measurements, domain=data.domain)
measured_cliques = list(set(m.clique for m in measurements))
warm_start = model.potentials.expand(measured_cliques)
model = mbi.estimation.MirrorDescent(
marginal_oracle=self.marginal_oracle,
).estimate(
data.domain,
measurements,
potentials=warm_start,
iters=self.pgm_iters,
callback_fn=callback_fn,
)
assert isinstance(model, mbi.MarkovRandomField)
t3 = time.time()
logging.info('[AIM] Mirror descent took %.2fs', t3 - t2)
with common.timed(phase_times, 'estimation'):
callback_fn = mbi.callbacks.default(measurements, domain=data.domain)
measured_cliques = list(set(m.clique for m in measurements))
warm_start = model.potentials.expand(measured_cliques)
model = mbi.estimation.MirrorDescent(
marginal_oracle=self.marginal_oracle,
).estimate(
data.domain,
measurements,
potentials=warm_start,
iters=self.pgm_iters,
callback_fn=callback_fn,
)
assert isinstance(model, mbi.MarkovRandomField)

new_estimate = model.project(marginal_query).datavector()

Expand All @@ -289,8 +292,12 @@ def __call__(
sigma = accounting.zcdp_gaussian_sigma((1 - fraction) * rho_per_round)
logging.info('[AIM] Reducing sigma: %.1f', sigma)

diagnostics = common.clique_stats(model)
diagnostics.phase_times = phase_times
diagnostics.num_rounds = t
return common.DiscreteMechanismResult(
model=model,
synthetic_data=model.synthetic_data(),
measurements=measurements,
diagnostics=diagnostics,
)
87 changes: 44 additions & 43 deletions dpsynth/discrete_mechanisms/aim_gdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from collections.abc import Iterable, Mapping
import dataclasses
import time
import typing

from absl import logging
Expand Down Expand Up @@ -216,6 +215,7 @@ def __call__(
raise ValueError('Must call calibrate() before using the mechanism.')

logging.info('[AIM] Starting Mechanism.')
phase_times = {}

# Convert end-to-end GDP sigma to budget for internal allocation.
gdp_budget = 1.0 / self.gdp_sigma**2
Expand Down Expand Up @@ -288,28 +288,26 @@ def __call__(
########################################################################
# Select a marginal query worst approximated by the current model. #
########################################################################
t0 = time.time()
budget_remaining -= budget_per_round
measure_budget = budget_per_round * (1 - self.select_budget_fraction)
select_budget = budget_per_round * self.select_budget_fraction
measure_sigma = accounting.gdp_gaussian_sigma(measure_budget)
percent_used = (gdp_budget - budget_remaining) / gdp_budget
size_limit = self.max_model_size * percent_used
small_candidates = _filter_candidates(candidates, model, size_limit)

marginal_query = _worst_approximated(
rng,
candidates=small_candidates,
errors=errors,
answers=answers,
model=model,
select_budget=select_budget,
measure_sigma=measure_sigma,
max_new_evals=self.max_candidates_per_round,
)
with common.timed(phase_times, 'selection'):
budget_remaining -= budget_per_round
measure_budget = budget_per_round * (1 - self.select_budget_fraction)
select_budget = budget_per_round * self.select_budget_fraction
measure_sigma = accounting.gdp_gaussian_sigma(measure_budget)
percent_used = (gdp_budget - budget_remaining) / gdp_budget
size_limit = self.max_model_size * percent_used
small_candidates = _filter_candidates(candidates, model, size_limit)

marginal_query = _worst_approximated(
rng,
candidates=small_candidates,
errors=errors,
answers=answers,
model=model,
select_budget=select_budget,
measure_sigma=measure_sigma,
max_new_evals=self.max_candidates_per_round,
)

t1 = time.time()
logging.info('[AIM] Found worst candidate in %.2fs', t1 - t0)
logging.info(
'[AIM] Round %d, Budget used: %.4f, Measuring: %s, Candidates: %d',
t,
Expand All @@ -321,31 +319,30 @@ def __call__(
######################################################################
# Measure the marginal query privately using the Gaussian mechanism. #
######################################################################
measurement = common.measure_marginals_with_noise(
rng, data, [marginal_query], measure_sigma
)[0]
measurements.append(measurement)
old_estimate = model.project(marginal_query).datavector()
with common.timed(phase_times, 'measurement'):
measurement = common.measure_marginals_with_noise(
rng, data, [marginal_query], measure_sigma
)[0]
measurements.append(measurement)
old_estimate = model.project(marginal_query).datavector()

#####################################################
# Estimate the data distribution using Private-PGM. #
#####################################################
t2 = time.time()
callback_fn = mbi.callbacks.default(measurements, domain=domain)
measured_cliques = list(set(m.clique for m in measurements))
warm_start = model.potentials.expand(measured_cliques)
model = mbi.estimation.MirrorDescent(
marginal_oracle=self.marginal_oracle,
).estimate(
domain,
measurements,
potentials=warm_start,
iters=self.pgm_iters,
callback_fn=callback_fn,
)
model = typing.cast(mbi.MarkovRandomField, model)
t3 = time.time()
logging.info('[AIM] Mirror descent took %.2fs', t3 - t2)
with common.timed(phase_times, 'estimation'):
callback_fn = mbi.callbacks.default(measurements, domain=domain)
measured_cliques = list(set(m.clique for m in measurements))
warm_start = model.potentials.expand(measured_cliques)
model = mbi.estimation.MirrorDescent(
marginal_oracle=self.marginal_oracle,
).estimate(
domain,
measurements,
potentials=warm_start,
iters=self.pgm_iters,
callback_fn=callback_fn,
)
model = typing.cast(mbi.MarkovRandomField, model)

new_estimate = model.project(marginal_query).datavector()

Expand All @@ -364,8 +361,12 @@ def __call__(
'[AIM] Increasing budget per round: %.5f', budget_per_round
)

diagnostics = common.clique_stats(model)
diagnostics.phase_times = phase_times
diagnostics.num_rounds = t
return common.DiscreteMechanismResult(
model=model,
synthetic_data=model.synthetic_data(),
measurements=measurements,
diagnostics=diagnostics,
)
83 changes: 81 additions & 2 deletions dpsynth/discrete_mechanisms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,99 @@
"""Common utility functions for synthetic data mechanisms."""

from collections.abc import Iterable, Mapping, Sequence
import contextlib
import dataclasses
import functools
import itertools
from typing import Any, TypeAlias
import time
from typing import TypeAlias

from absl import logging
from dpsynth import transformations
import mbi
import mbi.junction_tree
import more_itertools
import numpy as np
import scipy
import scipy.special
import tqdm


@dataclasses.dataclass
class MechanismDiagnostics:
"""Diagnostic info from a discrete mechanism run.

Attributes:
phase_times: Wall-clock time in seconds for each named phase.
num_rounds: Number of select-measure rounds (iterative mechanisms only).
num_cliques: Number of cliques in the fitted model.
max_clique_size: Size of the largest clique.
total_clique_size: Sum of all clique sizes.
max_jtree_node_size: Size of the largest junction tree node.
total_jtree_size: Sum of all junction tree node sizes.
"""

phase_times: dict[str, float] = dataclasses.field(default_factory=dict)
num_rounds: int = 0
num_cliques: int = 0
max_clique_size: int = 0
total_clique_size: int = 0
max_jtree_node_size: int = 0
total_jtree_size: int = 0


@contextlib.contextmanager
def timed(phase_times: dict[str, float], name: str):
"""Context manager that logs and records wall-clock time for a phase."""
start = time.monotonic()
yield
elapsed = time.monotonic() - start
phase_times[name] = phase_times.get(name, 0.0) + elapsed
logging.info('[%s] %.2fs', name, elapsed)


def clique_stats(model: mbi.Model) -> MechanismDiagnostics:
"""Compute structural diagnostics from a fitted model and log them.

Args:
model: The fitted graphical model (must be a MarkovRandomField).

Returns:
A MechanismDiagnostics populated with clique and junction tree stats.

Raises:
TypeError: If model is not a MarkovRandomField.
"""
if not isinstance(model, mbi.MarkovRandomField):
raise TypeError(f'Expected MarkovRandomField, got {type(model).__name__}')
domain = model.potentials.domain
cliques = model.potentials.cliques
sizes = [domain.size(c) for c in cliques]
jtree, _ = mbi.junction_tree.make_junction_tree(domain, cliques)
jtree_nodes = list(jtree.nodes)
jtree_sizes = [domain.size(n) for n in jtree_nodes]
diagnostics = MechanismDiagnostics(
num_cliques=len(cliques),
max_clique_size=max(sizes, default=0),
total_clique_size=sum(sizes),
max_jtree_node_size=max(jtree_sizes, default=0),
total_jtree_size=sum(jtree_sizes),
)
logging.info(
'Cliques: %d, max_size: %d, total_size: %d',
diagnostics.num_cliques,
diagnostics.max_clique_size,
diagnostics.total_clique_size,
)
logging.info(
'Junction tree: %d nodes, max_size: %d, total_size: %d',
len(jtree_nodes),
diagnostics.max_jtree_node_size,
diagnostics.total_jtree_size,
)
return diagnostics


@dataclasses.dataclass
class DiscreteMechanismResult:
"""Result of running a discrete mechanism.
Expand All @@ -45,7 +124,7 @@ class DiscreteMechanismResult:
measurements: list[mbi.LinearMeasurement] = dataclasses.field(
default_factory=list
)
diagnostics: Any | None = None
diagnostics: MechanismDiagnostics | None = None


def exponential_mechanism(
Expand Down
Loading
Loading