diff --git a/dpsynth/discrete_mechanisms/swift.py b/dpsynth/discrete_mechanisms/swift.py index 1962dd6..fc644f0 100644 --- a/dpsynth/discrete_mechanisms/swift.py +++ b/dpsynth/discrete_mechanisms/swift.py @@ -29,6 +29,7 @@ import dataclasses import functools import math +import time import typing from absl import logging @@ -169,37 +170,71 @@ def __call__( errors, candidates, domain, self.max_clique_size, budget_remaining ) + ######################################################## + # Precompile MirrorDescent + synth while measuring. # + ######################################################## + closed_oracle = functools.partial( + mbi.marginal_oracles.message_passing_stable, jtree=jtree + ) + estimator = mbi.estimation.MirrorDescent(marginal_oracle=closed_oracle) + rows = int(mbi.estimation.minimum_variance_unbiased_total(measurements)) + + pgm_future, synth_future = None, None + try: + pgm_future = estimator.precompile( + domain, measurements, extra_cliques=list(selected) + ) + synth_future = mbi.extensions.precompile(domain, list(jtree.nodes), rows) + logging.info('[SWIFT] Started precompilation of MirrorDescent + synth.') + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning('[SWIFT] Precompile failed (non-fatal): %s', e) + ########################################## # Measure the selected marginal queries. # ########################################## + logging.info('[SWIFT] Starting measurements.') new_measurements, _ = _measure_selected_marginals( rng, answers, selected, budget_remaining ) measurements.extend(new_measurements) + logging.info('[SWIFT] Finished measurements.') ######################################################## # Estimate the model using all measurements # ######################################################## - - closed_oracle = functools.partial( - mbi.marginal_oracles.message_passing_stable, jtree=jtree - ) + if pgm_future is not None: + t0 = time.time() + try: + pgm_future.result() + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning('[SWIFT] PGM precompile failed (non-fatal): %s', e) + logging.info('[SWIFT] PGM precompile wait: %.2fs', time.time() - t0) callback_fn = mbi.callbacks.default(measurements, domain=domain) - model = mbi.estimation.MirrorDescent( - marginal_oracle=closed_oracle, - ).estimate( + final_model = estimator.estimate( domain, measurements, iters=self.pgm_iters, potentials=potentials, callback_fn=callback_fn, ) + assert isinstance(final_model, mbi.MarkovRandomField) logging.info('[SWIFT] Estimated final model.') + if synth_future is not None: + t0 = time.time() + try: + synth_future.result() + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning('[SWIFT] Synth precompile failed (non-fatal): %s', e) + logging.info('[SWIFT] Synth precompile wait: %.2fs', time.time() - t0) + + syn = mbi.extensions.synthetic_data(final_model, rows) + logging.info('[SWIFT] Generated %d synthetic records.', rows) + return common.DiscreteMechanismResult( - model=model, - synthetic_data=model.synthetic_data(), + model=final_model, + synthetic_data=syn, measurements=measurements, )