Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.DS_Store
.DS_Store
.idea/
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dev = [

[project.optional-dependencies]
testing = [
"pytest>=8.3.2"
"pytest>=8.3.2",
"ipython>=8.29.0",
]

[build-system]
Expand Down
123 changes: 113 additions & 10 deletions src/bayesmbar/bayesmbar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Literal
from typing import Literal, Union

import numpy as np
from numpy.typing import NDArray
Expand All @@ -16,6 +16,7 @@
fmin_newton,
_compute_log_likelihood_of_dF,
_compute_log_likelihood_of_F,
_compute_log_likelihood_of_F_fast,
_compute_loss_likelihood_of_dF,
)

Expand All @@ -39,6 +40,9 @@ def __init__(
verbose: bool = True,
random_seed: int = 0,
method: Literal["Newton", "L-BFGS-B"] = "Newton",
elbo_samples: Union[int, Literal["auto"]] = "auto",
early_stopping_patience: int = 100,
early_stopping_tol: float = 1e-4,
) -> None:
"""
Args:
Expand All @@ -53,11 +57,16 @@ def __init__(
optimize_steps (int, optional): Number of optimization steps used to learn the hyperparameters when normal priors are used. Defaults to 10000.
verbose (bool, optional): Whether to print the progress bar for the optimization and sampling. Defaults to True.
random_seed (int, optional): Random seed. Defaults to 0.
elbo_samples (int or "auto", optional): Number of samples used in ELBO estimation.
If "auto", automatically determined based on problem size. Defaults to "auto".
early_stopping_patience (int, optional): Number of steps without improvement before stopping optimization. Defaults to 100.
early_stopping_tol (float, optional): Minimum improvement required to reset patience counter. Defaults to 1e-4.

"""

self._energy = jnp.float64(energy)
self._num_conf = jnp.int32(num_conf)
self._log_num_conf = jnp.log(self._num_conf) # Pre-compute for efficiency

self._prior = prior
self._mean_name = mean
Expand All @@ -69,13 +78,27 @@ def __init__(
self._sample_size = sample_size
self._warmup_steps = warmup_steps
self._optimize_steps = optimize_steps
self._early_stopping_patience = early_stopping_patience
self._early_stopping_tol = early_stopping_tol

self._verbose = verbose
self._method = method
self._rng_key = jax.random.PRNGKey(random_seed)

self._m = self._energy.shape[0]
self._n = self._energy.shape[1]

# Validate and set elbo_samples
if elbo_samples == "auto":
self._elbo_samples = _compute_optimal_elbo_samples(
self._m, self._n, verbose
)
elif isinstance(elbo_samples, int) and elbo_samples > 0:
self._elbo_samples = elbo_samples
else:
raise ValueError(
f"elbo_samples must be 'auto' or a positive integer, got {elbo_samples!r}"
)

# We first compute the mode estimate based on the likelihood
# because it is used in both the uniform and normal priors.
Expand Down Expand Up @@ -138,9 +161,11 @@ def logdensity(dF):
_data = {
"energy": self._energy,
"num_conf": self._num_conf,
"log_num_conf": self._log_num_conf,
"dF_mean_ll": self._dF_mean_ll,
"dF_prec_ll": self._dF_prec_ll,
"state_cv": self._state_cv,
"elbo_samples": self._elbo_samples,
}

## mean function of the prior
Expand Down Expand Up @@ -174,26 +199,47 @@ def logdensity(dF):
)
raw_params = _params_to_raw(params)

## optimize the hyperparameters
## optimize the hyperparameters with early stopping
self._rng_key, subkey = random.split(self._rng_key)
optimizer = sgd(learning_rate=1e-3, momentum=0.9, nesterov=True)
opt_state = optimizer.init(raw_params)

@partial(jit, static_argnames=["mean", "kernel"])
def step(key, raw_params, opt_state, mean, kernel, data):
loss, grads = _compute_elbo_loss(key, raw_params, mean, kernel, data)
@partial(jit, static_argnames=["mean", "kernel", "elbo_samples"])
def step(key, raw_params, opt_state, mean, kernel, data, elbo_samples):
loss, grads = _compute_elbo_loss(key, raw_params, mean, kernel, data, elbo_samples)
update, opt_state = optimizer.update(grads, opt_state)
raw_params = optax.apply_updates(raw_params, update)
return loss, raw_params, opt_state

# Early stopping variables
best_loss = float('inf')
patience_counter = 0
best_raw_params = raw_params

