Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.DS_Store
.DS_Store
.idea/
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dev = [

[project.optional-dependencies]
testing = [
"pytest>=8.3.2"
"pytest>=8.3.2",
"ipython>=8.29.0",
]

[build-system]
Expand Down
74 changes: 59 additions & 15 deletions src/bayesmbar/bayesmbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ def __init__(
mean: Literal["constant", "linear", "quadratic"] = "constant",
kernel: Literal["SE", "Matern52", "Matern32", "RQ"] = "SE",
state_cv: np.ndarray = None,
sample_size: int = 1000,
warmup_steps: int = 500,
optimize_steps: int = 10000,
sample_size: int = None,
warmup_steps: int = None,
max_optimize_steps: int = 10000,
optimize_patience: int = 500,
optimize_min_delta: float = 1e-4,
verbose: bool = True,
random_seed: int = 0,
method: Literal["Newton", "L-BFGS-B"] = "Newton",
Expand All @@ -48,9 +50,11 @@ def __init__(
mean (str, optional): Mean function of the prior. It can be either "constant", "linear", or "quadratic". Defaults to "constant".
kernel (str, optional): Kernel function of the prior. It can be either "SE", "Matern52", "Matern32", or "RQ". Defaults to "SE".
state_cv (np.ndarray, optional): State collective variables. It is a 2D array of shape (m, d), where m is the number of states and d is the dimension of the collective variables. Defaults to None.
sample_size (int, optional): Number of samples drawn from the posterior distribution. Defaults to 1000.
warmup_steps (int, optional): Number of warmup steps used to find the step size and mass matrix of the NUTS sampler. Defaults to 500.
optimize_steps (int, optional): Number of optimization steps used to learn the hyperparameters when normal priors are used. Defaults to 10000.
sample_size (int, optional): Number of samples drawn from the posterior distribution. If None, defaults to max(1000, 100 * (m - 1)) where m is the number of states.
warmup_steps (int, optional): Number of warmup steps used to find the step size and mass matrix of the NUTS sampler. If None, defaults to max(500, sample_size // 2).
max_optimize_steps (int, optional): Maximum number of optimization steps used to learn the hyperparameters when normal priors are used. Defaults to 10000.
optimize_patience (int, optional): Number of steps without improvement before early stopping. Defaults to 500.
optimize_min_delta (float, optional): Minimum change in loss to qualify as an improvement. Defaults to 1e-4.
verbose (bool, optional): Whether to print the progress bar for the optimization and sampling. Defaults to True.
random_seed (int, optional): Random seed. Defaults to 0.

Expand All @@ -66,17 +70,29 @@ def __init__(
if state_cv is not None:
self._state_cv = state_cv[1:]

self._sample_size = sample_size
self._warmup_steps = warmup_steps
self._optimize_steps = optimize_steps
self._m = self._energy.shape[0]
self._n = self._energy.shape[1]

# Adaptive sample_size: scale with dimension (m-1)
if sample_size is None:
self._sample_size = max(1000, 100 * (self._m - 1))
else:
self._sample_size = sample_size

# Adaptive warmup_steps: scale with sample_size
if warmup_steps is None:
self._warmup_steps = max(500, self._sample_size // 2)
else:
self._warmup_steps = warmup_steps

self._max_optimize_steps = max_optimize_steps
self._optimize_patience = optimize_patience
self._optimize_min_delta = optimize_min_delta

self._verbose = verbose
self._method = method
self._rng_key = jax.random.PRNGKey(random_seed)

self._m = self._energy.shape[0]
self._n = self._energy.shape[1]

# We first compute the mode estimate based on the likelihood
# because it is used in both the uniform and normal priors.
# The mode estimate based on the likelihood is the solution to the MBAR equation.
Expand Down Expand Up @@ -174,7 +190,7 @@ def logdensity(dF):
)
raw_params = _params_to_raw(params)

## optimize the hyperparameters
## optimize the hyperparameters with early stopping
self._rng_key, subkey = random.split(self._rng_key)
optimizer = sgd(learning_rate=1e-3, momentum=0.9, nesterov=True)
opt_state = optimizer.init(raw_params)
Expand All @@ -186,15 +202,43 @@ def step(key, raw_params, opt_state, mean, kernel, data):
raw_params = optax.apply_updates(raw_params, update)
return loss, raw_params, opt_state

for i in range(optimize_steps):
# Early stopping state
best_loss = float("inf")
best_raw_params = raw_params
steps_without_improvement = 0

for i in range(self._max_optimize_steps):
loss, raw_params, opt_state = step(
subkey, raw_params, opt_state, self.mean, self.kernel, _data
)
self._rng_key, subkey = random.split(self._rng_key)
if i % 100 == 0:

# Check for improvement
loss_val = float(loss)
if loss_val < best_loss - self._optimize_min_delta:
best_loss = loss_val
best_raw_params = raw_params
steps_without_improvement = 0
else:
steps_without_improvement += 1

if self._verbose and i % 100 == 0:
params = _params_from_raw(raw_params)
print(f"step: {i:>10d}, loss: {loss:10.4f}", _print_params(params))

# Early stopping check
if steps_without_improvement >= self._optimize_patience:
if self._verbose:
print(
f"Early stopping at step {i}: no improvement for {self._optimize_patience} steps"
)
break

# Use best parameters found
raw_params = best_raw_params
if self._verbose:
print(f"Optimization finished. Best loss: {best_loss:.4f}")

self._params = _params_from_raw(raw_params)
self._dF_mean_prior = self.mean(self._params["mean"], self._state_cv)
self._dF_cov_prior = self.kernel(
Expand Down
124 changes: 118 additions & 6 deletions src/test/test_bayesmbar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from pytest import approx
import numpy as np
from bayesmbar import BayesMBAR


@pytest.mark.parametrize("method", ["Newton", "L-BFGS-B"])
def test_BayesMBAR(setup_mbar_data, method):
energy, num_conf, F_ref, energy_p, F_ref_p = setup_mbar_data
Expand All @@ -11,12 +13,122 @@ def test_BayesMBAR(setup_mbar_data, method):
verbose=True,
method=method,
)
assert mbar.F_mode == approx(F_ref, abs = 1e-1)
assert mbar.F_mean == approx(F_ref, abs = 1e-1)

#results = fastmbar.calculate_free_energies_of_perturbed_states(energy_p)
#results['F'] = results['F'] - results['F'].mean()
#assert results['F'] == approx(F_ref_p, abs = 1e-1)
assert mbar.F_mode == approx(F_ref, abs=1e-1)
assert mbar.F_mean == approx(F_ref, abs=1e-1)

# results = fastmbar.calculate_free_energies_of_perturbed_states(energy_p)
# results['F'] = results['F'] - results['F'].mean()
# assert results['F'] == approx(F_ref_p, abs = 1e-1)


class TestAdaptiveParameters:
"""Tests for adaptive sample_size, warmup_steps, and early stopping parameters."""

def test_adaptive_sample_size_small_problem(self, setup_mbar_data):
"""Test that sample_size defaults to 1000 for small problems (m-1 < 10)."""
energy, num_conf, F_ref, _, _ = setup_mbar_data
m = energy.shape[0] # 5 states
mbar = BayesMBAR(energy, num_conf, verbose=False)

# For 5 states: max(1000, 100 * 4) = 1000
expected_sample_size = max(1000, 100 * (m - 1))
assert mbar._sample_size == expected_sample_size
assert mbar._sample_size == 1000

def test_adaptive_sample_size_large_problem(self):
"""Test that sample_size scales with dimension for larger problems."""
# Create a larger problem with 15 states
m = 15
n = 100
np.random.seed(42)
energy = np.random.randn(m, n)
num_conf = np.array([n // m] * m)

mbar = BayesMBAR(energy, num_conf, verbose=False)

# For 15 states: max(1000, 100 * 14) = 1400
expected_sample_size = max(1000, 100 * (m - 1))
assert mbar._sample_size == expected_sample_size
assert mbar._sample_size == 1400

def test_adaptive_warmup_steps_default(self, setup_mbar_data):
"""Test that warmup_steps defaults to max(500, sample_size // 2)."""
energy, num_conf, _, _, _ = setup_mbar_data
mbar = BayesMBAR(energy, num_conf, verbose=False)

# sample_size=1000, so warmup = max(500, 500) = 500
expected_warmup = max(500, mbar._sample_size // 2)
assert mbar._warmup_steps == expected_warmup

def test_adaptive_warmup_steps_scales_with_sample_size(self):
"""Test that warmup_steps scales when sample_size is large."""
m = 15
n = 100
np.random.seed(42)
energy = np.random.randn(m, n)
num_conf = np.array([n // m] * m)

mbar = BayesMBAR(energy, num_conf, verbose=False)

# sample_size=1400, so warmup = max(500, 700) = 700
expected_warmup = max(500, mbar._sample_size // 2)
assert mbar._warmup_steps == expected_warmup
assert mbar._warmup_steps == 700

def test_explicit_sample_size_overrides_adaptive(self, setup_mbar_data):
"""Test that explicit sample_size overrides the adaptive default."""
energy, num_conf, _, _, _ = setup_mbar_data
explicit_sample_size = 2000

mbar = BayesMBAR(energy, num_conf, sample_size=explicit_sample_size, verbose=False)

assert mbar._sample_size == explicit_sample_size

def test_explicit_warmup_steps_overrides_adaptive(self, setup_mbar_data):
"""Test that explicit warmup_steps overrides the adaptive default."""
energy, num_conf, _, _, _ = setup_mbar_data
explicit_warmup = 300

mbar = BayesMBAR(energy, num_conf, warmup_steps=explicit_warmup, verbose=False)

assert mbar._warmup_steps == explicit_warmup

def test_early_stopping_parameters_defaults(self, setup_mbar_data):
"""Test that early stopping parameters have correct defaults."""
energy, num_conf, _, _, _ = setup_mbar_data
mbar = BayesMBAR(energy, num_conf, verbose=False)

assert mbar._max_optimize_steps == 10000
assert mbar._optimize_patience == 500
assert mbar._optimize_min_delta == 1e-4

def test_early_stopping_parameters_custom(self, setup_mbar_data):
"""Test that custom early stopping parameters are set correctly."""
energy, num_conf, _, _, _ = setup_mbar_data
mbar = BayesMBAR(
energy,
num_conf,
max_optimize_steps=5000,
optimize_patience=200,
optimize_min_delta=1e-5,
verbose=False,
)

assert mbar._max_optimize_steps == 5000
assert mbar._optimize_patience == 200
assert mbar._optimize_min_delta == 1e-5

def test_combined_explicit_and_adaptive(self, setup_mbar_data):
"""Test mixing explicit and adaptive parameters."""
energy, num_conf, _, _, _ = setup_mbar_data

# Set explicit sample_size, let warmup_steps be adaptive
mbar = BayesMBAR(energy, num_conf, sample_size=1500, verbose=False)

assert mbar._sample_size == 1500
# warmup should adapt to the explicit sample_size
assert mbar._warmup_steps == max(500, 1500 // 2)
assert mbar._warmup_steps == 750

# import pytest
# from pytest import approx
Expand Down