Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,16 @@ def warn_load_deprecated():
)


def warn_set_constraints_deprecated():
"""Warn that the `set_constraints` method on synthesizer is deprecated."""
deprecation_msg = (
'Warning: The `set_constraints` method is deprecated. '
'Please use the `load_constraints` utility function to load constraints from a file '
'and add them to the synthesizer with the `add_constraints` method.'
)
warnings.warn(deprecation_msg, FutureWarning)


def _validate_correct_synthesizer_loading(synthesizer, cls):
"""Validate that the loaded synthesizer is of the correct type."""
synthesizer_name = synthesizer.__class__.__name__
Expand Down
28 changes: 28 additions & 0 deletions sdv/cag/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib
import json
import re
import traceback
import warnings

import numpy as np
Expand Down Expand Up @@ -269,3 +271,29 @@ def load_constraint_from_dict(constraint_dict):
raise ValueError(f"Unknown `constraint_class` '{class_name}'.")

return constraint_class.load_constraint_from_dict(parameters=parameters)


def _load_constraints_from_file(filepath):
"""Load constraints from a file (JSON).

Args:
filepath (str):
The string path to the file containing the constraints to load.

Returns:
list[BaseConstraint]:
A list of constraint objects.
"""
with open(filepath, 'r') as f:
constraints_json = json.load(f)

constraint_list = []
for constraint_dict in constraints_json:
try:
constraint_list.append(load_constraint_from_dict(constraint_dict))
except Exception as e:
warnings.warn(
f'Could not load constraint ({constraint_dict}):\n'
f' {traceback.format_exception_only(type(e), e)[0]}'
)
return constraint_list
17 changes: 4 additions & 13 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
check_synthesizer_version,
generate_synthesizer_id,
warn_load_deprecated,
warn_set_constraints_deprecated,
)
from sdv.cag._errors import ConstraintNotMetError
from sdv.cag._utils import (
_convert_to_snake_case,
_get_invalid_rows,
_load_constraints_from_file,
_validate_constraints,
load_constraint_from_dict,
)
from sdv.cag.programmable_constraint import ProgrammableConstraint, ProgrammableConstraintHarness
from sdv.errors import (
Expand Down Expand Up @@ -328,18 +329,8 @@ def set_constraints(self, filepath):
'Cannot `set_constraints` since constraints have already been applied.'
)

with open(filepath, 'r') as f:
constraints_json = json.load(f)

constraint_list = []
for constraint_dict in constraints_json:
try:
constraint_list.append(load_constraint_from_dict(constraint_dict))
except Exception as e:
warnings.warn(
f'Could not load constraint ({constraint_dict}):\n'
f' {traceback.format_exception_only(type(e), e)[0]}'
)
warn_set_constraints_deprecated()
constraint_list = _load_constraints_from_file(filepath)

for constraint in constraint_list:
try:
Expand Down
17 changes: 4 additions & 13 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
check_synthesizer_version,
generate_synthesizer_id,
warn_load_deprecated,
warn_set_constraints_deprecated,
)
from sdv.cag._errors import ConstraintNotMetError
from sdv.cag._utils import (
_convert_to_snake_case,
_get_invalid_rows,
_load_constraints_from_file,
_validate_constraints_single_table,
load_constraint_from_dict,
)
from sdv.cag.programmable_constraint import ProgrammableConstraint, ProgrammableConstraintHarness
from sdv.data_processing.data_processor import DataProcessor
Expand Down Expand Up @@ -552,18 +553,8 @@ def set_constraints(self, filepath):
'Cannot `set_constraints` since constraints have already been applied.'
)

with open(filepath, 'r') as f:
constraints_json = json.load(f)

constraint_list = []
for constraint_dict in constraints_json:
try:
constraint_list.append(load_constraint_from_dict(constraint_dict))
except Exception as e:
warnings.warn(
f'Could not load constraint ({constraint_dict}):\n'
f' {traceback.format_exception_only(type(e), e)[0]}'
)
warn_set_constraints_deprecated()
constraint_list = _load_constraints_from_file(filepath)

for constraint in constraint_list:
try:
Expand Down
14 changes: 12 additions & 2 deletions sdv/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
"""Utils module."""

from sdv.utils.utils import drop_unknown_references, get_random_sequence_subset, load_synthesizer
from sdv.utils.utils import (
drop_unknown_references,
get_random_sequence_subset,
load_synthesizer,
load_constraints,
)

__all__ = ('drop_unknown_references', 'get_random_sequence_subset', 'load_synthesizer')
__all__ = (
'drop_unknown_references',
'get_random_sequence_subset',
'load_synthesizer',
'load_constraints',
)
15 changes: 15 additions & 0 deletions sdv/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
check_synthesizer_version,
generate_synthesizer_id,
)
from sdv.cag._utils import _load_constraints_from_file
from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError
from sdv.logging import get_sdv_logger
from sdv.metadata.metadata import Metadata
Expand Down Expand Up @@ -198,3 +199,17 @@ def load_synthesizer(filepath):
})

