Skip to content

Batch uncached ESM embeddings#42

Merged
Vedasheersh merged 1 commit into
maranasgroup:mainfrom
Vedasheersh:codex/inference-speedup
Jun 14, 2026
Merged

Batch uncached ESM embeddings#42
Vedasheersh merged 1 commit into
maranasgroup:mainfrom
Vedasheersh:codex/inference-speedup

Conversation

@Vedasheersh

Copy link
Copy Markdown
Contributor

Summary

  • Batch uncached unique ESM sequences before dataset construction instead of materializing one row at a time.
  • Reuse the existing sequence-content cache before ESM inference, and save generated cache entries as CPU tensors for CPU/GPU portability.
  • Default conservatively to CATPRED_ESM_BATCH_SIZE=1 on CPU to preserve exact prior outputs and avoid padding overhead; default to 4 on GPU for throughput, with recursive OOM fallback.
  • Add focused tests for data-loader batching, fallback deduplication, and cache-aware ESM batching.

Verification

  • Fresh local CPU ESM-cache run for 2-row kcat, batch size 1: 7.733 s, matching prior subprocess output exactly up to existing SD_aleatoric noise (1.11e-16).
  • Fresh local CPU ESM-cache run for 2-row kcat, batch size 4: 10.564 s; slower for short CPU rows, which is why CPU defaults to 1.
  • Fresh Modal GPU smoke for 2-row kcat: Tesla T4, cuda_available=true, row_count=2, approximately 16 s, and wrote 2 ESM cache entries.

Tests

  • PYTHONDONTWRITEBYTECODE=1 python3 -m unittest tests/test_inference_fast_path.py tests/test_postprocess_predictions.py tests/test_esm_batching.py -v
  • PYTHONDONTWRITEBYTECODE=1 conda run -n esm python -m unittest tests/test_postprocess_predictions.py tests/test_esm_batching.py -v
  • PYTHONDONTWRITEBYTECODE=1 python3 -m py_compile catpred/args.py catpred/data/cache_utils.py catpred/data/esm_utils.py catpred/data/utils.py catpred/inference/__init__.py catpred/inference/backends.py catpred/inference/service.py catpred/train/make_predictions.py catpred/train/predict.py catpred/uncertainty/uncertainty_estimator.py catpred/uncertainty/uncertainty_predictor.py modal_app.py scripts/benchmark_inference.py tests/test_inference_fast_path.py tests/test_postprocess_predictions.py tests/test_esm_batching.py
  • PYTHONDONTWRITEBYTECODE=1 conda run -n esm python scripts/benchmark_inference.py --help

@Vedasheersh Vedasheersh merged commit 2e70135 into maranasgroup:main Jun 14, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant