diff --git a/tests/test_model_utils/test_binary_ks_curve.py b/tests/test_model_utils/test_binary_ks_curve.py new file mode 100644 index 0000000..dc10142 --- /dev/null +++ b/tests/test_model_utils/test_binary_ks_curve.py @@ -0,0 +1,104 @@ +import pytest +import numpy as np +from dython.model_utils import _binary_ks_curve + + +def test_binary_ks_curve_basic(): + """Test basic binary_ks_curve functionality""" + y_true = np.array([0, 1, 0, 1, 1, 0]) + y_probas = np.array([0.1, 0.9, 0.3, 0.8, 0.7, 0.2]) + + thresholds, pct1, pct2, ks_stat, max_dist, classes = _binary_ks_curve(y_true, y_probas) + + assert len(thresholds) > 0 + assert len(pct1) == len(thresholds) + assert len(pct2) == len(thresholds) + assert 0 <= ks_stat <= 1 + assert len(classes) == 2 + + +def test_binary_ks_curve_multiclass_error(): + """Test binary_ks_curve with more than 2 classes""" + y_true = np.array([0, 1, 2, 0, 1, 2]) + y_probas = np.array([0.1, 0.5, 0.9, 0.2, 0.6, 0.8]) + + with pytest.raises(ValueError): + _binary_ks_curve(y_true, y_probas) + + +def test_binary_ks_curve_thresholds_start_with_zero(): + """Test that thresholds start with 0""" + y_true = np.array([0, 1, 0, 1]) + y_probas = np.array([0.2, 0.8, 0.3, 0.9]) + + thresholds, _, _, _, _, _ = _binary_ks_curve(y_true, y_probas) + + assert thresholds[0] == 0.0 + + +def test_binary_ks_curve_thresholds_end_with_one(): + """Test that thresholds end with 1""" + y_true = np.array([0, 1, 0, 1]) + y_probas = np.array([0.2, 0.8, 0.3, 0.7]) + + thresholds, _, _, _, _, _ = _binary_ks_curve(y_true, y_probas) + + assert thresholds[-1] == 1.0 + + +def test_binary_ks_curve_with_edge_probabilities(): + """Test binary_ks_curve with probabilities at 0 and 1""" + y_true = np.array([0, 1, 0, 1]) + y_probas = np.array([0.0, 1.0, 0.1, 0.9]) + + thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas) + + assert len(thresholds) > 0 + assert thresholds[0] == 0.0 + assert thresholds[-1] == 1.0 + + +def test_binary_ks_curve_data1_exhausted_first(): + """Test binary_ks_curve when data1 is exhausted before data2""" + y_true = np.array([0, 0, 1, 1, 1]) + y_probas = np.array([0.1, 0.2, 0.6, 0.7, 0.8]) + + thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas) + + assert len(thresholds) > 0 + assert pct1[-1] == 1.0 + assert pct2[-1] == 1.0 + + +def test_binary_ks_curve_data2_exhausted_first(): + """Test binary_ks_curve when data2 is exhausted before data1""" + y_true = np.array([0, 0, 0, 1, 1]) + y_probas = np.array([0.6, 0.7, 0.8, 0.1, 0.2]) + + thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas) + + assert len(thresholds) > 0 + assert pct1[-1] == 1.0 + assert pct2[-1] == 1.0 + + +def test_binary_ks_curve_equal_values(): + """Test binary_ks_curve with equal probability values""" + y_true = np.array([0, 0, 1, 1]) + y_probas = np.array([0.5, 0.5, 0.5, 0.5]) + + thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas) + + assert len(thresholds) > 0 + + +def test_binary_ks_curve_interleaved_values(): + """Test binary_ks_curve with interleaved probability values""" + y_true = np.array([0, 1, 0, 1, 0, 1]) + y_probas = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) + + thresholds, pct1, pct2, ks_stat, max_dist, _ = _binary_ks_curve(y_true, y_probas) + + assert len(thresholds) > 0 + assert 0 <= ks_stat <= 1 + diff --git a/tests/test_model_utils/test_ks_abc_advanced.py b/tests/test_model_utils/test_ks_abc_advanced.py new file mode 100644 index 0000000..1b1bb6a --- /dev/null +++ b/tests/test_model_utils/test_ks_abc_advanced.py @@ -0,0 +1,121 @@ +import pytest +import numpy as np +import matplotlib.pyplot as plt +from dython.model_utils import ks_abc + + +def test_ks_abc_basic(): + """Test basic ks_abc functionality""" + y_true = [0, 1, 0, 1, 1, 0] + y_pred = [0.1, 0.9, 0.3, 0.8, 0.7, 0.2] + + result = ks_abc(y_true, y_pred, plot=False) + assert 'abc' in result + assert 'ks_stat' in result + assert 'eopt' in result + assert 'ax' in result + + +def test_ks_abc_mismatched_shapes(): + """Test ks_abc with mismatched shapes""" + y_true = [0, 1, 0] + y_pred = [0.1, 0.9, 0.3, 0.8] + + with pytest.raises(ValueError): + ks_abc(y_true, y_pred, plot=False) + + +def test_ks_abc_2d_binary(): + """Test ks_abc with 2D binary array""" + y_true = np.array([[1, 0], [0, 1], [1, 0], [0, 1]]) + y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.3, 0.7]]) + + result = ks_abc(y_true, y_pred, plot=False) + assert 'abc' in result + + +def test_ks_abc_single_column(): + """Test ks_abc with single column""" + y_true = np.array([[1], [0], [1], [0]]) + y_pred = np.array([[0.9], [0.2], [0.7], [0.3]]) + + result = ks_abc(y_true, y_pred, plot=False) + assert 'abc' in result + + +def test_ks_abc_multiclass_error(): + """Test ks_abc with multiclass (should raise error)""" + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7]]) + + with pytest.raises(ValueError): + ks_abc(y_true, y_pred, plot=False) + + +def test_ks_abc_with_ax(): + """Test ks_abc with provided ax""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + fig, ax = plt.subplots() + result = ks_abc(y_true, y_pred, ax=ax, plot=False) + assert result['ax'] == ax + plt.close(fig) + + +def test_ks_abc_with_custom_params(): + """Test ks_abc with custom visualization parameters""" + y_true = [0, 1, 0, 1, 1, 0] + y_pred = [0.1, 0.9, 0.3, 0.8, 0.7, 0.2] + + result = ks_abc( + y_true, y_pred, + colors=('red', 'blue'), + title='Custom KS Title', + xlim=(0, 0.5), + ylim=(0, 0.5), + fmt='.3f', + lw=3, + legend='upper left', + plot=False + ) + assert 'abc' in result + + +def test_ks_abc_no_legend(): + """Test ks_abc without legend""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + result = ks_abc(y_true, y_pred, legend=None, plot=False) + assert 'ax' in result + + +def test_ks_abc_with_filename(tmp_path): + """Test ks_abc with filename""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + filename = tmp_path / "ks_plot.png" + result = ks_abc(y_true, y_pred, filename=str(filename), plot=False) + assert filename.exists() + plt.close('all') + + +def test_ks_abc_abc_value_range(): + """Test that ABC value is in valid range""" + y_true = [0, 1, 0, 1, 1, 0, 0, 1] + y_pred = [0.1, 0.9, 0.2, 0.8, 0.7, 0.3, 0.15, 0.85] + + result = ks_abc(y_true, y_pred, plot=False) + assert 0 <= result['abc'] <= 1 + + +def test_ks_abc_ks_stat_value_range(): + """Test that KS statistic is in valid range""" + y_true = [0, 1, 0, 1, 1, 0, 0, 1] + y_pred = [0.1, 0.9, 0.2, 0.8, 0.7, 0.3, 0.15, 0.85] + + result = ks_abc(y_true, y_pred, plot=False) + assert 0 <= result['ks_stat'] <= 1 + diff --git a/tests/test_model_utils/test_metric_graph_advanced.py b/tests/test_model_utils/test_metric_graph_advanced.py new file mode 100644 index 0000000..282ff0d --- /dev/null +++ b/tests/test_model_utils/test_metric_graph_advanced.py @@ -0,0 +1,195 @@ +import pytest +import numpy as np +import matplotlib.pyplot as plt +from dython.model_utils import metric_graph + + +def test_metric_graph_invalid_metric(): + """Test metric_graph with invalid metric""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + with pytest.raises(ValueError): + metric_graph(y_true, y_pred, metric='invalid', plot=False) + + +def test_metric_graph_none_metric(): + """Test metric_graph with None metric""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + with pytest.raises(ValueError): + metric_graph(y_true, y_pred, metric=None, plot=False) + + +def test_metric_graph_multiclass_with_class_names(): + """Test metric_graph with multiclass and class names""" + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) + y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7], [0.6, 0.3, 0.1]]) + + result = metric_graph(y_true, y_pred, metric='roc', class_names=['A', 'B', 'C'], plot=False) + assert 'A' in result + assert 'B' in result + assert 'C' in result + + +def test_metric_graph_multiclass_wrong_class_names_type(): + """Test metric_graph with multiclass and wrong class names type""" + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) + y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7], [0.6, 0.3, 0.1]]) + + # class_names as non-list/non-string for multiclass should raise error + with pytest.raises((ValueError, TypeError)): + metric_graph(y_true, y_pred, metric='roc', class_names=123, plot=False) + + +def test_metric_graph_multiclass_wrong_class_names_count(): + """Test metric_graph with multiclass and wrong number of class names""" + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) + y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7], [0.6, 0.3, 0.1]]) + + with pytest.raises(ValueError): + metric_graph(y_true, y_pred, metric='roc', class_names=['A', 'B'], plot=False) + + +def test_metric_graph_pr_binary(): + """Test PR curve for binary classification""" + y_true = [0, 1, 0, 1, 1, 0] + y_pred = [0.1, 0.9, 0.3, 0.8, 0.7, 0.2] + + result = metric_graph(y_true, y_pred, metric='pr', plot=False) + assert 'auc' in result['0'] + assert 'naive' in result['0']['auc'] + + +def test_metric_graph_pr_multiclass(): + """Test PR curve for multiclass""" + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) + y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7], [0.6, 0.3, 0.1]]) + + result = metric_graph(y_true, y_pred, metric='pr', plot=False) + assert '0' in result + + +def test_metric_graph_with_colors_string(): + """Test metric_graph with colors as string""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + result = metric_graph(y_true, y_pred, metric='roc', colors='red', plot=False) + assert 'ax' in result + + +def test_metric_graph_multiclass_no_micro(): + """Test metric_graph multiclass without micro""" + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) + y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7], [0.6, 0.3, 0.1]]) + + result = metric_graph(y_true, y_pred, metric='roc', micro=False, plot=False) + assert '0' in result + + +def test_metric_graph_multiclass_no_macro(): + """Test metric_graph multiclass without macro""" + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) + y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7], [0.6, 0.3, 0.1]]) + + result = metric_graph(y_true, y_pred, metric='roc', macro=False, plot=False) + assert '0' in result + + +def test_metric_graph_pr_multiclass_no_macro(): + """Test PR curve multiclass without macro (macro not applicable for PR)""" + y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]]) + y_pred = np.array([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.1, 0.7], [0.6, 0.3, 0.1]]) + + result = metric_graph(y_true, y_pred, metric='pr', macro=False, plot=False) + assert '0' in result + + +def test_metric_graph_binary_no_eopt(): + """Test metric_graph binary without eopt""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + result = metric_graph(y_true, y_pred, metric='roc', eopt=False, plot=False) + assert result['0']['eopt']['val'] is None + + +def test_metric_graph_multiclass_force(): + """Test metric_graph with force_multiclass flag""" + y_true = np.array([[1, 0], [0, 1], [1, 0], [0, 1]]) + y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.3, 0.7]]) + + result = metric_graph(y_true, y_pred, metric='roc', force_multiclass=True, plot=False) + assert '0' in result + assert '1' in result + + +def test_metric_graph_binary_2d_array(): + """Test metric_graph binary with 2D array""" + y_true = np.array([[1, 0], [0, 1], [1, 0], [0, 1]]) + y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.3, 0.7]]) + + result = metric_graph(y_true, y_pred, metric='roc', plot=False) + assert 'auc' in result['0'] + + +def test_metric_graph_mismatched_shapes(): + """Test metric_graph with mismatched shapes""" + y_true = [0, 1, 0] + y_pred = [0.1, 0.9, 0.3, 0.8] + + with pytest.raises(ValueError): + metric_graph(y_true, y_pred, metric='roc', plot=False) + + +def test_metric_graph_with_ax(): + """Test metric_graph with provided ax""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + fig, ax = plt.subplots() + result = metric_graph(y_true, y_pred, metric='roc', ax=ax, plot=False) + assert result['ax'] == ax + plt.close(fig) + + +def test_metric_graph_binary_1d_arrays(): + """Test metric_graph binary with 1D arrays""" + y_true = np.array([1, 0, 1, 0]) + y_pred = np.array([0.9, 0.2, 0.7, 0.3]) + + result = metric_graph(y_true, y_pred, metric='roc', plot=False) + assert '0' in result + + +def test_metric_graph_with_custom_params(): + """Test metric_graph with custom visualization parameters""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + result = metric_graph( + y_true, y_pred, + metric='roc', + xlim=(0, 0.5), + ylim=(0.5, 1.0), + lw=3, + ls='--', + ms=15, + fmt='.3f', + legend='upper right', + title='Custom Title', + plot=False + ) + assert 'ax' in result + + +def test_metric_graph_with_class_name_string(): + """Test metric_graph with class_names as string for binary""" + y_true = [0, 1, 0, 1] + y_pred = [0.1, 0.9, 0.3, 0.8] + + result = metric_graph(y_true, y_pred, metric='roc', class_names='PositiveClass', plot=False) + assert 'PositiveClass' in result + diff --git a/tests/test_nominal/test_associations_advanced.py b/tests/test_nominal/test_associations_advanced.py new file mode 100644 index 0000000..32d9957 --- /dev/null +++ b/tests/test_nominal/test_associations_advanced.py @@ -0,0 +1,357 @@ +import pytest +import numpy as np +import pandas as pd +from dython.nominal import associations, replot_last_associations + + +def test_associations_with_numerical_columns(): + """Test associations with numerical_columns parameter""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num1': [1, 2, 3, 4], + 'num2': [5, 6, 7, 8] + }) + result = associations(df, numerical_columns=['num1', 'num2'], plot=False) + assert 'corr' in result + assert isinstance(result['corr'], pd.DataFrame) + + +def test_associations_with_numerical_columns_all(): + """Test associations with numerical_columns='all'""" + df = pd.DataFrame({ + 'num1': [1, 2, 3, 4], + 'num2': [5, 6, 7, 8] + }) + result = associations(df, numerical_columns='all', plot=False) + assert 'corr' in result + + +def test_associations_with_numerical_columns_auto(): + """Test associations with numerical_columns='auto'""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, numerical_columns='auto', plot=False) + assert 'corr' in result + + +def test_associations_drop_samples(): + """Test associations with drop_samples nan strategy""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', None, 'a'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1'], nan_strategy='drop_samples', plot=False) + assert 'corr' in result + + +def test_associations_drop_features(): + """Test associations with drop_features nan strategy""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', None, 'a'], + 'cat2': ['x', 'y', 'z', 'x'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], nan_strategy='drop_features', plot=False) + assert 'corr' in result + + +def test_associations_drop_sample_pairs(): + """Test associations with drop_sample_pairs nan strategy""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', None, 'a'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1'], nan_strategy='drop_sample_pairs', plot=False) + assert 'corr' in result + + +def test_associations_invalid_nan_strategy(): + """Test associations with invalid nan strategy""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'num1': [1, 2, 3, 4] + }) + with pytest.raises(ValueError): + associations(df, nominal_columns=['cat1'], nan_strategy='invalid', plot=False) + + +def test_associations_hide_rows(): + """Test associations with hide_rows parameter""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], hide_rows='cat1', plot=False) + assert 'corr' in result + assert 'cat1' not in result['corr'].index + + +def test_associations_hide_columns(): + """Test associations with hide_columns parameter""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], hide_columns='num1', plot=False) + assert 'corr' in result + assert 'num1' not in result['corr'].columns + + +def test_associations_hide_rows_list(): + """Test associations with hide_rows as list""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], hide_rows=['cat1', 'cat2'], plot=False) + assert 'corr' in result + + +def test_associations_hide_columns_list(): + """Test associations with hide_columns as list""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], hide_columns=['cat1', 'num1'], plot=False) + assert 'corr' in result + + +def test_associations_display_rows_single(): + """Test associations with display_rows as single column""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], display_rows=['cat1'], plot=False) + assert 'corr' in result + assert 'cat1' in result['corr'].index + + +def test_associations_display_columns_single(): + """Test associations with display_columns as single column""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], display_columns='num1', plot=False) + assert 'corr' in result + + +def test_associations_with_datetime(): + """Test associations with datetime columns""" + df = pd.DataFrame({ + 'date': pd.date_range('2020-01-01', periods=4), + 'num1': [1, 2, 3, 4], + 'cat1': ['a', 'b', 'c', 'a'] + }) + result = associations(df, nominal_columns=['cat1'], plot=False) + assert 'corr' in result + + +def test_associations_with_categorical_dtype(): + """Test associations with pandas categorical dtype""" + df = pd.DataFrame({ + 'cat1': pd.Categorical(['a', 'b', 'c', 'a']), + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns='auto', plot=False) + assert 'corr' in result + + +def test_associations_with_categorical_nan(): + """Test associations with categorical dtype and NaN values""" + df = pd.DataFrame({ + 'cat1': pd.Categorical(['a', 'b', None, 'a']), + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1'], nan_strategy='replace', nan_replace_value='missing', plot=False) + assert 'corr' in result + + +def test_associations_clustering(): + """Test associations with clustering enabled""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'] * 3, + 'cat2': ['x', 'y', 'x', 'y'] * 3, + 'num1': list(range(12)), + 'num2': list(range(12, 24)) + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], clustering=True, plot=False) + assert 'corr' in result + + +def test_associations_mark_columns(): + """Test associations with mark_columns enabled""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1'], mark_columns=True, plot=False) + assert 'corr' in result + assert any('(nom)' in str(col) for col in result['corr'].columns) + + +def test_associations_theil(): + """Test associations with Theil's U""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'] + }) + result = associations(df, nominal_columns='all', nom_nom_assoc='theil', plot=False) + assert 'corr' in result + + +def test_associations_spearman(): + """Test associations with Spearman correlation""" + df = pd.DataFrame({ + 'num1': [1, 2, 3, 4], + 'num2': [5, 6, 7, 8] + }) + result = associations(df, nominal_columns=None, num_num_assoc='spearman', plot=False) + assert 'corr' in result + + +def test_associations_kendall(): + """Test associations with Kendall correlation""" + df = pd.DataFrame({ + 'num1': [1, 2, 3, 4], + 'num2': [5, 6, 7, 8] + }) + result = associations(df, nominal_columns=None, num_num_assoc='kendall', plot=False) + assert 'corr' in result + + +def test_associations_custom_nom_nom(): + """Test associations with custom nominal-nominal function""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'] + }) + + def custom_assoc(x, y): + return 0.5 + + result = associations(df, nominal_columns='all', nom_nom_assoc=custom_assoc, plot=False) + assert 'corr' in result + + +def test_associations_custom_nom_nom_asymmetric(): + """Test associations with custom asymmetric nominal-nominal function""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'] + }) + + def custom_assoc(x, y): + return 0.5 + + result = associations(df, nominal_columns='all', nom_nom_assoc=custom_assoc, symmetric_nom_nom=False, plot=False) + assert 'corr' in result + + +def test_associations_custom_num_num(): + """Test associations with custom numerical-numerical function""" + df = pd.DataFrame({ + 'num1': [1, 2, 3, 4], + 'num2': [5, 6, 7, 8] + }) + + def custom_corr(x, y): + return 0.9 + + result = associations(df, nominal_columns=None, num_num_assoc=custom_corr, plot=False) + assert 'corr' in result + + +def test_associations_custom_num_num_asymmetric(): + """Test associations with custom asymmetric numerical-numerical function""" + df = pd.DataFrame({ + 'num1': [1, 2, 3, 4], + 'num2': [5, 6, 7, 8] + }) + + def custom_corr(x, y): + return 0.9 + + result = associations(df, nominal_columns=None, num_num_assoc=custom_corr, symmetric_num_num=False, plot=False) + assert 'corr' in result + + +def test_associations_custom_nom_num(): + """Test associations with custom nominal-numerical function""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'num1': [1, 2, 3, 4] + }) + + def custom_assoc(x, y): + return 0.7 + + result = associations(df, nominal_columns=['cat1'], nom_num_assoc=custom_assoc, plot=False) + assert 'corr' in result + + +def test_associations_cramers_v_no_bias_correction(): + """Test associations with Cramer's V without bias correction""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'] + }) + result = associations(df, nominal_columns='all', cramers_v_bias_correction=False, plot=False) + assert 'corr' in result + + +def test_replot_last_associations(): + """Test replot_last_associations function""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'num1': [1, 2, 3, 4] + }) + # First create an association plot + associations(df, nominal_columns=['cat1'], plot=False) + + # Now replot + ax = replot_last_associations(plot=False) + assert ax is not None + + +def test_replot_last_associations_without_prior(): + """Test replot_last_associations without prior associations call""" + from dython.nominal import _ASSOC_PLOT_PARAMS + _ASSOC_PLOT_PARAMS.clear() + + with pytest.raises(RuntimeError): + replot_last_associations(plot=False) + + +def test_associations_compute_only(): + """Test associations with compute_only flag""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1'], compute_only=True) + assert 'corr' in result + assert result['ax'] is None + + +def test_associations_single_value_column(): + """Test associations with single-value column""" + df = pd.DataFrame({ + 'cat1': ['a', 'a', 'a', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num1': [1, 2, 3, 4] + }) + result = associations(df, nominal_columns=['cat1', 'cat2'], plot=False) + assert 'corr' in result + diff --git a/tests/test_nominal/test_cluster_correlations.py b/tests/test_nominal/test_cluster_correlations.py new file mode 100644 index 0000000..9137ead --- /dev/null +++ b/tests/test_nominal/test_cluster_correlations.py @@ -0,0 +1,60 @@ +import pytest +import numpy as np +import pandas as pd +from dython.nominal import cluster_correlations, associations + + +def test_cluster_correlations_with_dataframe(): + """Test cluster correlations with DataFrame input""" + corr_mat = pd.DataFrame( + [[1.0, 0.8, 0.3], + [0.8, 1.0, 0.4], + [0.3, 0.4, 1.0]], + columns=['A', 'B', 'C'], + index=['A', 'B', 'C'] + ) + result, indices = cluster_correlations(corr_mat) + assert isinstance(result, pd.DataFrame) + assert isinstance(indices, np.ndarray) + assert result.shape == corr_mat.shape + + +def test_cluster_correlations_with_numpy(): + """Test cluster correlations with numpy array input""" + corr_mat = np.array( + [[1.0, 0.8, 0.3], + [0.8, 1.0, 0.4], + [0.3, 0.4, 1.0]] + ) + result, indices = cluster_correlations(corr_mat) + assert isinstance(result, np.ndarray) + assert isinstance(indices, np.ndarray) + assert result.shape == corr_mat.shape + + +def test_cluster_correlations_with_provided_indices(): + """Test cluster correlations with provided indices""" + corr_mat = pd.DataFrame( + [[1.0, 0.8, 0.3], + [0.8, 1.0, 0.4], + [0.3, 0.4, 1.0]], + columns=['A', 'B', 'C'], + index=['A', 'B', 'C'] + ) + indices = np.array([1, 1, 2]) + result, result_indices = cluster_correlations(corr_mat, indices) + assert isinstance(result, pd.DataFrame) + assert np.array_equal(result_indices, indices) + + +def test_cluster_correlations_larger_matrix(): + """Test cluster correlations with larger matrix""" + n = 10 + corr_mat = np.random.rand(n, n) + corr_mat = (corr_mat + corr_mat.T) / 2 # Make symmetric + np.fill_diagonal(corr_mat, 1.0) + + result, indices = cluster_correlations(corr_mat) + assert isinstance(result, np.ndarray) + assert result.shape == corr_mat.shape + diff --git a/tests/test_nominal/test_conditional_entropy.py b/tests/test_nominal/test_conditional_entropy.py new file mode 100644 index 0000000..5cc635e --- /dev/null +++ b/tests/test_nominal/test_conditional_entropy.py @@ -0,0 +1,50 @@ +import pytest +import numpy as np +import pandas as pd +from dython.nominal import conditional_entropy + + +def test_conditional_entropy_basic(): + """Test basic conditional entropy calculation""" + x = [1, 1, 2, 2, 3, 3] + y = ['a', 'a', 'b', 'b', 'c', 'c'] + result = conditional_entropy(x, y) + assert isinstance(result, float) + assert result >= 0 + + +def test_conditional_entropy_with_drop_strategy(): + """Test conditional entropy with drop nan strategy""" + x = np.array([1.0, 1.0, 2.0, 2.0, 3.0, np.nan]) + y = np.array([1.0, 1.0, 2.0, 2.0, 3.0, 3.0]) + result = conditional_entropy(x, y, nan_strategy='drop') + assert isinstance(result, float) + assert result >= 0 + + +def test_conditional_entropy_with_replace_strategy(): + """Test conditional entropy with replace nan strategy""" + x = [1, 1, 2, 2, 3, None] + y = ['a', 'a', 'b', 'b', 'c', 'c'] + result = conditional_entropy(x, y, nan_strategy='replace', nan_replace_value=0) + assert isinstance(result, float) + assert result >= 0 + + +def test_conditional_entropy_custom_log_base(): + """Test conditional entropy with custom log base""" + x = [1, 1, 2, 2, 3, 3] + y = ['a', 'a', 'b', 'b', 'c', 'c'] + result = conditional_entropy(x, y, log_base=2) + assert isinstance(result, float) + assert result >= 0 + + +def test_conditional_entropy_with_pandas(): + """Test conditional entropy with pandas Series""" + x = pd.Series([1, 1, 2, 2, 3, 3]) + y = pd.Series(['a', 'a', 'b', 'b', 'c', 'c']) + result = conditional_entropy(x, y) + assert isinstance(result, float) + assert result >= 0 + diff --git a/tests/test_nominal/test_inf_nan_handling.py b/tests/test_nominal/test_inf_nan_handling.py new file mode 100644 index 0000000..25aab17 --- /dev/null +++ b/tests/test_nominal/test_inf_nan_handling.py @@ -0,0 +1,76 @@ +import pytest +import numpy as np +import pandas as pd +from dython.nominal import associations, theils_u, correlation_ratio, cramers_v + + +def test_theils_u_single_category(): + """Test Theil's U with single category returns 1""" + # When x and y are completely determined by each other + x = ['a'] * 100 + y = ['a'] * 100 + + result = theils_u(x, y) + assert result == 1.0 + + +def test_theils_u_with_replace_strategy(): + """Test Theil's U with replace nan strategy""" + x = pd.Series(['a', 'b', 'c', None, 'a']) + y = pd.Series(['x', 'y', 'z', 'w', 'x']) + + result = theils_u(x, y, nan_strategy='replace', nan_replace_value='missing') + assert isinstance(result, float) + + +def test_correlation_ratio_with_replace_strategy(): + """Test correlation ratio with replace nan strategy""" + categories = pd.Series(['a', 'b', 'c', None, 'a']) + measurements = pd.Series([1.0, 2.0, 3.0, 4.0, 1.5]) + + result = correlation_ratio(categories, measurements, nan_strategy='replace', nan_replace_value='missing') + assert isinstance(result, float) + + +def test_correlation_ratio_zero_numerator(): + """Test correlation ratio with zero numerator""" + # All measurements are the same + categories = ['a', 'b', 'c', 'a'] + measurements = [5.0, 5.0, 5.0, 5.0] + + result = correlation_ratio(categories, measurements) + assert result == 0.0 + + +def test_correlation_ratio_precision_warning(): + """Test correlation ratio with values that need precision rounding""" + # Create data that produces eta slightly > 1 + np.random.seed(42) + categories = ['a'] * 50 + ['b'] * 50 + measurements = [1.0] * 50 + [2.0] * 50 + + result = correlation_ratio(categories, measurements) + assert 0.0 <= result <= 1.0 + + +def test_cramers_v_without_bias_correction(): + """Test Cramer's V without bias correction""" + x = ['a', 'b', 'c', 'a'] + y = ['x', 'y', 'x', 'y'] + + result = cramers_v(x, y, bias_correction=False) + assert isinstance(result, float) + assert 0.0 <= result <= 1.0 + + +def test_associations_with_inf_nan_values(): + """Test associations with inf/nan producing columns""" + df = pd.DataFrame({ + 'cat1': ['a'] * 10, # Single value + 'cat2': ['x', 'y'] * 5, + 'num1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + }) + + result = associations(df, nominal_columns=['cat1', 'cat2'], plot=False) + assert 'corr' in result + diff --git a/tests/test_nominal/test_numerical_encoding.py b/tests/test_nominal/test_numerical_encoding.py new file mode 100644 index 0000000..89b5ee0 --- /dev/null +++ b/tests/test_nominal/test_numerical_encoding.py @@ -0,0 +1,145 @@ +import pytest +import numpy as np +import pandas as pd +from dython.nominal import numerical_encoding + + +def test_numerical_encoding_basic(): + """Test basic numerical encoding""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num': [1, 2, 3, 4] + }) + result = numerical_encoding(df, nominal_columns=['cat1', 'cat2']) + assert isinstance(result, pd.DataFrame) + assert len(result.columns) >= len(df.columns) + + +def test_numerical_encoding_auto(): + """Test numerical encoding with auto detection""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'num': [1, 2, 3, 4] + }) + result = numerical_encoding(df, nominal_columns='auto') + assert isinstance(result, pd.DataFrame) + + +def test_numerical_encoding_all(): + """Test numerical encoding with all columns""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'a'], + 'cat2': ['x', 'y', 'x', 'y'] + }) + result = numerical_encoding(df, nominal_columns='all') + assert isinstance(result, pd.DataFrame) + + +def test_numerical_encoding_none(): + """Test numerical encoding with None nominal columns""" + df = pd.DataFrame({ + 'num1': [1, 2, 3, 4], + 'num2': [5, 6, 7, 8] + }) + result = numerical_encoding(df, nominal_columns=None) + assert isinstance(result, pd.DataFrame) + pd.testing.assert_frame_equal(result, df) + + +def test_numerical_encoding_single_value_with_numeric(): + """Test numerical encoding with single value column alongside numeric""" + df = pd.DataFrame({ + 'num': [1, 2, 3, 4], + 'cat1': ['a', 'a', 'a', 'a'], + 'cat2': ['x', 'y', 'x', 'y'] + }) + result = numerical_encoding(df, nominal_columns=['cat1', 'cat2']) + assert isinstance(result, pd.DataFrame) + # cat1 should be encoded as 0 (single value) + assert 'cat1' in result.columns + assert (result['cat1'] == 0).all() + + +def test_numerical_encoding_drop_single_label(): + """Test numerical encoding with drop_single_label=True""" + df = pd.DataFrame({ + 'cat1': ['a', 'a', 'a', 'a'], + 'cat2': ['x', 'y', 'x', 'y'], + 'num': [1, 2, 3, 4] + }) + result = numerical_encoding(df, nominal_columns=['cat1', 'cat2'], drop_single_label=True) + assert isinstance(result, pd.DataFrame) + assert 'cat1' not in result.columns + + +def test_numerical_encoding_return_dict(): + """Test numerical encoding with drop_fact_dict=False""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'a', 'b'], + 'num': [1, 2, 3, 4] + }) + result = numerical_encoding(df, nominal_columns=['cat1'], drop_fact_dict=False) + assert isinstance(result, tuple) + assert len(result) == 2 + assert isinstance(result[0], pd.DataFrame) + assert isinstance(result[1], dict) + + +def test_numerical_encoding_with_nan_replace(): + """Test numerical encoding with nan replace strategy""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', None, 'a'], + 'num': [1, 2, 3, np.nan] + }) + result = numerical_encoding(df, nominal_columns=['cat1'], nan_strategy='replace', nan_replace_value=0) + assert isinstance(result, pd.DataFrame) + assert not result.isnull().any().any() + + +def test_numerical_encoding_with_nan_drop_samples(): + """Test numerical encoding with nan drop_samples strategy""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', None, 'a'], + 'num': [1, 2, 3, 4] + }) + result = numerical_encoding(df, nominal_columns=['cat1'], nan_strategy='drop_samples') + assert isinstance(result, pd.DataFrame) + assert len(result) < len(df) + + +def test_numerical_encoding_with_nan_drop_features(): + """Test numerical encoding with nan drop_features strategy""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', None, 'a'], + 'cat2': ['x', 'y', 'z', 'x'], + 'num': [1, 2, 3, 4] + }) + result = numerical_encoding(df, nominal_columns=['cat1', 'cat2'], nan_strategy='drop_features') + assert isinstance(result, pd.DataFrame) + assert 'cat1' not in result.columns + + +def test_numerical_encoding_three_plus_values(): + """Test numerical encoding with more than two values (get_dummies)""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'c', 'd'], + 'num': [1, 2, 3, 4] + }) + result = numerical_encoding(df, nominal_columns=['cat1']) + assert isinstance(result, pd.DataFrame) + # Should have dummy columns for cat1 + assert len(result.columns) > 2 + + +def test_numerical_encoding_two_values(): + """Test numerical encoding with exactly two values (factorize)""" + df = pd.DataFrame({ + 'cat1': ['a', 'b', 'a', 'b'], + 'num': [1, 2, 3, 4] + }) + result, fact_dict = numerical_encoding(df, nominal_columns=['cat1'], drop_fact_dict=False) + assert isinstance(result, pd.DataFrame) + assert 'cat1' in fact_dict + assert len(fact_dict['cat1']) == 2 +