A post-hoc cascade adapter that extracts explicit low-rank spatial representations from the residuals of any frozen first-stage predictor, yielding closed-form covariance estimation and uncertainty quantification.
Requires Python 3.10+, Conda, and (optionally) CUDA 12.8+ for GPU.
git clone https://github.com/STLABTW/spatial-adapter.git
cd spatial-adapter
make conda-env # recommended: conda env + C++ extensions
# or
pip install -e ".[all]" # pip-onlyFor Blackwell / sm_120 GPUs (e.g. RTX 5070 Ti):
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128 --force-reinstallHides the (τ₁, τ₂) grid search behind a single call: runs a short Optuna sweep, refits at the best weights, returns the trained adapter plus the chosen τ's and the trials dataframe.
import torch
from spatial_adapter import SpatialAdapter, TrendModel
trend = TrendModel(num_continuous_features=5, hidden_layer_sizes=[], n_locations=100)
result = SpatialAdapter.fit_tuned(
trend, train_loader,
val_cont=val_X, val_y=val_Y, locs=station_coords,
device=torch.device("cpu"),
latent_dim=10,
n_trials=10, # per-seed Optuna budget (defaults to 10)
seed=42,
criterion="auto", # see table below
)
result.adapter.reconstruct(val_X, val_Y) # trained; ready to use
result.tau1, result.tau2 # best weights chosen by Optuna
result.trials # pandas DataFrame of all trialsValidation criterion — what Optuna minimizes/maximizes over (τ₁, τ₂):
criterion |
Task | Direction | What it measures |
|---|---|---|---|
"auto" (default) |
any | — | regression → rmse, binary → accuracy |
"rmse" |
regression | min | point-prediction RMSE on val |
"accuracy" |
binary | max | classification accuracy on val |
"auc" |
binary | max | ROC AUC on val |
"cov_frob" |
regression | min | relative Frobenius covariance error, ||Σ̂_pred − Σ̂_obs||_F / ||Σ̂_obs||_F |
"sv_score" |
regression | min | weighted L² distance between empirical semivariograms of val and prediction (gstools) |
For the paper's KAUST Table 4 ablation (RMSE vs CovFrob vs SV_score
selection), use these criteria directly. For domain-specific ground-truth
covariance (e.g. a known Matérn model), fall back to AdapterTuner with a
custom evaluate_fn.
Bypasses tuning when you already have (τ₁, τ₂) or want direct control of the ADMM loop.
import torch
from spatial_adapter import (
SpatialAdapter, SpatialAdapterConfig,
SpatialBasisLearner, TrendModel,
)
trend = TrendModel(num_continuous_features=5, hidden_layer_sizes=[], n_locations=100)
basis = SpatialBasisLearner(num_locations=100, latent_dim=10)
adapter = SpatialAdapter(
trend, basis, train_loader,
val_cont=val_X, val_y=val_Y, locs=station_coords,
config=SpatialAdapterConfig(),
device=torch.device("cpu"),
tau1=1.0, tau2=1.0,
)
adapter.pretrain_trend(epochs=5)
adapter.init_basis_dense()
adapter.run() # ADMM optimizationEnd-to-end scripts, configs, and notebooks live under examples/.
To reproduce the paper tables (data download, config → table mapping,
expected output) see examples/experiments/README.md.
spatial_adapter/— core library:models/,data/,metrics.py,prediction.py,tuning/,utils/, andcpp_extensions/(pybind11 C++ kernels).examples/— runnable experiments (experiments/) and baselines (baselines/stdk/).tests/— pytest suite; seemake test.data/— external datasets (git-ignored; seedata/README.md).
make conda-env # set up environment
make build-cpp # build C++ extensions
make test # run tests
make test-cov # tests with HTML coverage@misc{wang2026spatialadapterstructuredspatial,
title={Spatial Adapter: Structured Spatial Decomposition and Closed-Form Covariance for Frozen Predictors},
author={Wen-Ting Wang and Wei-Ying Wu and Hao-Yun Huang and Xuan-Chun Wang},
year={2026},
eprint={2605.11394},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={https://arxiv.org/abs/2605.11394},
}MIT — see LICENSE.