Skip to content

Speed up warm CatPred inference path#41

Merged
Vedasheersh merged 1 commit into
maranasgroup:mainfrom
Vedasheersh:codex/inference-speedup
Jun 14, 2026
Merged

Speed up warm CatPred inference path#41
Vedasheersh merged 1 commit into
maranasgroup:mainfrom
Vedasheersh:codex/inference-speedup

Conversation

@Vedasheersh

@Vedasheersh Vedasheersh commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Add a warm in-process inference path that reuses loaded checkpoint models/scalers instead of spawning a fresh Python process for each request.
  • Preserve exact MVE uncertainty component outputs without requiring per-model ensemble prediction columns.
  • Route local and Modal inference through the fast path by default, while keeping subprocess compatibility via CATPRED_LOCAL_SUBPROCESS=1.
  • Add a benchmark CLI and focused tests for routing, model-cache reuse, uncertainty postprocessing, and ESM batching/cache behavior.
  • Add Modal GPU smoke support and persist ESM/Torch caches on the Modal checkpoint volume; ESM cache entries are saved as CPU tensors so CPU and GPU workers can both reuse them.
  • Batch uncached unique ESM sequences before dataset construction. Defaults are conservative on CPU (CATPRED_ESM_BATCH_SIZE=1) and batched on GPU (4), with recursive OOM fallback and an env override.

Benchmarks

Local macOS CPU, conda env esm, 2-row demo inputs, production checkpoints, 3 measured runs with 1 in-process warmup:

Parameter Legacy subprocess median Warm in-process median Median speedup Latency reduction
kcat 2.6674 s 0.2906 s 9.18x 89.10%
km 2.9151 s 0.3042 s 9.58x 89.57%
ki 2.5334 s 0.0614 s 41.24x 97.57%

Output equivalence vs legacy subprocess was verified for predictions and uncertainty fields:

  • kcat: max abs diff 1.11e-16
  • km: max abs diff 3.89e-16
  • ki: max abs diff 2.22e-16

Additional ESM-cache checks:

  • Fresh local CPU ESM-cache run for 2-row kcat, batch size 1: 7.733 s, matching prior subprocess output exactly up to existing SD_aleatoric noise (1.11e-16).
  • Fresh local CPU ESM-cache run for 2-row kcat, batch size 4: 10.564 s; this was slower for short CPU rows, so CPU defaults to 1.
  • Fresh Modal GPU smoke for 2-row kcat: Tesla T4, cuda_available=true, row_count=2, approximately 16 s, and wrote 2 ESM cache entries.

Modal GPU smoke verification:

  • kcat, km, and ki each completed 1-row inference on Tesla T4 with cuda_available=true.
  • Returned exact uncertainty component columns such as *_mve_uncal_aleatoric_var and *_mve_uncal_epistemic_var.

Tests

  • PYTHONDONTWRITEBYTECODE=1 python3 -m unittest tests/test_inference_fast_path.py tests/test_postprocess_predictions.py tests/test_esm_batching.py -v
  • PYTHONDONTWRITEBYTECODE=1 conda run -n esm python -m unittest tests/test_postprocess_predictions.py tests/test_esm_batching.py -v
  • PYTHONDONTWRITEBYTECODE=1 python3 -m py_compile catpred/args.py catpred/data/cache_utils.py catpred/data/esm_utils.py catpred/data/utils.py catpred/inference/__init__.py catpred/inference/backends.py catpred/inference/service.py catpred/train/make_predictions.py catpred/train/predict.py catpred/uncertainty/uncertainty_estimator.py catpred/uncertainty/uncertainty_predictor.py modal_app.py scripts/benchmark_inference.py tests/test_inference_fast_path.py tests/test_postprocess_predictions.py tests/test_esm_batching.py
  • PYTHONDONTWRITEBYTECODE=1 conda run -n esm python scripts/benchmark_inference.py --help

Notes

This now covers the first deeper ESM pass: persistent caches, CPU/GPU-shareable cache values, unique-sequence materialization, GPU batching, and OOM fallback. A larger future step would be a dedicated long-lived ESM service/worker for high-throughput API deployments.

@Vedasheersh Vedasheersh merged commit e5b9b5c into maranasgroup:main Jun 14, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant