Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
History
=======

(unreleased)
------------

Fixed:
^^^^^^
- ``pyrato.edc.intersection_time_lundeby`` now works for multi-dimensional Signals (PR #170)


1.0.0 (2026-03-19)
------------------

Expand Down
4 changes: 3 additions & 1 deletion pyrato/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _smooth_rir(
The time vector fitting the original data.

"""
cshape = data.shape[:-1]
data = np.atleast_2d(data)
n_samples = data.shape[-1]
n_samples_nan = np.count_nonzero(np.isnan(data), axis=-1)
Expand All @@ -164,11 +165,12 @@ def _smooth_rir(
np.floor((n_samples-n_samples_nan)/n_samples_per_block),
dtype=int)

# average data in blocks
n_blocks_min = int(np.min(n_blocks))
n_samples_actual = int(n_blocks_min*n_samples_per_block)
reshaped_array = np.reshape(
data[..., :n_samples_actual],
(-1, n_blocks_min, n_samples_per_block))
(*cshape, n_blocks_min, n_samples_per_block))
time_window_data = np.mean(reshaped_array, axis=-1)

# Use average time instances corresponding to the average energy level
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_smoothed_rir_1D():
rir,
sampling_rate=3000,
smooth_block_length=0.075)[0]
npt.assert_allclose(actual, expected)
npt.assert_allclose(np.atleast_2d(actual), expected)


def test_smoothed_rir_2D():
Expand Down
34 changes: 32 additions & 2 deletions tests/test_edc_noise_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,17 @@ def test_edc_chu_2D():
npt.assert_allclose(actual.time, expected)


def test_intersection_time_1D():
def test_intersection_time_lundeby_single():
"""Test with a single-channel Signal."""
rir = pf.Signal(genfromtxt(
os.path.join(test_data_path, 'analytic_rir_psnr50_1D.csv'),
delimiter=','), 3000)
expected = np.atleast_2d(genfromtxt(
os.path.join(test_data_path, 'intersection_time_1D.csv'),
delimiter=',')).T

assert rir.cshape == (1, )

actual = enh.intersection_time_lundeby(
rir,
freq='broadband',
Expand All @@ -248,14 +251,41 @@ def test_intersection_time_1D():
npt.assert_allclose(actual, expected)


def test_intersection_time_2D():
def test_intersection_time_lundeby_multichannel():
"""Test with a two-channel Signal."""
rir = pf.Signal(genfromtxt(
os.path.join(test_data_path, 'analytic_rir_psnr50_2D.csv'),
delimiter=','), 3000)
expected = np.atleast_2d(genfromtxt(
os.path.join(test_data_path, 'intersection_time_2D.csv'),
delimiter=','))

assert rir.cshape == (2, )

actual = enh.intersection_time_lundeby(
rir,
freq='broadband',
is_energy=False,
time_shift=False,
channel_independent=False,
plot=False)
npt.assert_allclose(actual, expected)


def test_intersection_time_lundeby_multi_dimensional():
"""Test with a Signal of cshape = (2, 2)."""
rir = pf.Signal(genfromtxt(
os.path.join(test_data_path, 'analytic_rir_psnr50_2D.csv'),
delimiter=','), 3000)
rir = pf.utils.concatenate_channels((rir, rir), 0).reshape((2, 2))

expected = np.atleast_2d(genfromtxt(
os.path.join(test_data_path, 'intersection_time_2D.csv'),
delimiter=','))
expected = [np.vstack((e, e)) for e in expected]

assert rir.cshape == (2, 2)

actual = enh.intersection_time_lundeby(
rir,
freq='broadband',
Expand Down