for i in range(optimize_steps):
loss, raw_params, opt_state = step(
subkey, raw_params, opt_state, self.mean, self.kernel, _data
subkey, raw_params, opt_state, self.mean, self.kernel, _data, self._elbo_samples
)
self._rng_key, subkey = random.split(self._rng_key)

# Early stopping check
loss_val = float(loss)
if loss_val < best_loss - self._early_stopping_tol:
best_loss = loss_val
best_raw_params = raw_params
patience_counter = 0
else:
patience_counter += 1

if i % 100 == 0:
params = _params_from_raw(raw_params)
print(f"step: {i:>10d}, loss: {loss:10.4f}", _print_params(params))

if patience_counter >= self._early_stopping_patience:
if self._verbose:
print(f"Early stopping at step {i} (no improvement for {self._early_stopping_patience} steps)")
raw_params = best_raw_params
break

self._params = _params_from_raw(raw_params)
self._dF_mean_prior = self.mean(self._params["mean"], self._state_cv)
Expand Down Expand Up @@ -334,6 +380,62 @@ def _dF_to_F(dF, num_conf):
return F


def _compute_optimal_elbo_samples(m: int, n: int, verbose: bool = True) -> int:
"""
Automatically determine the optimal number of ELBO samples based on problem characteristics.

The optimal number balances:
1. Gradient quality: More samples = lower variance gradient estimates
2. Computational cost: Each sample requires O(m * n) likelihood computation
3. Dimensionality: dF has (m-1) dimensions, higher dimensions need more samples

Heuristics:
- Base: Scale with sqrt(m-1) to account for dimensionality
- Adjust down for large n (more configurations = more expensive per sample)
- Clamp to [32, 256] for practical bounds

Args:
m: Number of states
n: Total number of configurations
verbose: Whether to print the determined value

Returns:
Optimal number of ELBO samples
"""
dim = m - 1 # Dimensionality of dF

# Base scaling: ~4 samples per dimension, with sqrt scaling for high dimensions
# This ensures good coverage of the (m-1)-dimensional space
if dim <= 16:
base = 4 * dim # Linear for small problems
else:
base = int(16 * np.sqrt(dim)) # Sublinear for larger problems

# Adjust for computational cost: reduce samples for very large datasets
# The idea: with more configurations, each likelihood evaluation is more expensive
# but also more informative (lower variance), so we can use fewer samples
if n > 100000:
cost_factor = 0.5 # Large dataset: halve the samples
elif n > 50000:
cost_factor = 0.7
elif n > 10000:
cost_factor = 0.85
else:
cost_factor = 1.0 # Small dataset: use full samples

optimal = int(base * cost_factor)

# Clamp to reasonable bounds
# Min 32: ensures reasonable gradient quality even with momentum
# Max 256: diminishing returns beyond this
optimal = max(32, min(256, optimal))

if verbose:
print(f"Auto-determined elbo_samples={optimal} (states={m}, configs={n})")

return optimal


def _compute_loss_joint_likelihood_of_dF(dF, energy, num_conf, mean_prior, prec_prior):
"""
Compute the loss function of dF based on the joint likelihood.
Expand Down Expand Up @@ -374,9 +476,10 @@ def _compute_log_joint_likelihood_of_dF(dF, energy, num_conf, mean_prior, prec_p


@partial(value_and_grad, argnums=1)
def _compute_elbo_loss(rng_key, raw_params, mean, kernel, data):
def _compute_elbo_loss(rng_key, raw_params, mean, kernel, data, elbo_samples):
energy = data["energy"]
num_conf = data["num_conf"]
log_num_conf = data.get("log_num_conf", jnp.log(num_conf)) # Use pre-computed if available
state_cv = data["state_cv"]
dF_prec_ll = data["dF_prec_ll"]
dF_mean_ll = data["dF_mean_ll"]
Expand All @@ -387,10 +490,10 @@ def _compute_elbo_loss(rng_key, raw_params, mean, kernel, data):
mu_prop, cov_prop = _compute_proposal_dist(
mean_prior, cov_prior, dF_mean_ll, dF_prec_ll
)
dFs = random.multivariate_normal(rng_key, mu_prop, cov_prop, shape=(1024,))
dFs = random.multivariate_normal(rng_key, mu_prop, cov_prop, shape=(elbo_samples,))
Fs = jnp.concatenate([jnp.zeros((dFs.shape[0], 1)), dFs], axis=1)
elbo = jax.vmap(_compute_log_likelihood_of_F, in_axes=(0, None, None))(
Fs, energy, num_conf
elbo = jax.vmap(_compute_log_likelihood_of_F_fast, in_axes=(0, None, None, None))(
Fs, energy, num_conf, log_num_conf
)
elbo = jnp.mean(elbo)
elbo = elbo - _compute_kl_divergence(mu_prop, cov_prop, mean_prior, cov_prior)
Expand Down
21 changes: 21 additions & 0 deletions src/bayesmbar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,27 @@ def _compute_log_likelihood_of_F(F, energy, num_conf):
return L


def _compute_log_likelihood_of_F_fast(F, energy, num_conf, log_num_conf):
"""
Compute the log likelihood of F with pre-computed log(num_conf).

