Skip to content

Small fixes, backend robustness, and bootstrap speedups#55

Merged
remlapmot merged 32 commits into
mainfrom
devel-2026-06-09
Jun 16, 2026
Merged

Small fixes, backend robustness, and bootstrap speedups#55
remlapmot merged 32 commits into
mainfrom
devel-2026-06-09

Conversation

@remlapmot

Copy link
Copy Markdown
Contributor

With the last PR and this one pySEQTarget (using glum and scikit-survival) is now faster than SEQTaRget (on my machine at least)!

This,

Correctnes

  • Bootstrap weights under weight_preexpansion were corrupted when a subject was sampled more than once — replicate copies collapsed into one (id, trial) group and their cumulative weights interleaved, biasing bootstrap CIs in the default weighted setup. Each copy now accumulates its own weight.
  • selection_random with non-[0,1] treatment levels silently dropped the entire control arm (an operator-precedence bug). Fixed.
  • Marginal hazard-ratio g-formula now takes the first-event row in the survival reduction; the old logic recorded single events as censored and inflated HR variance.
  • A batch of crash fixes: time_varying_cols/fixed_cols=None, collect() before fit(), retrieve_data("…switches") (Python-2 has_key), short excused_colnames, and to_md() reports for weighted-ITT / offloaded models.

Backend robustness

  • glum and jax models are now picklable, so the existing alternative GLM backends work end-to-end with offload and parallel (previously they crossed the process/disk boundary fine in serial only).
  • Froze categorical levels into the glum pickle reference frame so models with string covariates round-trip correctly.
  • Offload now stores all weight models under their real attribute names (some were being missed).

Performance

  • Skip refitting weight models on every bootstrap replicate under weight_preexpansion — they're identical, verified bit-for-bit. The biggest win for weighted analyses.
  • Parallelize the bootstrap hazard simulation, and ship data to parallel workers once per worker instead of per task.
  • Share the polars→pandas conversion across weight-model predictions.

Behaviour change

  • Hazard estimation now raises for method="dose-response" rather than silently returning HR ≈ 1.

remlapmot and others added 27 commits June 9, 2026 11:52
The hazard g-formula reduced the simulated counterfactual grid to one survival row per (id, trial) with `cum_sum(outcome) <= 1` then `.last()`. Because outcomes are simulated independently at every follow-up row, that keeps post-event rows (the cumulative count stays at 1 until a second event) and `.last()` returns the final follow-up row, so single events were silently recorded as censored. On the short-course data this dropped ~99% of simulated events (28 vs 2156) and biased survivors late, inflating the marginal-HR variance ~8x and shifting the mean relative to SEQTaRget (R), which correctly takes the first-event row.

Extract `_truncate_to_first_event()`: keep rows whose cumulative event count strictly before the current row is 0 (`cum_sum(x) - x == 0`), then `.last()` to get the first-event row (or the max-follow-up row when there is no event). Apply to both the plain and competing-event branches.

