Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 85 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,123 @@ PyGaborSTM is a Python library for extracting Rate-Scale-Frequency (RSF) represe

<!-- TODO: Publish to PyPI -->
```bash
pip install pygaborstm
pip install pygaborstm # not published yet
```
For now, install from source (see below).

### From source
```bash
git clone https://github.com/JHU-LCAP/PyGaborSTM.git
cd pygaborstm
cd PyGaborSTM
poetry install
```

### GPU Support (Optional, Linux/Windows only)
For GPU acceleration, you need:

1. **NVIDIA GPU** with CUDA support
2. **CUDA Toolkit** installed on your system

```bash
# Check your CUDA version
nvidia-smi
```

Download and install the CUDA Toolkit from NVIDIA:
https://developer.nvidia.com/cuda-toolkit

After installation, add to your `~/.bashrc` or `~/.zshrc`:

```bash
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
```

Verify installation:

```bash
nvcc --version
```

The library uses CuPy for GPU acceleration. Make sure your CuPy version matches your CUDA version:
- CUDA 11.x → `cupy-cuda11x`
- CUDA 12.x → `cupy-cuda12x`
- CUDA 13.x → `cupy-cuda13x`

## Quick Start
```python
import pygaborstm as stm

# One-liner
rsf = stm.compute_rsf("audio.wav")
# Create model (CPU)
model = stm.PyGaborSTM()

# Step by step
audio, sr = stm.load("audio.wav")
spectrogram = stm.auditory_spectrogram(audio)
rsf = stm.rsf(spectrogram)
# Create model (GPU)
model = stm.PyGaborSTM(config=stm.Config(use_gpu=True))

# Visualization
stm.plot_spectrogram(spectrogram)
stm.plot_rsf(rsf) # Unfolded
stm.plot_rsf(rsf, fold=True) # Symmetric
# Compute spectrogram and RSF
spec = model.spectrogram(audio)
rsf = model.rsf(spec)

# Access data
rs_matrix = rsf.rate_scale_matrix() # For visualization
rsf_3d = rsf.mean_over_time() # For TSVD input
# Visualization
stm.plot.spectrogram(spec)
stm.plot.rsf(rsf)
stm.plot.rsf(rsf, fold=True) # Symmetric folding
```

See `notebooks/example_usage.ipynb` for more examples.

## Configuration
```python
config = stm.Config(
# General
use_gpu=False, # Enable GPU acceleration
sample_rate=16000, # Audio sample rate

# Spectrogram
n_filters=128, # Number of frequency channels
f_min=180.0, # Minimum frequency (Hz)
octaves=5.3, # Frequency range in octaves

# RSF / Gabor
resolution="low", # "low", "medium", "high", "ultra"
)
```

## Directory Structure
```
pygaborstm/
PyGaborSTM/
├── pygaborstm/
│ ├── __init__.py # Public API
│ ├── config.py # SpectrogramConfig, GaborConfig, Config
│ ├── config.py # Config dataclass
│ ├── structs.py # Spectrogram, RSF dataclasses
│ ├── spectrogram.py # AuditorySpectrogram
│ ├── gabor.py # GaborFilterbank
│ ├── core.py # load(), compute_rsf()
│ └── plotting.py # plot_spectrogram(), plot_rsf(), plot_filterbank()
│ ├── core.py # PyGaborSTM class
│ ├── plot.py # Plotting functions
│ └── backend.py # NumPy/CuPy switching
├── notebooks/
│ ├── assets/
│ └── example_usage.ipynb
└── tests/
```

## Development
```bash
poetry run jupyter notebook # Run notebooks
poetry run pytest -v # Run tests
poetry install # Install all dependencies
poetry run jupyter notebook # Run notebooks
poetry run pytest -v # Run tests
poetry run ruff check --fix . # lint and fix
poetry run ruff format . # format code
```

Note: Please lint and format before pushing, as CI will fail otherwise.

### Jupyter Kernel
Ensure your notebook uses the correct Poetry environment:
```bash
# Check Poetry env path
poetry env info --path

# Register kernel (if needed)
poetry run python -m ipykernel install --user --name pygaborstm
```

## References
Expand Down
855 changes: 855 additions & 0 deletions notebooks/chi2005_validation.ipynb

Large diffs are not rendered by default.

142 changes: 89 additions & 53 deletions notebooks/example_usage.ipynb

Large diffs are not rendered by default.

185 changes: 185 additions & 0 deletions notebooks/mvripfft_validation.ipynb

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions notebooks/nb_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .audio_generator import (
generate_tone,
generate_three_tones,
generate_broadband_noise,
generate_harmonic_complex,
generate_moving_ripple,
generate_ripple_set,
save_three_tones,
save_noise,
save_harmonic_complexes,
SR,
DURATION,
)

__all__ = [
"generate_tone",
"generate_three_tones",
"generate_broadband_noise",
"generate_harmonic_complex",
"generate_moving_ripple",
"generate_ripple_set",
"save_three_tones",
"save_noise",
"save_harmonic_complexes",
"SR",
"DURATION",
]
172 changes: 172 additions & 0 deletions notebooks/nb_utils/audio_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
Stimulus generation for auditory spectrogram validation.