return synthesizer


def load_constraints(filepath):
"""Load constraints from a file (JSON).

Args:
filepath (str):
The string path to the file containing the constraints to load.

Returns:
list[BaseConstraint]:
A list of constraint objects.
"""
return _load_constraints_from_file(filepath)
60 changes: 7 additions & 53 deletions tests/integration/cag/test_base.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,9 @@
import inspect

import pytest

import sdv.cag

CONSTRAINTS = [
(
'FixedCombinations',
{
'column_names': ['A', 'B', 'C'],
'table_name': 'table',
},
),
(
'FixedIncrements',
{
'column_name': 'column',
'increment_value': 5,
'table_name': 'table',
},
),
(
'Inequality',
{
'low_column_name': 'low',
'high_column_name': 'high',
'strict_boundaries': True,
'table_name': None,
},
),
(
'OneHotEncoding',
{
'column_names': ['col1', 'col2', 'col3', 'col4', 'col5'],
'table_name': 'table',
'learning_strategy': 'one_hot',
},
),
(
'Range',
{
'low_column_name': 'low',
'middle_column_name': 'middle',
'high_column_name': 'high',
'strict_boundaries': False,
'table_name': None,
},
),
]


def test_all_available_constraints_included_in_constraint_test_list():
def test_all_available_constraints_included_in_constraint_test_list(constraints_as_dicts):
"""Test that all available constraints are included in the test list."""
# Setup
skipped_cag_module_classes = ['ProgrammableConstraint', 'ConstraintList']
Expand All @@ -61,25 +14,26 @@ def test_all_available_constraints_included_in_constraint_test_list():
for constraint, cls in available_constraints
if constraint not in skipped_cag_module_classes
}
tested_constraints = {constraint[0] for constraint in CONSTRAINTS}
tested_constraints = {constraint[0] for constraint in constraints_as_dicts}

# Run and Assert
assert available_constraints == tested_constraints


@pytest.mark.parametrize(['constraint', 'constraint_params'], CONSTRAINTS)
def test_get_constraint_dict_and_load_constraint_from_dict(constraint, constraint_params):
def test_get_constraint_dict_and_load_constraint_from_dict(constraint_tuple):
"""Test ``get_constraint_dict`` and ``load_constraint_from_dict for all constraints."""
# Setup
constraint_class = getattr(sdv.cag, constraint)
constraint_class_name, constraint_params = constraint_tuple

constraint_class = getattr(sdv.cag, constraint_class_name)
constraint_instance = constraint_class(**constraint_params)

# Run
constraint_dict = constraint_instance.get_constraint_dict()
loaded_constraint = constraint_class.load_constraint_from_dict(constraint_dict['parameters'])

# Assert
assert constraint_dict == {'class_name': constraint, 'parameters': constraint_params}
assert constraint_dict == {'class_name': constraint_class_name, 'parameters': constraint_params}
for param, param_value in constraint_params.items():
instanced_param = getattr(
constraint_instance, param, getattr(constraint_instance, f'_{param}', None)
Expand Down
73 changes: 73 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from copy import deepcopy

import pytest

import sdv.cag

CONSTRAINTS = [
(
'FixedCombinations',
{
'column_names': ['A', 'B', 'C'],
'table_name': 'table',
},
),
(
'FixedIncrements',
{
'column_name': 'column',
'increment_value': 5,
'table_name': 'table',
},
),
(
'Inequality',
{
'low_column_name': 'low',
'high_column_name': 'high',
'strict_boundaries': True,
'table_name': None,
},
),
(
'OneHotEncoding',
{
'column_names': ['col1', 'col2', 'col3', 'col4', 'col5'],
'table_name': 'table',
'learning_strategy': 'one_hot',
},
),
(
'Range',
{
'low_column_name': 'low',
'middle_column_name': 'middle',
'high_column_name': 'high',
'strict_boundaries': False,
'table_name': None,
},
),
]


@pytest.fixture()
def constraints_as_dicts():
"""Return a list of dictionaries (each dictionary is a constraint definition)."""
return CONSTRAINTS


@pytest.fixture(params=CONSTRAINTS, ids=[c_name for c_name, _ in CONSTRAINTS])
def constraint_tuple(request):
"""Return a constraint class name and constraint parameters as a tuple."""
constraint_class_name, constraint_params = request.param

return constraint_class_name, deepcopy(constraint_params)


@pytest.fixture()
def constraint_object(constraint_tuple):
"""Return a constraint object."""
constraint_class_name, constraint_params = constraint_tuple
constraint_class = getattr(sdv.cag, constraint_class_name)

return constraint_class(**constraint_params)
Loading
Loading