diff --git a/sdv/_utils.py b/sdv/_utils.py index 4928f5ff4..f6eb49b93 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -507,6 +507,16 @@ def warn_load_deprecated(): ) +def warn_set_constraints_deprecated(): + """Warn that the `set_constraints` method on synthesizer is deprecated.""" + deprecation_msg = ( + 'Warning: The `set_constraints` method is deprecated. ' + 'Please use the `load_constraints` utility function to load constraints from a file ' + 'and add them to the synthesizer with the `add_constraints` method.' + ) + warnings.warn(deprecation_msg, FutureWarning) + + def _validate_correct_synthesizer_loading(synthesizer, cls): """Validate that the loaded synthesizer is of the correct type.""" synthesizer_name = synthesizer.__class__.__name__ diff --git a/sdv/cag/_utils.py b/sdv/cag/_utils.py index b3bdbeb09..46ac769cb 100644 --- a/sdv/cag/_utils.py +++ b/sdv/cag/_utils.py @@ -1,5 +1,7 @@ import importlib +import json import re +import traceback import warnings import numpy as np @@ -269,3 +271,29 @@ def load_constraint_from_dict(constraint_dict): raise ValueError(f"Unknown `constraint_class` '{class_name}'.") return constraint_class.load_constraint_from_dict(parameters=parameters) + + +def _load_constraints_from_file(filepath): + """Load constraints from a file (JSON). + + Args: + filepath (str): + The string path to the file containing the constraints to load. + + Returns: + list[BaseConstraint]: + A list of constraint objects. + """ + with open(filepath, 'r') as f: + constraints_json = json.load(f) + + constraint_list = [] + for constraint_dict in constraints_json: + try: + constraint_list.append(load_constraint_from_dict(constraint_dict)) + except Exception as e: + warnings.warn( + f'Could not load constraint ({constraint_dict}):\n' + f' {traceback.format_exception_only(type(e), e)[0]}' + ) + return constraint_list diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 2441bc23a..f001b935e 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -22,13 +22,14 @@ check_synthesizer_version, generate_synthesizer_id, warn_load_deprecated, + warn_set_constraints_deprecated, ) from sdv.cag._errors import ConstraintNotMetError from sdv.cag._utils import ( _convert_to_snake_case, _get_invalid_rows, + _load_constraints_from_file, _validate_constraints, - load_constraint_from_dict, ) from sdv.cag.programmable_constraint import ProgrammableConstraint, ProgrammableConstraintHarness from sdv.errors import ( @@ -328,18 +329,8 @@ def set_constraints(self, filepath): 'Cannot `set_constraints` since constraints have already been applied.' ) - with open(filepath, 'r') as f: - constraints_json = json.load(f) - - constraint_list = [] - for constraint_dict in constraints_json: - try: - constraint_list.append(load_constraint_from_dict(constraint_dict)) - except Exception as e: - warnings.warn( - f'Could not load constraint ({constraint_dict}):\n' - f' {traceback.format_exception_only(type(e), e)[0]}' - ) + warn_set_constraints_deprecated() + constraint_list = _load_constraints_from_file(filepath) for constraint in constraint_list: try: diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 407f5f61d..f7c0254f7 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -32,13 +32,14 @@ check_synthesizer_version, generate_synthesizer_id, warn_load_deprecated, + warn_set_constraints_deprecated, ) from sdv.cag._errors import ConstraintNotMetError from sdv.cag._utils import ( _convert_to_snake_case, _get_invalid_rows, + _load_constraints_from_file, _validate_constraints_single_table, - load_constraint_from_dict, ) from sdv.cag.programmable_constraint import ProgrammableConstraint, ProgrammableConstraintHarness from sdv.data_processing.data_processor import DataProcessor @@ -552,18 +553,8 @@ def set_constraints(self, filepath): 'Cannot `set_constraints` since constraints have already been applied.' ) - with open(filepath, 'r') as f: - constraints_json = json.load(f) - - constraint_list = [] - for constraint_dict in constraints_json: - try: - constraint_list.append(load_constraint_from_dict(constraint_dict)) - except Exception as e: - warnings.warn( - f'Could not load constraint ({constraint_dict}):\n' - f' {traceback.format_exception_only(type(e), e)[0]}' - ) + warn_set_constraints_deprecated() + constraint_list = _load_constraints_from_file(filepath) for constraint in constraint_list: try: diff --git a/sdv/utils/__init__.py b/sdv/utils/__init__.py index 15fc00e91..2bcf133da 100644 --- a/sdv/utils/__init__.py +++ b/sdv/utils/__init__.py @@ -1,5 +1,15 @@ """Utils module.""" -from sdv.utils.utils import drop_unknown_references, get_random_sequence_subset, load_synthesizer +from sdv.utils.utils import ( + drop_unknown_references, + get_random_sequence_subset, + load_synthesizer, + load_constraints, +) -__all__ = ('drop_unknown_references', 'get_random_sequence_subset', 'load_synthesizer') +__all__ = ( + 'drop_unknown_references', + 'get_random_sequence_subset', + 'load_synthesizer', + 'load_constraints', +) diff --git a/sdv/utils/utils.py b/sdv/utils/utils.py index 56c3ba6b2..8ffb2a239 100644 --- a/sdv/utils/utils.py +++ b/sdv/utils/utils.py @@ -15,6 +15,7 @@ check_synthesizer_version, generate_synthesizer_id, ) +from sdv.cag._utils import _load_constraints_from_file from sdv.errors import InvalidDataError, SamplingError, SynthesizerInputError from sdv.logging import get_sdv_logger from sdv.metadata.metadata import Metadata @@ -198,3 +199,17 @@ def load_synthesizer(filepath): }) return synthesizer + + +def load_constraints(filepath): + """Load constraints from a file (JSON). + + Args: + filepath (str): + The string path to the file containing the constraints to load. + + Returns: + list[BaseConstraint]: + A list of constraint objects. + """ + return _load_constraints_from_file(filepath) diff --git a/tests/integration/cag/test_base.py b/tests/integration/cag/test_base.py index 9d4b8db9b..4b9923b71 100644 --- a/tests/integration/cag/test_base.py +++ b/tests/integration/cag/test_base.py @@ -1,56 +1,9 @@ import inspect -import pytest - import sdv.cag -CONSTRAINTS = [ - ( - 'FixedCombinations', - { - 'column_names': ['A', 'B', 'C'], - 'table_name': 'table', - }, - ), - ( - 'FixedIncrements', - { - 'column_name': 'column', - 'increment_value': 5, - 'table_name': 'table', - }, - ), - ( - 'Inequality', - { - 'low_column_name': 'low', - 'high_column_name': 'high', - 'strict_boundaries': True, - 'table_name': None, - }, - ), - ( - 'OneHotEncoding', - { - 'column_names': ['col1', 'col2', 'col3', 'col4', 'col5'], - 'table_name': 'table', - 'learning_strategy': 'one_hot', - }, - ), - ( - 'Range', - { - 'low_column_name': 'low', - 'middle_column_name': 'middle', - 'high_column_name': 'high', - 'strict_boundaries': False, - 'table_name': None, - }, - ), -] - -def test_all_available_constraints_included_in_constraint_test_list(): +def test_all_available_constraints_included_in_constraint_test_list(constraints_as_dicts): """Test that all available constraints are included in the test list.""" # Setup skipped_cag_module_classes = ['ProgrammableConstraint', 'ConstraintList'] @@ -61,17 +14,18 @@ def test_all_available_constraints_included_in_constraint_test_list(): for constraint, cls in available_constraints if constraint not in skipped_cag_module_classes } - tested_constraints = {constraint[0] for constraint in CONSTRAINTS} + tested_constraints = {constraint[0] for constraint in constraints_as_dicts} # Run and Assert assert available_constraints == tested_constraints -@pytest.mark.parametrize(['constraint', 'constraint_params'], CONSTRAINTS) -def test_get_constraint_dict_and_load_constraint_from_dict(constraint, constraint_params): +def test_get_constraint_dict_and_load_constraint_from_dict(constraint_tuple): """Test ``get_constraint_dict`` and ``load_constraint_from_dict for all constraints.""" # Setup - constraint_class = getattr(sdv.cag, constraint) + constraint_class_name, constraint_params = constraint_tuple + + constraint_class = getattr(sdv.cag, constraint_class_name) constraint_instance = constraint_class(**constraint_params) # Run @@ -79,7 +33,7 @@ def test_get_constraint_dict_and_load_constraint_from_dict(constraint, constrain loaded_constraint = constraint_class.load_constraint_from_dict(constraint_dict['parameters']) # Assert - assert constraint_dict == {'class_name': constraint, 'parameters': constraint_params} + assert constraint_dict == {'class_name': constraint_class_name, 'parameters': constraint_params} for param, param_value in constraint_params.items(): instanced_param = getattr( constraint_instance, param, getattr(constraint_instance, f'_{param}', None) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 000000000..ac9a43d2e --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,73 @@ +from copy import deepcopy + +import pytest + +import sdv.cag + +CONSTRAINTS = [ + ( + 'FixedCombinations', + { + 'column_names': ['A', 'B', 'C'], + 'table_name': 'table', + }, + ), + ( + 'FixedIncrements', + { + 'column_name': 'column', + 'increment_value': 5, + 'table_name': 'table', + }, + ), + ( + 'Inequality', + { + 'low_column_name': 'low', + 'high_column_name': 'high', + 'strict_boundaries': True, + 'table_name': None, + }, + ), + ( + 'OneHotEncoding', + { + 'column_names': ['col1', 'col2', 'col3', 'col4', 'col5'], + 'table_name': 'table', + 'learning_strategy': 'one_hot', + }, + ), + ( + 'Range', + { + 'low_column_name': 'low', + 'middle_column_name': 'middle', + 'high_column_name': 'high', + 'strict_boundaries': False, + 'table_name': None, + }, + ), +] + + +@pytest.fixture() +def constraints_as_dicts(): + """Return a list of dictionaries (each dictionary is a constraint definition).""" + return CONSTRAINTS + + +@pytest.fixture(params=CONSTRAINTS, ids=[c_name for c_name, _ in CONSTRAINTS]) +def constraint_tuple(request): + """Return a constraint class name and constraint parameters as a tuple.""" + constraint_class_name, constraint_params = request.param + + return constraint_class_name, deepcopy(constraint_params) + + +@pytest.fixture() +def constraint_object(constraint_tuple): + """Return a constraint object.""" + constraint_class_name, constraint_params = constraint_tuple + constraint_class = getattr(sdv.cag, constraint_class_name) + + return constraint_class(**constraint_params) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 81956b227..b22ca9352 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -25,6 +25,7 @@ from sdv.metadata import MultiTableMetadata from sdv.metadata.metadata import Metadata from sdv.multi_table import HMASynthesizer +from sdv.utils import load_constraints from tests.integration.single_table.custom_constraints import MyConstraint from tests.utils import catch_sdv_logs @@ -2188,10 +2189,16 @@ def test_get_constraints_and_load_constraints(tmp_path): synthesizer = HMASynthesizer(metadata) synthesizer.add_constraints([inequality_constraint, fixed_combinations_constraint]) new_synthesizer = HMASynthesizer(metadata) + expected_warning = re.escape( + 'Warning: The `set_constraints` method is deprecated. ' + 'Please use the `load_constraints` utility function to load constraints from a file ' + 'and add them to the synthesizer with the `add_constraints` method.' + ) # Run synthesizer.get_constraints(filepath=filepath) - new_synthesizer.set_constraints(filepath=filepath) + with pytest.warns(FutureWarning, match=expected_warning): + new_synthesizer.set_constraints(filepath=filepath) # Assert assert [str(constraint) for constraint in synthesizer.get_constraints()] == [ @@ -2221,6 +2228,51 @@ def test_get_constraints_and_load_constraints(tmp_path): ] +def test_load_constraints_and_add_constraints(tmp_path): + """Test loading constraints from a JSON file and adding them to a synthesizer.""" + # Setup + filepath = tmp_path / 'constraints.json' + metadata = Metadata.load_from_dict({ + 'tables': { + 'guests': { + 'columns': { + 'checkin_date': {'sdtype': 'numerical'}, + 'checkout_date': {'sdtype': 'numerical'}, + }, + }, + 'hotels': { + 'columns': { + 'city': {'sdtype': 'categorical'}, + 'state': {'sdtype': 'categorical'}, + }, + }, + }, + }) + inequality_constraint = Inequality( + table_name='guests', + low_column_name='checkin_date', + high_column_name='checkout_date', + ) + fixed_combinations_constraint = FixedCombinations( + table_name='hotels', column_names=['city', 'state'] + ) + constraints = [inequality_constraint, fixed_combinations_constraint] + constraints = [c.get_constraint_dict() for c in constraints] + synthesizer = HMASynthesizer(metadata) + with open(filepath, 'w') as f: + json.dump(constraints, f) + + # Run + loaded_constraints = load_constraints(filepath) + synthesizer.add_constraints(loaded_constraints) + + # Assert + loaded_constraint_dicts = [c.get_constraint_dict() for c in loaded_constraints] + synthesizer_constraint_dicts = [c.get_constraint_dict() for c in synthesizer.get_constraints()] + assert loaded_constraint_dicts == constraints + assert synthesizer_constraint_dicts == constraints + + def test_save_and_load_with_downgraded_version(tmp_path): """Test that synthesizers are raising errors if loaded on a downgraded version.""" # Setup diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index dd00ddf71..40e34fa57 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -24,6 +24,7 @@ TVAESynthesizer, ) from sdv.single_table.base import BaseSingleTableSynthesizer +from sdv.utils import load_constraints METADATA = Metadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', @@ -514,10 +515,16 @@ def test_get_constraints_and_load_constraints(tmp_path, synthesizer_class): synthesizer = synthesizer_class(metadata) synthesizer.add_constraints([inequality_constraint, fixed_combinations_constraint]) new_synthesizer = synthesizer_class(metadata) + expected_warning = re.escape( + 'Warning: The `set_constraints` method is deprecated. ' + 'Please use the `load_constraints` utility function to load constraints from a file ' + 'and add them to the synthesizer with the `add_constraints` method.' + ) # Run synthesizer.get_constraints(filepath=filepath) - new_synthesizer.set_constraints(filepath=filepath) + with pytest.warns(FutureWarning, match=expected_warning): + new_synthesizer.set_constraints(filepath=filepath) # Assert assert [str(constraint) for constraint in synthesizer.get_constraints()] == [ @@ -547,6 +554,45 @@ def test_get_constraints_and_load_constraints(tmp_path, synthesizer_class): ] +@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES) +def test_load_constraints_and_add_constraints(tmp_path, synthesizer_class): + """Test loading constraints from a JSON file and adding them to a synthesizer.""" + # Setup + filepath = tmp_path / 'constraints.json' + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'low': {'sdtype': 'numerical'}, + 'high': {'sdtype': 'numerical'}, + 'has_rewards': {'sdtype': 'categorical'}, + 'room_type': {'sdtype': 'categorical'}, + }, + }, + }, + }) + inequality_constraint = Inequality( + low_column_name='low', + high_column_name='high', + ) + fixed_combinations_constraint = FixedCombinations(column_names=['has_rewards', 'room_type']) + constraints = [inequality_constraint, fixed_combinations_constraint] + constraints = [c.get_constraint_dict() for c in constraints] + synthesizer = synthesizer_class(metadata) + with open(filepath, 'w') as f: + json.dump(constraints, f) + + # Run + loaded_constraints = load_constraints(filepath) + synthesizer.add_constraints(loaded_constraints) + + # Assert + loaded_constraint_dicts = [c.get_constraint_dict() for c in loaded_constraints] + synthesizer_constraint_dicts = [c.get_constraint_dict() for c in synthesizer.get_constraints()] + assert loaded_constraint_dicts == constraints + assert synthesizer_constraint_dicts == constraints + + def test_save_and_load(tmp_path): """Test that synthesizers can be saved and loaded properly.""" # Setup diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index c42e635b2..e438d0634 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -1,3 +1,4 @@ +import json import re from copy import deepcopy @@ -9,7 +10,12 @@ from sdv.errors import InvalidDataError from sdv.metadata.metadata import Metadata from sdv.single_table import GaussianCopulaSynthesizer -from sdv.utils import drop_unknown_references, get_random_sequence_subset, load_synthesizer +from sdv.utils import ( + drop_unknown_references, + get_random_sequence_subset, + load_constraints, + load_synthesizer, +) @pytest.fixture @@ -214,3 +220,19 @@ def test_load_synthesizer(tmp_path): # Assert assert isinstance(loaded_synthesizer, GaussianCopulaSynthesizer) assert set(synthetic_data.columns) == set(data.columns) + + +def test_load_constraints(tmp_path, constraint_object): + """Test the `load_constraints` method.""" + # Setup + constraints = [constraint_object.get_constraint_dict()] + filepath = tmp_path / 'constraints.json' + with open(filepath, 'w') as f: + json.dump(constraints, f) + + # Run + loaded_constraints = load_constraints(filepath) + + # Assert + assert len(loaded_constraints) == 1 + assert loaded_constraints[0].get_constraint_dict() == constraints[0] diff --git a/tests/unit/cag/test__utils.py b/tests/unit/cag/test__utils.py index e8ebfc040..7aefe64de 100644 --- a/tests/unit/cag/test__utils.py +++ b/tests/unit/cag/test__utils.py @@ -11,6 +11,7 @@ _convert_to_snake_case, _filter_old_style_constraints, _is_list_of_type, + _load_constraints_from_file, _remove_columns_from_metadata, _validate_columns_not_primary_key, _validate_constraints, @@ -424,3 +425,53 @@ def test_load_constraints_from_dict(importlib_mock): assert sandbox_constraint == ( sandbox_mock.mock_sandbox_constraint.load_constraint_from_dict.return_value ) + + +@patch('sdv.cag._utils.open') +@patch('sdv.cag._utils.json') +@patch('sdv.cag._utils.load_constraint_from_dict') +def test__load_constraints_from_file( + mock_load_constraint_from_dict, + mock_json, + mock_open, +): + """Test loading a list of constraints from a JSON file.""" + # Setup + constraint_dict1 = {'class_name': 'ConstraintClass1', 'parameters': {}} + invalid_constraint_dict = {'class_name': 'UnknownConstraint', 'parameters': {}} + constraint_dict2 = {'class_name': 'ConstraintClass2', 'parameters': {}} + mock_json.load.return_value = [constraint_dict1, invalid_constraint_dict, constraint_dict2] + + mock_constraint1 = Mock() + mock_constraint2 = Mock() + mock_constraints = { + 'ConstraintClass1': mock_constraint1, + 'ConstraintClass2': mock_constraint2, + } + + def load_constraint_from_dict_mock(constraint_dict): + if constraint_dict['class_name'] == 'UnknownConstraint': + raise ValueError("Unknown `constraint_class` 'UnknownConstraint'.") + + return mock_constraints[constraint_dict['class_name']] + + mock_load_constraint_from_dict.side_effect = load_constraint_from_dict_mock + filepath = 'path/to/constraints.json' + expected_warning = re.escape( + "Could not load constraint ({'class_name': 'UnknownConstraint', 'parameters': {}}):\n" + " ValueError: Unknown `constraint_class` 'UnknownConstraint'." + ) + + # Run + with pytest.warns(UserWarning, match=expected_warning): + result = _load_constraints_from_file(filepath) + + # Assert + assert result == [mock_constraint1, mock_constraint2] + mock_open.assert_called_once_with(filepath, 'r') + mock_json.load.assert_called_once() + mock_load_constraint_from_dict.assert_has_calls([ + call(constraint_dict1), + call(invalid_constraint_dict), + call(constraint_dict2), + ]) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 6104ebe77..e2f8e68ff 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1795,10 +1795,17 @@ def test_get_constraints_from_file(self, mock_Path, mock_json, mock_open): indent=4, ) - @patch('sdv.multi_table.base.open') - @patch('sdv.multi_table.base.json') - @patch('sdv.multi_table.base.load_constraint_from_dict') - def test_set_constraints(self, mock_load_constraint_from_dict, mock_json, mock_open): + @patch('sdv.multi_table.base.warn_set_constraints_deprecated') + @patch('sdv.cag._utils.open') + @patch('sdv.cag._utils.json') + @patch('sdv.cag._utils.load_constraint_from_dict') + def test_set_constraints( + self, + mock_load_constraint_from_dict, + mock_json, + mock_open, + mock_warn_set_constraints_deprecated, + ): """Test setting constraints from file.""" # Setup mock_json.load.return_value = [ @@ -1868,12 +1875,34 @@ def add_constraints_mock(mock_constraint): call([mock_invalid_constraint]), call([mock_constraint2]), ]) + mock_warn_set_constraints_deprecated.assert_called_once() + + @patch('sdv.multi_table.base._load_constraints_from_file') + def test_set_constraints_warns_deprecated(self, mock_load_constraints_from_file): + """Test ``set_constraints`` emits a deprecation warning.""" + # Setup + filepath = 'path/to/constraints.json' + instance = Mock() + instance.get_constraints.return_value = [] + mock_load_constraints_from_file.return_value = [] + expected_message = re.escape( + 'Warning: The `set_constraints` method is deprecated. ' + 'Please use the `load_constraints` utility function to load constraints from a file ' + 'and add them to the synthesizer with the `add_constraints` method.' + ) + + # Run + with pytest.warns(FutureWarning, match=expected_message): + BaseMultiTableSynthesizer.set_constraints(instance, filepath) + + # Assert + mock_load_constraints_from_file.assert_called_once_with(filepath) - def set_constraints_errors_with_existing_constraints(self): + def test__set_constraints_errors_with_existing_constraints(self): """Test ``set_constraints`` errors if constraints already applied.""" # Setup instance = Mock() - instance._get_all_constraints_list = [Mock()] + instance.get_constraints.return_value = [Mock()] # Run and Assert expected_msg = re.escape( diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 120913cb4..ce7e5de94 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -899,10 +899,17 @@ def test_get_constraints_to_file(self, mock_Path, mock_json, mock_open): indent=4, ) - @patch('sdv.single_table.base.open') - @patch('sdv.single_table.base.json') - @patch('sdv.single_table.base.load_constraint_from_dict') - def test_set_constraints(self, mock_load_constraint_from_dict, mock_json, mock_open): + @patch('sdv.single_table.base.warn_set_constraints_deprecated') + @patch('sdv.cag._utils.open') + @patch('sdv.cag._utils.json') + @patch('sdv.cag._utils.load_constraint_from_dict') + def test_set_constraints( + self, + mock_load_constraint_from_dict, + mock_json, + mock_open, + mock_warn_set_constraints_deprecated, + ): """Test setting constraints from file.""" # Setup mock_json.load.return_value = [ @@ -972,12 +979,34 @@ def add_constraints_mock(mock_constraint): call([mock_invalid_constraint]), call([mock_constraint2]), ]) + mock_warn_set_constraints_deprecated.assert_called_once() + + @patch('sdv.single_table.base._load_constraints_from_file') + def test_set_constraints_warns_deprecated(self, mock_load_constraints_from_file): + """Test ``set_constraints`` emits a deprecation warning.""" + # Setup + filepath = 'path/to/constraints.json' + instance = Mock() + instance.get_constraints.return_value = [] + mock_load_constraints_from_file.return_value = [] + expected_message = re.escape( + 'Warning: The `set_constraints` method is deprecated. ' + 'Please use the `load_constraints` utility function to load constraints from a file ' + 'and add them to the synthesizer with the `add_constraints` method.' + ) + + # Run + with pytest.warns(FutureWarning, match=expected_message): + BaseSynthesizer.set_constraints(instance, filepath) + + # Assert + mock_load_constraints_from_file.assert_called_once_with(filepath) - def set_constraints_errors_with_existing_constraints(self): + def test__set_constraints_errors_with_existing_constraints(self): """Test ``set_constraints`` errors if constraints already applied.""" # Setup instance = Mock() - instance._get_all_constraints_list = [Mock()] + instance.get_constraints.return_value = [Mock()] # Run and Assert expected_msg = re.escape( diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 9949e4d6b..c48a2f289 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -10,7 +10,12 @@ from sdv.errors import InvalidDataError from sdv.metadata import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer -from sdv.utils.utils import drop_unknown_references, get_random_sequence_subset, load_synthesizer +from sdv.utils.utils import ( + drop_unknown_references, + get_random_sequence_subset, + load_constraints, + load_synthesizer, +) from tests.utils import catch_sdv_logs @@ -553,3 +558,18 @@ def test_load_synthesizer( 'SYNTHESIZER CLASS NAME': 'BaseMultiTableSynthesizer', 'SYNTHESIZER ID': 'HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', }) + + +@patch('sdv.utils.utils._load_constraints_from_file') +def test_load_constraints(mock_load_constraints_from_file): + """Test ``load_constraints`` function.""" + # Setup + mock_constraint = Mock() + mock_load_constraints_from_file.return_value = [mock_constraint] + + # Run + result = load_constraints('constraints.json') + + # Assert + assert result == [mock_constraint] + mock_load_constraints_from_file.assert_called_once_with('constraints.json')