Based on Chi, Ru & Shamma (2005) "Multiresolution spectrotemporal analysis of complex sounds"
"""

import numpy as np
from pathlib import Path
import soundfile as sf

# Defaults
SR = 16000
DURATION = 3.0


def generate_tone(freq, duration=DURATION, sr=SR, amplitude=0.5):
"""Generate a pure Sine tone."""
t = np.arange(int(duration * sr)) / sr
return amplitude * np.sin(2 * np.pi * freq * t)


def generate_three_tones(duration=DURATION, sr=SR):
"""Generate 250, 1000, 4000 Hz tones (Chi 2005 Section III.B.1)."""
return (
generate_tone(250, duration, sr),
generate_tone(1000, duration, sr),
generate_tone(4000, duration, sr),
)


def generate_broadband_noise(duration=DURATION, sr=SR, seed=42):
"""Generate broadband noise - 59 random-phase tones (Chi 2005 Section III.B.2)."""
rng = np.random.default_rng(seed)
freqs = np.logspace(np.log2(135), np.log2(7465), 59, base=2.0)
phases = rng.uniform(0, 2 * np.pi, 59)

t = np.arange(int(duration * sr)) / sr
signal = sum(np.sin(2 * np.pi * f * t + p) for f, p in zip(freqs, phases))
return 0.5 * signal / np.max(np.abs(signal))


def generate_harmonic_complex(
f0=80, duration=DURATION, sr=SR, phase_type="in_phase", seed=42
):
"""Generate harmonic complex F0=80 Hz (Chi 2005 Section III.B.3)."""
rng = np.random.default_rng(seed)
t = np.arange(int(duration * sr)) / sr
signal = np.zeros_like(t)

for n in range(1, 51):
freq = f0 * n
if freq >= sr / 2:
break
phase = 0.0 if phase_type == "in_phase" else rng.uniform(0, 2 * np.pi)
signal += np.sin(2 * np.pi * freq * t + phase)

return 0.5 * signal / np.max(np.abs(signal))


def generate_moving_ripple(
rate,
scale,
duration=DURATION,
sr=SR,
mod_depth=0.9,
f0=1000,
bandwidth=5.3,
df=1 / 16,
):
"""
Generate moving ripple (spectrotemporally modulated noise).

Matches MATLAB mvripfft function parameters.

Args:
rate: Temporal modulation rate ω (Hz), negative = downward
scale: Spectral modulation scale Ω (cycles/octave)
duration: Duration in seconds
sr: Sample rate
mod_depth: Modulation depth Am (0-1), default 0.9
f0: Center frequency (Hz), default 1000
bandwidth: Bandwidth in octaves, default 5.3
df: Frequency spacing in octaves, default 1/16
"""
# Frequency axis (log-spaced, matching MATLAB)
n_freqs = int(bandwidth / df)
x = np.linspace(-bandwidth / 2, bandwidth / 2, n_freqs) # octaves from f0
freqs = f0 * (2**x)

# Time axis
n_samples = int(duration * sr)
t = np.arange(n_samples) / sr

# Random phases for each frequency component
rng = np.random.default_rng()
phases = rng.uniform(0, 2 * np.pi, n_freqs)

# Generate ripple: sum of modulated sinusoids
signal = np.zeros(n_samples)
for i, (freq, xi, phi) in enumerate(zip(freqs, x, phases)):
# Ripple envelope: 1 + Am * sin(2π * Ω * x + 2π * ω * t)
envelope = 1 + mod_depth * np.sin(2 * np.pi * scale * xi + 2 * np.pi * rate * t)
carrier = np.sin(2 * np.pi * freq * t + phi)
signal += envelope * carrier

return signal / np.max(np.abs(signal))


def generate_ripple_set(output_dir, rates=None, scales=None, duration=DURATION, sr=SR):
"""
Generate full set of ripple stimuli.

Default: 10 rates × 6 scales = 60 ripples (matching MATLAB script)
"""
if rates is None:
rates = [-32, -16, -8, -4, -2, 2, 4, 8, 16, 32]
if scales is None:
scales = [0.25, 0.5, 1, 2, 4, 8]

output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

print(
f"Generating {len(rates)} × {len(scales)} = {len(rates) * len(scales)} ripples..."
)

counter = 1
for rate in rates:
for scale in scales:
ripple = generate_moving_ripple(rate, scale, duration, sr)
filename = f"ripple_{counter:02d}_R{rate:.2f}_S{scale:.2f}.wav"
sf.write(output_dir / filename, ripple, sr)
counter += 1

print(f"Saved {counter - 1} ripples to {output_dir}")


# === Save functions ===


def save_three_tones(output_dir, duration=DURATION, sr=SR):
"""Save three test tones."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

for freq, audio in zip([250, 1000, 4000], generate_three_tones(duration, sr)):
sf.write(output_dir / f"tone_{freq}Hz.wav", audio, sr)


def save_noise(output_dir, duration=DURATION, sr=SR, seed=42):
"""Save broadband noise."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sf.write(
output_dir / "broadband_noise.wav",
generate_broadband_noise(duration, sr, seed),
sr,
)


def save_harmonic_complexes(output_dir, f0=80, duration=DURATION, sr=SR, seed=42):
"""Save harmonic complexes (in-phase and random-phase)."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

for phase_type in ["in_phase", "random_phase"]:
audio = generate_harmonic_complex(
f0, duration, sr, phase_type.replace("_phase", ""), seed
)
sf.write(
output_dir / f"harmonic_complex_F0{int(f0)}_{phase_type}.wav", audio, sr
)
Loading