diff --git a/.gitignore b/.gitignore index f875d65..f561942 100755 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -.DS_Store \ No newline at end of file +.DS_Store +.idea/ diff --git a/pyproject.toml b/pyproject.toml index 91d9a8c..9243b08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dev = [ [project.optional-dependencies] testing = [ - "pytest>=8.3.2" + "pytest>=8.3.2", + "ipython>=8.29.0", ] [build-system] diff --git a/src/bayesmbar/bayesmbar.py b/src/bayesmbar/bayesmbar.py index 1036bc8..694d976 100755 --- a/src/bayesmbar/bayesmbar.py +++ b/src/bayesmbar/bayesmbar.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Literal +from typing import Literal, Union import numpy as np from numpy.typing import NDArray @@ -16,6 +16,7 @@ fmin_newton, _compute_log_likelihood_of_dF, _compute_log_likelihood_of_F, + _compute_log_likelihood_of_F_fast, _compute_loss_likelihood_of_dF, ) @@ -39,6 +40,9 @@ def __init__( verbose: bool = True, random_seed: int = 0, method: Literal["Newton", "L-BFGS-B"] = "Newton", + elbo_samples: Union[int, Literal["auto"]] = "auto", + early_stopping_patience: int = 100, + early_stopping_tol: float = 1e-4, ) -> None: """ Args: @@ -53,11 +57,16 @@ def __init__( optimize_steps (int, optional): Number of optimization steps used to learn the hyperparameters when normal priors are used. Defaults to 10000. verbose (bool, optional): Whether to print the progress bar for the optimization and sampling. Defaults to True. random_seed (int, optional): Random seed. Defaults to 0. + elbo_samples (int or "auto", optional): Number of samples used in ELBO estimation. + If "auto", automatically determined based on problem size. Defaults to "auto". + early_stopping_patience (int, optional): Number of steps without improvement before stopping optimization. Defaults to 100. + early_stopping_tol (float, optional): Minimum improvement required to reset patience counter. Defaults to 1e-4. """ self._energy = jnp.float64(energy) self._num_conf = jnp.int32(num_conf) + self._log_num_conf = jnp.log(self._num_conf) # Pre-compute for efficiency self._prior = prior self._mean_name = mean @@ -69,6 +78,8 @@ def __init__( self._sample_size = sample_size self._warmup_steps = warmup_steps self._optimize_steps = optimize_steps + self._early_stopping_patience = early_stopping_patience + self._early_stopping_tol = early_stopping_tol self._verbose = verbose self._method = method @@ -76,6 +87,18 @@ def __init__( self._m = self._energy.shape[0] self._n = self._energy.shape[1] + + # Validate and set elbo_samples + if elbo_samples == "auto": + self._elbo_samples = _compute_optimal_elbo_samples( + self._m, self._n, verbose + ) + elif isinstance(elbo_samples, int) and elbo_samples > 0: + self._elbo_samples = elbo_samples + else: + raise ValueError( + f"elbo_samples must be 'auto' or a positive integer, got {elbo_samples!r}" + ) # We first compute the mode estimate based on the likelihood # because it is used in both the uniform and normal priors. @@ -138,9 +161,11 @@ def logdensity(dF): _data = { "energy": self._energy, "num_conf": self._num_conf, + "log_num_conf": self._log_num_conf, "dF_mean_ll": self._dF_mean_ll, "dF_prec_ll": self._dF_prec_ll, "state_cv": self._state_cv, + "elbo_samples": self._elbo_samples, } ## mean function of the prior @@ -174,26 +199,47 @@ def logdensity(dF): ) raw_params = _params_to_raw(params) - ## optimize the hyperparameters + ## optimize the hyperparameters with early stopping self._rng_key, subkey = random.split(self._rng_key) optimizer = sgd(learning_rate=1e-3, momentum=0.9, nesterov=True) opt_state = optimizer.init(raw_params) - @partial(jit, static_argnames=["mean", "kernel"]) - def step(key, raw_params, opt_state, mean, kernel, data): - loss, grads = _compute_elbo_loss(key, raw_params, mean, kernel, data) + @partial(jit, static_argnames=["mean", "kernel", "elbo_samples"]) + def step(key, raw_params, opt_state, mean, kernel, data, elbo_samples): + loss, grads = _compute_elbo_loss(key, raw_params, mean, kernel, data, elbo_samples) update, opt_state = optimizer.update(grads, opt_state) raw_params = optax.apply_updates(raw_params, update) return loss, raw_params, opt_state + # Early stopping variables + best_loss = float('inf') + patience_counter = 0 + best_raw_params = raw_params + for i in range(optimize_steps): loss, raw_params, opt_state = step( - subkey, raw_params, opt_state, self.mean, self.kernel, _data + subkey, raw_params, opt_state, self.mean, self.kernel, _data, self._elbo_samples ) self._rng_key, subkey = random.split(self._rng_key) + + # Early stopping check + loss_val = float(loss) + if loss_val < best_loss - self._early_stopping_tol: + best_loss = loss_val + best_raw_params = raw_params + patience_counter = 0 + else: + patience_counter += 1 + if i % 100 == 0: params = _params_from_raw(raw_params) print(f"step: {i:>10d}, loss: {loss:10.4f}", _print_params(params)) + + if patience_counter >= self._early_stopping_patience: + if self._verbose: + print(f"Early stopping at step {i} (no improvement for {self._early_stopping_patience} steps)") + raw_params = best_raw_params + break self._params = _params_from_raw(raw_params) self._dF_mean_prior = self.mean(self._params["mean"], self._state_cv) @@ -334,6 +380,62 @@ def _dF_to_F(dF, num_conf): return F +def _compute_optimal_elbo_samples(m: int, n: int, verbose: bool = True) -> int: + """ + Automatically determine the optimal number of ELBO samples based on problem characteristics. + + The optimal number balances: + 1. Gradient quality: More samples = lower variance gradient estimates + 2. Computational cost: Each sample requires O(m * n) likelihood computation + 3. Dimensionality: dF has (m-1) dimensions, higher dimensions need more samples + + Heuristics: + - Base: Scale with sqrt(m-1) to account for dimensionality + - Adjust down for large n (more configurations = more expensive per sample) + - Clamp to [32, 256] for practical bounds + + Args: + m: Number of states + n: Total number of configurations + verbose: Whether to print the determined value + + Returns: + Optimal number of ELBO samples + """ + dim = m - 1 # Dimensionality of dF + + # Base scaling: ~4 samples per dimension, with sqrt scaling for high dimensions + # This ensures good coverage of the (m-1)-dimensional space + if dim <= 16: + base = 4 * dim # Linear for small problems + else: + base = int(16 * np.sqrt(dim)) # Sublinear for larger problems + + # Adjust for computational cost: reduce samples for very large datasets + # The idea: with more configurations, each likelihood evaluation is more expensive + # but also more informative (lower variance), so we can use fewer samples + if n > 100000: + cost_factor = 0.5 # Large dataset: halve the samples + elif n > 50000: + cost_factor = 0.7 + elif n > 10000: + cost_factor = 0.85 + else: + cost_factor = 1.0 # Small dataset: use full samples + + optimal = int(base * cost_factor) + + # Clamp to reasonable bounds + # Min 32: ensures reasonable gradient quality even with momentum + # Max 256: diminishing returns beyond this + optimal = max(32, min(256, optimal)) + + if verbose: + print(f"Auto-determined elbo_samples={optimal} (states={m}, configs={n})") + + return optimal + + def _compute_loss_joint_likelihood_of_dF(dF, energy, num_conf, mean_prior, prec_prior): """ Compute the loss function of dF based on the joint likelihood. @@ -374,9 +476,10 @@ def _compute_log_joint_likelihood_of_dF(dF, energy, num_conf, mean_prior, prec_p @partial(value_and_grad, argnums=1) -def _compute_elbo_loss(rng_key, raw_params, mean, kernel, data): +def _compute_elbo_loss(rng_key, raw_params, mean, kernel, data, elbo_samples): energy = data["energy"] num_conf = data["num_conf"] + log_num_conf = data.get("log_num_conf", jnp.log(num_conf)) # Use pre-computed if available state_cv = data["state_cv"] dF_prec_ll = data["dF_prec_ll"] dF_mean_ll = data["dF_mean_ll"] @@ -387,10 +490,10 @@ def _compute_elbo_loss(rng_key, raw_params, mean, kernel, data): mu_prop, cov_prop = _compute_proposal_dist( mean_prior, cov_prior, dF_mean_ll, dF_prec_ll ) - dFs = random.multivariate_normal(rng_key, mu_prop, cov_prop, shape=(1024,)) + dFs = random.multivariate_normal(rng_key, mu_prop, cov_prop, shape=(elbo_samples,)) Fs = jnp.concatenate([jnp.zeros((dFs.shape[0], 1)), dFs], axis=1) - elbo = jax.vmap(_compute_log_likelihood_of_F, in_axes=(0, None, None))( - Fs, energy, num_conf + elbo = jax.vmap(_compute_log_likelihood_of_F_fast, in_axes=(0, None, None, None))( + Fs, energy, num_conf, log_num_conf ) elbo = jnp.mean(elbo) elbo = elbo - _compute_kl_divergence(mu_prop, cov_prop, mean_prior, cov_prior) diff --git a/src/bayesmbar/utils.py b/src/bayesmbar/utils.py index e0dccc0..baa210d 100755 --- a/src/bayesmbar/utils.py +++ b/src/bayesmbar/utils.py @@ -47,6 +47,27 @@ def _compute_log_likelihood_of_F(F, energy, num_conf): return L +def _compute_log_likelihood_of_F_fast(F, energy, num_conf, log_num_conf): + """ + Compute the log likelihood of F with pre-computed log(num_conf). + + This is an optimized version that avoids recomputing log(num_conf) on each call. + + Arguments: + F (jnp.ndarray): Free energies of the states + energy (jnp.ndarray): Energy matrix + num_conf (jnp.ndarray): Number of configurations in each state + log_num_conf (jnp.ndarray): Pre-computed log(num_conf) + + Returns: + jnp.ndarray: Log likelihood of F + + """ + u = energy.T - F - log_num_conf + L = jnp.sum(num_conf * F) - logsumexp(-u, axis=1).sum() + return L + + def _compute_log_likelihood_of_dF(dF, energy, num_conf): """ Compute the log likelihood of dF. diff --git a/src/test/test_bayesmbar.py b/src/test/test_bayesmbar.py index a2b3d2f..4af6d89 100755 --- a/src/test/test_bayesmbar.py +++ b/src/test/test_bayesmbar.py @@ -1,4 +1,5 @@ import pytest +import numpy as np from pytest import approx from bayesmbar import BayesMBAR @@ -14,6 +15,75 @@ def test_BayesMBAR(setup_mbar_data, method): assert mbar.F_mode == approx(F_ref, abs = 1e-1) assert mbar.F_mean == approx(F_ref, abs = 1e-1) + +def test_BayesMBAR_uniform_prior_accuracy(setup_mbar_data): + """Fast test: verify uniform prior gives correct results. + + This is a quick sanity check that runs in <5s. + """ + energy, num_conf, F_ref, energy_p, F_ref_p = setup_mbar_data + + mbar = BayesMBAR( + energy, num_conf, + prior='uniform', + sample_size=20, warmup_steps=10, + random_seed=42, + verbose=False, + ) + + # Should recover reference free energies + assert mbar.F_mode == approx(F_ref, abs=0.2) + + +def test_auto_elbo_samples_scaling(): + """Test that auto elbo_samples scales appropriately with problem size.""" + from bayesmbar.bayesmbar import _compute_optimal_elbo_samples + + # Small problem: should use ~4*dim samples + samples_small = _compute_optimal_elbo_samples(m=5, n=500, verbose=False) + assert 16 <= samples_small <= 64 # 4*4=16 for dim=4 + + # Medium problem + samples_medium = _compute_optimal_elbo_samples(m=20, n=10000, verbose=False) + assert 50 <= samples_medium <= 150 + + # Large problem with many configs: should reduce samples + samples_large = _compute_optimal_elbo_samples(m=50, n=200000, verbose=False) + assert 32 <= samples_large <= 100 # Reduced due to large n + + # Verify bounds are respected + samples_min = _compute_optimal_elbo_samples(m=3, n=100, verbose=False) + assert samples_min >= 32 # Minimum bound + + samples_max = _compute_optimal_elbo_samples(m=200, n=1000, verbose=False) + assert samples_max <= 256 # Maximum bound + + +def test_elbo_samples_validation(setup_mbar_data): + """Test that invalid elbo_samples values raise clear errors.""" + energy, num_conf, F_ref, energy_p, F_ref_p = setup_mbar_data + + # Valid values should work + mbar = BayesMBAR(energy, num_conf, elbo_samples="auto", verbose=False) + assert mbar._elbo_samples > 0 + + mbar = BayesMBAR(energy, num_conf, elbo_samples=64, verbose=False) + assert mbar._elbo_samples == 64 + + # Invalid values should raise ValueError + with pytest.raises(ValueError, match="must be 'auto' or a positive integer"): + BayesMBAR(energy, num_conf, elbo_samples=0, verbose=False) + + with pytest.raises(ValueError, match="must be 'auto' or a positive integer"): + BayesMBAR(energy, num_conf, elbo_samples=-10, verbose=False) + + with pytest.raises(ValueError, match="must be 'auto' or a positive integer"): + BayesMBAR(energy, num_conf, elbo_samples=3.14, verbose=False) + + with pytest.raises(ValueError, match="must be 'auto' or a positive integer"): + BayesMBAR(energy, num_conf, elbo_samples="invalid", verbose=False) + + #results = fastmbar.calculate_free_energies_of_perturbed_states(energy_p) #results['F'] = results['F'] - results['F'].mean() #assert results['F'] == approx(F_ref_p, abs = 1e-1)