This is an optimized version that avoids recomputing log(num_conf) on each call.

Arguments:
F (jnp.ndarray): Free energies of the states
energy (jnp.ndarray): Energy matrix
num_conf (jnp.ndarray): Number of configurations in each state
log_num_conf (jnp.ndarray): Pre-computed log(num_conf)

Returns:
jnp.ndarray: Log likelihood of F

"""
u = energy.T - F - log_num_conf
L = jnp.sum(num_conf * F) - logsumexp(-u, axis=1).sum()
return L


def _compute_log_likelihood_of_dF(dF, energy, num_conf):
"""
Compute the log likelihood of dF.
Expand Down
70 changes: 70 additions & 0 deletions src/test/test_bayesmbar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import numpy as np
from pytest import approx
from bayesmbar import BayesMBAR

Expand All @@ -14,6 +15,75 @@ def test_BayesMBAR(setup_mbar_data, method):
assert mbar.F_mode == approx(F_ref, abs = 1e-1)
assert mbar.F_mean == approx(F_ref, abs = 1e-1)


def test_BayesMBAR_uniform_prior_accuracy(setup_mbar_data):
"""Fast test: verify uniform prior gives correct results.

This is a quick sanity check that runs in <5s.
"""
energy, num_conf, F_ref, energy_p, F_ref_p = setup_mbar_data

mbar = BayesMBAR(
energy, num_conf,
prior='uniform',
sample_size=20, warmup_steps=10,
random_seed=42,
verbose=False,
)

# Should recover reference free energies
assert mbar.F_mode == approx(F_ref, abs=0.2)


def test_auto_elbo_samples_scaling():
"""Test that auto elbo_samples scales appropriately with problem size."""
from bayesmbar.bayesmbar import _compute_optimal_elbo_samples

# Small problem: should use ~4*dim samples
samples_small = _compute_optimal_elbo_samples(m=5, n=500, verbose=False)
assert 16 <= samples_small <= 64 # 4*4=16 for dim=4

# Medium problem
samples_medium = _compute_optimal_elbo_samples(m=20, n=10000, verbose=False)
assert 50 <= samples_medium <= 150

# Large problem with many configs: should reduce samples
samples_large = _compute_optimal_elbo_samples(m=50, n=200000, verbose=False)
assert 32 <= samples_large <= 100 # Reduced due to large n

# Verify bounds are respected
samples_min = _compute_optimal_elbo_samples(m=3, n=100, verbose=False)
assert samples_min >= 32 # Minimum bound

samples_max = _compute_optimal_elbo_samples(m=200, n=1000, verbose=False)
assert samples_max <= 256 # Maximum bound


def test_elbo_samples_validation(setup_mbar_data):
"""Test that invalid elbo_samples values raise clear errors."""
energy, num_conf, F_ref, energy_p, F_ref_p = setup_mbar_data

# Valid values should work
mbar = BayesMBAR(energy, num_conf, elbo_samples="auto", verbose=False)
assert mbar._elbo_samples > 0

mbar = BayesMBAR(energy, num_conf, elbo_samples=64, verbose=False)
assert mbar._elbo_samples == 64

# Invalid values should raise ValueError
with pytest.raises(ValueError, match="must be 'auto' or a positive integer"):
BayesMBAR(energy, num_conf, elbo_samples=0, verbose=False)

with pytest.raises(ValueError, match="must be 'auto' or a positive integer"):
BayesMBAR(energy, num_conf, elbo_samples=-10, verbose=False)

with pytest.raises(ValueError, match="must be 'auto' or a positive integer"):
BayesMBAR(energy, num_conf, elbo_samples=3.14, verbose=False)

with pytest.raises(ValueError, match="must be 'auto' or a positive integer"):
BayesMBAR(energy, num_conf, elbo_samples="invalid", verbose=False)


#results = fastmbar.calculate_free_energies_of_perturbed_states(energy_p)
#results['F'] = results['F'] - results['F'].mean()
#assert results['F'] == approx(F_ref_p, abs = 1e-1)
Expand Down