Add tests/test_hazard_truncation.py (per-pattern correctness, no-events-dropped invariant, per-(id, trial) grouping, regression contrast vs the old idiom). Fix the stale `_safe_predict` docstring: it raises on NaN rather than imputing 0.5.
The hazard ratio is computed by a g-formula Monte-Carlo simulation that needs the fitted outcome model, so running the bootstrap across a process pool requires those models to survive a pickle roundtrip. glum's _GlumFit
held a patsy DesignInfo, which raises NotImplementedError on pickle (patsy #26), so parallel=True and offload=True both crashed under glm_package="glum".

Three fixes make the process-pool route work:

- _GlumFit now records the formula and a small reference frame and rebuilds
  its DesignInfo in __setstate__, instead of pickling the DesignInfo. Safe here because _cast_categories freezes categorical level order and the models use only stateless transforms (precomputed squares, explicit-knot splines). Roundtrip preserves params, bse, and predictions exactly.

- SEQuential.__getstate__ drops the glum-only _patsy_design_cache, which also holds (unpicklable) DesignInfo objects and otherwise rode along when the object crossed a process boundary.

- _bootstrap_worker now calls the raw, undecorated fit body via __wrapped__ rather than the @bootstrap_loop-wrapped method. Going through the wrapper re-entered bootstrap_loop and returned [model_dict] (a list) instead of model_dict, crashing the hazard/survival consumers that index outcome_model[i]["outcome"]. This was a pre-existing bug affecting both the statsmodels and glum backends.

parallel=True now bit-matches the serial hazard ratio + CI for both backends; offload=True + glum and fit(parallel=True) + glum both work.

Tests:
- test_glum.py: _GlumFit pickle roundtrip preserves params and predictions.
- test_parallel.py: parallel hazard matches serial (statsmodels and glum).
The per-replicate hazard step (g-formula simulation + Cox fit) is GIL-bound — the patsy design-matrix build dominates — so threads can't speed it up. With glum models now picklable, run it across a process pool when parallel=True instead.

_calculate_hazard_single gains a parallel branch (_parallel_boot_log_hrs) that submits each replicate to a ProcessPoolExecutor. A pool initializer ships the analysis frame (via the offloader ref) and a slimmed SEQuential copy (DT/data nulled, fitted models kept) once per worker process, so each task carries only small integers. The serial body is factored into _one_boot_log_hr and shared by both paths; the RNG is rebuilt from seed + sample_idx + 1 exactly as before, so results are bit-identical to the serial loop.

To keep the pooled models cheap to ship, _GlumFit no longer pickles its full design matrix (_X_design can be ~100s of MB): __getstate__ caches the small coefficient covariance and stores the observation count, then drops _X_design. bse/summary still work after unpickle; predict never needed it. This also lightens offload and fit(parallel=True).

Benchmark (496k rows, 20 bootstraps, glum, 8 cores): hazard 73.3s -> 40.8s (1.80x), total 137.2s -> 73.4s (1.87x), hazard ratio bit-identical.

Tests: extend the _GlumFit pickle test to confirm bse survives the _X_design drop. test_parallel_hazard_matches_serial now exercises the parallel hazard path for both backends.
_fit_numerator/_fit_denominator store the weight models with
`self.numerator_model = fits`, overwriting on every fit. In a serial bootstrap the replicates run in-process, so after bootstrapping `self.numerator_model`/`self.denominator_model` hold the last resampled replicate's models — and the numerator/denominator summaries then report a bootstrap replicate's observation count instead of the main fit's, shifting with bootstrap conditions. (The parallel path was unaffected, since replicates run in worker copies, so it already showed the main fit.)

bootstrap_loop now snapshots the main-fit weight-model attributes (numerator/denominator/cense/visit models and weight_stats) after the main fit and restores them after the replicate loop. Replicates still refit and use their own weights during fitting; only the stored models used for display are pinned to the main fit, so serial and parallel agree. Estimates are unchanged.

Add tests/test_weight_main_fit_preserved.py (statsmodels + glum, serial and parallel).
Under method="censoring" the verbose output now prints how many expanded rows enter the outcome model (switch != 1) versus how many are artificially censored, alongside the total. This makes the count line up with implementations that report only the modelled rows (e.g. Stata seqtte), where the difference is just the retained censored rows. ITT prints no split (no artificial censoring). Add tests/test_verbose_censored_counts.py.
By joining weights on a recovered original-ID key instead of collapsing the resampled IDs, so each replicate copy of a multiply-sampled subject accumulates its own weight product.
It previously silently dropped all sampled control trials when treatment_level[0] != 0
To empty lists in the constructor so omitting these documented-Optional arguments no longer crashes
So that switch diagnostics are retrievable, and add the missing f-prefix to its error message
By deriving outcome and compevent model lists under a single guard with a None fallback
And raise a clear error for missing columns instead of a confusing polars TypeError
By rebuilding the patsy design info from formula and reference frame on unpickle so glm_package jax works with offload and parallel bootstrap
Instead of silently simulating identical arms and returning HR near 1
So models fit on plain string covariates survive offload and parallel round-trips
…expansion

By caching the main fit's predicted weight frame, since the pre-expansion data is never resampled and the refits were bit-identical
And make SEQoutput.summary load offloaded refs and tolerate absent model lists, and drop the duplicate _DT parquet write in the serial bootstrap path
…orkers once per worker

This is via a pool initializer instead of pickling them into every task
…d cense/visit weight predictions

Instead of reconverting the same rows for every model
So reports survive absent weight models and offloaded refs, fix the summary and retrieve_data default annotations, and name all three followup flags in the param checker error.
@remlapmot remlapmot requested a review from ryan-odea June 15, 2026 09:22
Because the GHA runner has no GPU these tests don't do anything different to the macOS tests.

@ryan-odea ryan-odea left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this overall looks good. There are some PR/bug fix specific comments which could be removed prior to the merge, e.g.

# Parentheses matter: `|` binds tighter than `!=`, so without them
# this parses as `(is_in | col) != level`, which silently drops
# every sampled control trial when treatment_level[0] != 0.

@remlapmot

Copy link
Copy Markdown
Contributor Author

Have removed that comment.

@remlapmot

Copy link
Copy Markdown
Contributor Author

I had a flaky failing test on my M3 Mac today, so switched that test tolerance.

@remlapmot remlapmot merged commit 3749af5 into main Jun 16, 2026
6 checks passed
@remlapmot remlapmot deleted the devel-2026-06-09 branch June 16, 2026 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants