diff --git a/dpsynth/data_generation_v3.py b/dpsynth/data_generation_v3.py index d4b0d81..b5a86da 100644 --- a/dpsynth/data_generation_v3.py +++ b/dpsynth/data_generation_v3.py @@ -355,10 +355,13 @@ def __call__( # Phase 2: Encode data to discrete domain. discrete_domains = {} + discrete_labels = {} discrete_data = {} one_way_measurements = [] for col, result in results.items(): - discrete_domains[col] = result.categorical_attribute.size + cat_attr = result.categorical_attribute + discrete_domains[col] = cat_attr.size + discrete_labels[col] = tuple(cat_attr.possible_values) if result.bin_edges is not None: discrete_data[col] = vtx.discretize( data[col].values, result.bin_edges, self.domains[col] @@ -370,7 +373,12 @@ def __call__( if result.measurement is not None: one_way_measurements.append(result.measurement) - discrete = mbi.Dataset(discrete_data, mbi.Domain.fromdict(discrete_domains)) + mbi_domain = mbi.Domain( + attributes=tuple(discrete_domains.keys()), + shape=tuple(discrete_domains.values()), + labels=tuple(discrete_labels[c] for c in discrete_domains), + ) + discrete = mbi.Dataset(discrete_data, mbi_domain) logging.info('[DPSynth]: Finished encoding data.') # Phase 3: Run the discrete mechanism.