diff --git a/HISTORY.rst b/HISTORY.rst index 9948734..dba841a 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,6 +2,14 @@ History ======= +(unreleased) +------------ + +Fixed: +^^^^^^ +- ``pyrato.edc.intersection_time_lundeby`` now works for multi-dimensional Signals (PR #170) + + 1.0.0 (2026-03-19) ------------------ diff --git a/pyrato/dsp.py b/pyrato/dsp.py index cc1aa9c..d917de4 100644 --- a/pyrato/dsp.py +++ b/pyrato/dsp.py @@ -155,6 +155,7 @@ def _smooth_rir( The time vector fitting the original data. """ + cshape = data.shape[:-1] data = np.atleast_2d(data) n_samples = data.shape[-1] n_samples_nan = np.count_nonzero(np.isnan(data), axis=-1) @@ -164,11 +165,12 @@ def _smooth_rir( np.floor((n_samples-n_samples_nan)/n_samples_per_block), dtype=int) + # average data in blocks n_blocks_min = int(np.min(n_blocks)) n_samples_actual = int(n_blocks_min*n_samples_per_block) reshaped_array = np.reshape( data[..., :n_samples_actual], - (-1, n_blocks_min, n_samples_per_block)) + (*cshape, n_blocks_min, n_samples_per_block)) time_window_data = np.mean(reshaped_array, axis=-1) # Use average time instances corresponding to the average energy level diff --git a/tests/test_dsp.py b/tests/test_dsp.py index 2b2a4e4..48e91b1 100644 --- a/tests/test_dsp.py +++ b/tests/test_dsp.py @@ -222,7 +222,7 @@ def test_smoothed_rir_1D(): rir, sampling_rate=3000, smooth_block_length=0.075)[0] - npt.assert_allclose(actual, expected) + npt.assert_allclose(np.atleast_2d(actual), expected) def test_smoothed_rir_2D(): diff --git a/tests/test_edc_noise_handling.py b/tests/test_edc_noise_handling.py index bb1c25d..5b38875 100644 --- a/tests/test_edc_noise_handling.py +++ b/tests/test_edc_noise_handling.py @@ -230,7 +230,8 @@ def test_edc_chu_2D(): npt.assert_allclose(actual.time, expected) -def test_intersection_time_1D(): +def test_intersection_time_lundeby_single(): + """Test with a single-channel Signal.""" rir = pf.Signal(genfromtxt( os.path.join(test_data_path, 'analytic_rir_psnr50_1D.csv'), delimiter=','), 3000) @@ -238,6 +239,8 @@ def test_intersection_time_1D(): os.path.join(test_data_path, 'intersection_time_1D.csv'), delimiter=',')).T + assert rir.cshape == (1, ) + actual = enh.intersection_time_lundeby( rir, freq='broadband', @@ -248,7 +251,8 @@ def test_intersection_time_1D(): npt.assert_allclose(actual, expected) -def test_intersection_time_2D(): +def test_intersection_time_lundeby_multichannel(): + """Test with a two-channel Signal.""" rir = pf.Signal(genfromtxt( os.path.join(test_data_path, 'analytic_rir_psnr50_2D.csv'), delimiter=','), 3000) @@ -256,6 +260,32 @@ def test_intersection_time_2D(): os.path.join(test_data_path, 'intersection_time_2D.csv'), delimiter=',')) + assert rir.cshape == (2, ) + + actual = enh.intersection_time_lundeby( + rir, + freq='broadband', + is_energy=False, + time_shift=False, + channel_independent=False, + plot=False) + npt.assert_allclose(actual, expected) + + +def test_intersection_time_lundeby_multi_dimensional(): + """Test with a Signal of cshape = (2, 2).""" + rir = pf.Signal(genfromtxt( + os.path.join(test_data_path, 'analytic_rir_psnr50_2D.csv'), + delimiter=','), 3000) + rir = pf.utils.concatenate_channels((rir, rir), 0).reshape((2, 2)) + + expected = np.atleast_2d(genfromtxt( + os.path.join(test_data_path, 'intersection_time_2D.csv'), + delimiter=',')) + expected = [np.vstack((e, e)) for e in expected] + + assert rir.cshape == (2, 2) + actual = enh.intersection_time_lundeby( rir, freq='broadband',