From 1bf595dc6e4f0076be5987cf70c5137dffc2e45b Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 9 Jun 2026 11:52:19 +0100 Subject: [PATCH 01/32] Fix marginal-HR g-formula: take first-event row in survival reduction The hazard g-formula reduced the simulated counterfactual grid to one survival row per (id, trial) with `cum_sum(outcome) <= 1` then `.last()`. Because outcomes are simulated independently at every follow-up row, that keeps post-event rows (the cumulative count stays at 1 until a second event) and `.last()` returns the final follow-up row, so single events were silently recorded as censored. On the short-course data this dropped ~99% of simulated events (28 vs 2156) and biased survivors late, inflating the marginal-HR variance ~8x and shifting the mean relative to SEQTaRget (R), which correctly takes the first-event row. Extract `_truncate_to_first_event()`: keep rows whose cumulative event count strictly before the current row is 0 (`cum_sum(x) - x == 0`), then `.last()` to get the first-event row (or the max-follow-up row when there is no event). Apply to both the plain and competing-event branches. Add tests/test_hazard_truncation.py (per-pattern correctness, no-events-dropped invariant, per-(id, trial) grouping, regression contrast vs the old idiom). Fix the stale `_safe_predict` docstring: it raises on NaN rather than imputing 0.5. --- pySEQTarget/analysis/_hazard.py | 69 ++++++++++------ pySEQTarget/helpers/_predict_model.py | 4 +- tests/test_hazard_truncation.py | 109 ++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 27 deletions(-) create mode 100644 tests/test_hazard_truncation.py diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index a8331bf..f8388af 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -87,6 +87,41 @@ def _calculate_hazard_single(self, data, idx=None, val=None): return _create_hazard_output(np.exp(full_log_hr), lci, uci, val, self) +def _truncate_to_first_event(tmp, id_col, event_col): + """Reduce a simulated counterfactual grid to one survival row per (id, trial). + + Keeps the FIRST row in which ``event_col`` fires (status 1 at the first event + time); if the unit never has an event it keeps the final follow-up row + (status 0, censored at max follow-up). + + Outcomes are simulated independently at every follow-up row, so a unit may + have ``event_col == 1`` at several rows. We therefore keep only rows whose + cumulative event count *strictly before* the current row is 0 — i.e. every + row up to and including the first event — and then take the last of those, + which is the first-event row (or the max-follow-up row when there is no + event). + + NOTE: the inclusive form ``cum_sum(event_col) <= 1`` is WRONG here: it + retains post-event rows (the cumulative count stays at 1 until a second + event), so ``.last()`` returns the final follow-up row and a single event is + silently recorded as censored. That dropped ~99% of simulated events and + inflated the marginal-HR variance ~8x relative to SEQTaRget (R). See + tests/test_hazard_truncation.py. + """ + return ( + tmp.with_columns( + ( + pl.col(event_col).cum_sum().over([id_col, "trial"]) + - pl.col(event_col) + ).alias("_event_prior") + ) + .filter(pl.col("_event_prior") == 0) + .group_by([id_col, "trial"]) + .last() + .drop("_event_prior") + ) + + def _hazard_handler(self, data, idx, boot_idx, rng): exclude_cols = [ "followup", @@ -140,36 +175,18 @@ def _hazard_handler(self, data, idx, boot_idx, rng): ce_sim = rng.binomial(1, ce_prob) tmp = tmp.with_columns([pl.Series("ce", ce_sim)]) - tmp = ( - tmp.with_columns( - [ - pl.when((pl.col("outcome") == 1) | (pl.col("ce") == 1)) - .then(1) - .otherwise(0) - .alias("any_event") - ] - ) - .with_columns( - [ - pl.col("any_event") - .cum_sum() - .over([self.id_col, "trial"]) - .alias("event_cumsum") - ] - ) - .filter(pl.col("event_cumsum") <= 1) - ) - else: tmp = tmp.with_columns( [ - pl.col("outcome") - .cum_sum() - .over([self.id_col, "trial"]) - .alias("event_cumsum") + pl.when((pl.col("outcome") == 1) | (pl.col("ce") == 1)) + .then(1) + .otherwise(0) + .alias("any_event") ] - ).filter(pl.col("event_cumsum") <= 1) + ) + tmp = _truncate_to_first_event(tmp, self.id_col, "any_event") + else: + tmp = _truncate_to_first_event(tmp, self.id_col, "outcome") - tmp = tmp.group_by([self.id_col, "trial"]).last() all_treatments.append(tmp) sim_data = pl.concat(all_treatments) diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index 7ccb3b1..672f4f3 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -14,7 +14,9 @@ def _safe_predict(model, data, clip_probs=True): data : pandas DataFrame Data to predict on clip_probs : bool - If True, clip probabilities to [0, 1] and replace NaN with 0.5 + If True, clip probabilities to [0, 1]. Raises ValueError if any + predicted probability is NaN (this signals a train/predict dtype + mismatch or coefficient overflow, not a value to silently impute). """ try: probs = model.predict(data) diff --git a/tests/test_hazard_truncation.py b/tests/test_hazard_truncation.py new file mode 100644 index 0000000..1cbccc0 --- /dev/null +++ b/tests/test_hazard_truncation.py @@ -0,0 +1,109 @@ +"""Regression tests for the survival-time reduction in the hazard g-formula. + +`_truncate_to_first_event` collapses the simulated counterfactual grid (outcomes +drawn independently at every follow-up row) to one survival row per (id, trial): +the first-event row, or the max-follow-up row when there is no event. + +The earlier implementation used the inclusive `cum_sum(outcome) <= 1` then +`.last()`, which kept post-event rows and returned the final follow-up row, +silently recording single events as censored. That dropped ~99% of simulated +events and inflated the marginal-HR variance ~8x relative to SEQTaRget (R). +""" + +import polars as pl + +from pySEQTarget.analysis._hazard import _truncate_to_first_event + + +def _grid(rows): + # rows: list of (id, trial, [outcome per follow-up 0..T]) + recs = [] + for uid, trial, outs in rows: + for f, o in enumerate(outs): + recs.append((uid, trial, f, o)) + return pl.DataFrame(recs, schema=["id", "trial", "followup", "outcome"], orient="row") + + +def test_first_event_row_is_kept_for_each_pattern(): + grid = _grid( + [ + (1, 0, [0, 0, 1, 0, 0]), # single interior event -> (followup=2, event=1) + (2, 0, [0, 0, 0, 0, 0]), # no event -> (followup=4, event=0) + (3, 0, [0, 1, 0, 1, 0]), # two events; first -> (followup=1, event=1) + (4, 0, [1, 0, 0, 0, 0]), # event at time 0 -> (followup=0, event=1) + (5, 0, [0, 0, 0, 0, 1]), # event at last row -> (followup=4, event=1) + ] + ) + + out = ( + _truncate_to_first_event(grid, "id", "outcome") + .sort("id") + .select(["id", "followup", "outcome"]) + ) + + assert out.to_dicts() == [ + {"id": 1, "followup": 2, "outcome": 1}, + {"id": 2, "followup": 4, "outcome": 0}, + {"id": 3, "followup": 1, "outcome": 1}, + {"id": 4, "followup": 0, "outcome": 1}, + {"id": 5, "followup": 4, "outcome": 1}, + ] + + +def test_no_events_are_dropped(): + # Every unit that has >=1 simulated outcome must end up with event=1; only the + # all-zero unit (id=2) is censored. This is the property the old idiom broke. + grid = _grid( + [ + (1, 0, [0, 0, 1, 0, 0]), + (2, 0, [0, 0, 0, 0, 0]), + (3, 0, [0, 1, 0, 1, 0]), + (4, 0, [1, 0, 0, 0, 0]), + (5, 0, [0, 0, 0, 0, 1]), + ] + ) + out = _truncate_to_first_event(grid, "id", "outcome") + true_units_with_event = ( + grid.group_by("id").agg(pl.col("outcome").max().alias("ever"))["ever"].sum() + ) + assert out["outcome"].sum() == true_units_with_event == 4 + + +def test_grouping_is_per_id_and_trial(): + # Same id, two trials with different first-event times must be reduced + # independently. + grid = _grid( + [ + (1, 0, [0, 0, 1, 0]), # trial 0: event at 2 + (1, 1, [1, 0, 0, 0]), # trial 1: event at 0 + ] + ) + out = ( + _truncate_to_first_event(grid, "id", "outcome") + .sort(["id", "trial"]) + .select(["id", "trial", "followup", "outcome"]) + ) + assert out.to_dicts() == [ + {"id": 1, "trial": 0, "followup": 2, "outcome": 1}, + {"id": 1, "trial": 1, "followup": 0, "outcome": 1}, + ] + + +def test_beats_the_old_buggy_idiom(): + # Lock the regression: the previous `cum_sum <= 1` then `.last()` loses the + # single interior events that the fixed helper retains. + grid = _grid([(uid, 0, [0, 0, 1, 0, 0]) for uid in range(1, 11)]) + + fixed = _truncate_to_first_event(grid, "id", "outcome")["outcome"].sum() + + old = ( + grid.with_columns(pl.col("outcome").cum_sum().over(["id", "trial"]).alias("cs")) + .filter(pl.col("cs") <= 1) + .group_by(["id", "trial"]) + .last()["outcome"] + .sum() + ) + + assert fixed == 10 # every unit's single event retained + assert old == 0 # old idiom dropped all of them + assert fixed > old From be14ad74f3e9cf7a5511dba4bceee468610d596c Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 9 Jun 2026 11:53:46 +0100 Subject: [PATCH 02/32] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 046f356..b310c67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.7" +version = "0.13.8" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} From 7e12f393aec7410c6aafeb450b779cc2c2d0f83f Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 9 Jun 2026 11:53:55 +0100 Subject: [PATCH 03/32] Amend TP email --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b310c67..a8dd6b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ authors = [ {name = "Ryan O'Dea", email = "ryan.odea@psi.ch"}, {name = "Alejandro Szmulewicz", email = "aszmulewicz@hsph.harvard.edu"}, - {name = "Tom Palmer", email = "tom.palmer@bristol.ac.uk"}, + {name = "Tom Palmer", email = "remlapmot@hotmail.com"}, {name = "Miguel Hernán", email = "mhernan@hsph.harvard.edu"}, ] From af444b2c20764169ee6268891fe954a516a98ef4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 9 Jun 2026 10:54:25 +0000 Subject: [PATCH 04/32] Auto-format code --- pySEQTarget/analysis/_hazard.py | 3 +-- tests/test_hazard_truncation.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index f8388af..72f22a5 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -111,8 +111,7 @@ def _truncate_to_first_event(tmp, id_col, event_col): return ( tmp.with_columns( ( - pl.col(event_col).cum_sum().over([id_col, "trial"]) - - pl.col(event_col) + pl.col(event_col).cum_sum().over([id_col, "trial"]) - pl.col(event_col) ).alias("_event_prior") ) .filter(pl.col("_event_prior") == 0) diff --git a/tests/test_hazard_truncation.py b/tests/test_hazard_truncation.py index 1cbccc0..a106d3e 100644 --- a/tests/test_hazard_truncation.py +++ b/tests/test_hazard_truncation.py @@ -21,7 +21,9 @@ def _grid(rows): for uid, trial, outs in rows: for f, o in enumerate(outs): recs.append((uid, trial, f, o)) - return pl.DataFrame(recs, schema=["id", "trial", "followup", "outcome"], orient="row") + return pl.DataFrame( + recs, schema=["id", "trial", "followup", "outcome"], orient="row" + ) def test_first_event_row_is_kept_for_each_pattern(): From 374f78e6c9d5bbd11eaf34d391752cae155b9e0e Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 9 Jun 2026 14:21:56 +0100 Subject: [PATCH 05/32] Make glum models picklable so parallel/offload work end-to-end The hazard ratio is computed by a g-formula Monte-Carlo simulation that needs the fitted outcome model, so running the bootstrap across a process pool requires those models to survive a pickle roundtrip. glum's _GlumFit held a patsy DesignInfo, which raises NotImplementedError on pickle (patsy #26), so parallel=True and offload=True both crashed under glm_package="glum". Three fixes make the process-pool route work: - _GlumFit now records the formula and a small reference frame and rebuilds its DesignInfo in __setstate__, instead of pickling the DesignInfo. Safe here because _cast_categories freezes categorical level order and the models use only stateless transforms (precomputed squares, explicit-knot splines). Roundtrip preserves params, bse, and predictions exactly. - SEQuential.__getstate__ drops the glum-only _patsy_design_cache, which also holds (unpicklable) DesignInfo objects and otherwise rode along when the object crossed a process boundary. - _bootstrap_worker now calls the raw, undecorated fit body via __wrapped__ rather than the @bootstrap_loop-wrapped method. Going through the wrapper re-entered bootstrap_loop and returned [model_dict] (a list) instead of model_dict, crashing the hazard/survival consumers that index outcome_model[i]["outcome"]. This was a pre-existing bug affecting both the statsmodels and glum backends. parallel=True now bit-matches the serial hazard ratio + CI for both backends; offload=True + glum and fit(parallel=True) + glum both work. Tests: - test_glum.py: _GlumFit pickle roundtrip preserves params and predictions. - test_parallel.py: parallel hazard matches serial (statsmodels and glum). --- pySEQTarget/SEQuential.py | 10 +++++ pySEQTarget/helpers/_bootstrap.py | 11 ++++- pySEQTarget/helpers/_glum_fit.py | 75 ++++++++++++++++++++++++++++--- tests/test_glum.py | 43 ++++++++++++++++++ tests/test_parallel.py | 44 ++++++++++++++++++ 5 files changed, 176 insertions(+), 7 deletions(-) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 111e92b..a5761d8 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -111,6 +111,16 @@ def __init__( _param_checker(self) _data_checker(self) + def __getstate__(self): + # The glum design-info cache (_outcome_fit) holds patsy DesignInfo + # objects, which can't be pickled (patsy #26). It is a per-process speed + # cache rebuilt lazily on first fit, so drop it when crossing a process + # boundary (parallel bootstrap / offload); workers repopulate it. Without + # this, parallel=True + glm_package="glum" crashes on pickling. + state = self.__dict__.copy() + state.pop("_patsy_design_cache", None) + return state + def expand(self): """ Creates the sequentially nested, emulated target trial structure. diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 733be5e..72a5cf7 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -80,8 +80,17 @@ def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): # Disable bootstrapping to prevent recursion obj.bootstrap_nboot = 0 + # Call the raw, undecorated fit body — not the @bootstrap_loop-wrapped + # method — so it returns this replicate's single model dict. Going through + # the wrapper would re-enter bootstrap_loop and return a list ([model_dict]), + # which the serial path never does, breaking the hazard/survival consumers + # that index outcome_model[i]["outcome"]. method = getattr(obj, method_name) - result = method(*args, **kwargs) + raw = getattr(method, "__wrapped__", None) + if raw is not None: + result = raw(obj, *args, **kwargs) + else: + result = method(*args, **kwargs) obj._rng = None return result diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index b9f2a73..e2ffa97 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -40,21 +40,69 @@ class _GlumFit: just like statsmodels keeps model.exog, so memory use is comparable. """ - def __init__(self, glum_model, design_info, feature_names, X_design, sample_weight): + def __init__( + self, + glum_model, + design_info, + feature_names, + X_design, + sample_weight, + formula=None, + ref_frame=None, + ): self._glum = glum_model self._design_info = design_info self._X_design = X_design # includes the intercept column self._sample_weight = sample_weight + # Inputs to rebuild ``design_info`` on unpickle: the patsy DesignInfo + # itself cannot be pickled (patsy #26), so we keep the formula and a + # tiny reference frame (which preserves each categorical column's full, + # ordered dtype categories) and re-parse on __setstate__. + self._formula = formula + self._ref_frame = ref_frame + + self._build_model_namespace(design_info, feature_names) + self.exog_names = feature_names + # statsmodels convention: intercept first + all_coefs = np.concatenate([[glum_model.intercept_], glum_model.coef_]) + self.params = pd.Series(all_coefs, index=feature_names) + + def _build_model_namespace(self, design_info, feature_names): self.model = types.SimpleNamespace( exog_names=feature_names, data=types.SimpleNamespace(design_info=design_info), ) - self.exog_names = feature_names - # statsmodels convention: intercept first - all_coefs = np.concatenate([[glum_model.intercept_], glum_model.coef_]) - self.params = pd.Series(all_coefs, index=feature_names) + def __getstate__(self): + # Drop the unpicklable patsy DesignInfo and the SimpleNamespaces that + # reference it; __setstate__ rebuilds them from the formula + ref_frame. + state = self.__dict__.copy() + state.pop("_design_info", None) + state.pop("model", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + if self._formula is None or self._ref_frame is None: + raise RuntimeError( + "Cannot unpickle _GlumFit fitted before formula/ref_frame were " + "recorded; refit with the current pySEQTarget version." + ) + _, X_mat = patsy.dmatrices( + self._formula, self._ref_frame, return_type="dataframe" + ) + if list(X_mat.columns) != list(self.exog_names): + # The reference frame's categorical ordering must reproduce the + # frozen column structure exactly, or glum's coefficients would be + # paired with the wrong design columns on predict. Fail loudly + # rather than return silently wrong predictions. + raise RuntimeError( + "_GlumFit design columns changed on unpickle: " + f"{list(X_mat.columns)} != {list(self.exog_names)}" + ) + self._design_info = X_mat.design_info + self._build_model_namespace(self._design_info, self.exog_names) def predict(self, data, transform=True): if transform: @@ -185,4 +233,19 @@ def _fit_glum(formula, data, var_weights=None, start_params=None, design_cache=N fit_kwargs["sample_weight"] = sample_weight glm.fit(X_arr, y_arr, **fit_kwargs) - return _GlumFit(glm, design_info, feature_names, X_design, sample_weight) + + # Keep a minimal reference frame so the (unpicklable) design_info can be + # rebuilt on unpickle. Two rows suffice: patsy derives categorical contrasts + # from each column's full dtype categories, not the observed values, and the + # codebase uses only stateless transforms (precomputed squares, explicit-knot + # splines), so no fit-time state needs preserving. + ref_frame = data.head(2).copy() + return _GlumFit( + glm, + design_info, + feature_names, + X_design, + sample_weight, + formula=formula, + ref_frame=ref_frame, + ) diff --git a/tests/test_glum.py b/tests/test_glum.py index cb3b3c5..b875ebe 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -358,6 +358,49 @@ def test_glum_design_cache_handles_categorical_level_reordering(): assert list(mb.params.index) == list(m.params.index) +def test_glum_model_pickle_roundtrip_preserves_predictions(): + # _GlumFit holds a patsy DesignInfo, which cannot be pickled (patsy #26). + # It must rebuild the DesignInfo on unpickle from the stored formula + + # reference frame so fitted models can cross a process boundary (parallel + # bootstrap, offload). The roundtrip must preserve params and predictions + # exactly and yield a usable design_info. + import pickle + + import numpy as np + import pandas as pd + + from pySEQTarget.helpers._glum_fit import _fit_glum + + rng = np.random.default_rng(0) + n = 500 + df = pd.DataFrame( + { + "y": (rng.random(n) < 0.4).astype(int), + "x": rng.standard_normal(n), + "g": pd.Categorical(rng.choice(["a", "b", "c"], n), categories=["a", "b", "c"]), + } + ) + + m = _fit_glum("y ~ x + g", df) + pred = m.predict(df) + + m2 = pickle.loads(pickle.dumps(m)) + + assert list(m2.params.index) == list(m.params.index) + assert list(m2.params.values) == approx(list(m.params.values), rel=1e-12, abs=1e-12) + assert list(m2.exog_names) == list(m.exog_names) + # design_info is rebuilt and reproduces the frozen column structure + assert list(m2.model.data.design_info.column_names) == list( + m.model.data.design_info.column_names + ) + # predictions are bit-identical through both predict paths + np.testing.assert_array_equal(m2.predict(df), pred) + np.testing.assert_array_equal( + m2.predict(m2._X_design, transform=False), + m.predict(m._X_design, transform=False), + ) + + def test_glum_warm_start_dropped_when_design_columns_mismatch(): # The defensive guard in _fit_glum: a (values, names) tuple whose names # don't line up with the patsy design matrix must be ignored, falling back diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 907121d..a6384f9 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -43,3 +43,47 @@ def test_parallel_ITT(): ], abs=1e-6, ) + + +def _hazard_run(parallel, glm_package): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + glm_package=glm_package, + hazard_estimate=True, + bootstrap_nboot=4, + ncores=2, + parallel=parallel, + seed=42, + ), + ) + s.expand() + s.bootstrap() + s.fit() + s.hazard() + hr = s.hazard_ratio + return (hr["Hazard ratio"][0], hr["LCI"][0], hr["UCI"][0]) + + +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Parallelism test hangs in CI environment" +) +@pytest.mark.parametrize("glm_package", ["statsmodels", "glum"]) +def test_parallel_hazard_matches_serial(glm_package): + # The process-pool bootstrap must produce the same hazard ratio + CI as the + # serial loop. Locks two fixes: (1) glum's _GlumFit is now picklable so the + # fitted models survive crossing the process boundary, and (2) the worker + # calls the raw fit body, so outcome_model[i] is a model dict (not a list) + # for the hazard consumer to index. Previously crashed for both backends. + serial = _hazard_run(parallel=False, glm_package=glm_package) + parallel = _hazard_run(parallel=True, glm_package=glm_package) + assert parallel == pytest.approx(serial, rel=1e-9, abs=1e-12) From 107f0a0b648c037f147259c5ae0278d605a8548b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 9 Jun 2026 13:22:21 +0000 Subject: [PATCH 06/32] Auto-format code --- tests/test_glum.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_glum.py b/tests/test_glum.py index b875ebe..4132c94 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -377,7 +377,9 @@ def test_glum_model_pickle_roundtrip_preserves_predictions(): { "y": (rng.random(n) < 0.4).astype(int), "x": rng.standard_normal(n), - "g": pd.Categorical(rng.choice(["a", "b", "c"], n), categories=["a", "b", "c"]), + "g": pd.Categorical( + rng.choice(["a", "b", "c"], n), categories=["a", "b", "c"] + ), } ) From 2f5d59425b8974580e741d6c1ccc8cb16a7d60c0 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 9 Jun 2026 14:47:36 +0100 Subject: [PATCH 07/32] Parallelize the bootstrap hazard simulation over a process pool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The per-replicate hazard step (g-formula simulation + Cox fit) is GIL-bound — the patsy design-matrix build dominates — so threads can't speed it up. With glum models now picklable, run it across a process pool when parallel=True instead. _calculate_hazard_single gains a parallel branch (_parallel_boot_log_hrs) that submits each replicate to a ProcessPoolExecutor. A pool initializer ships the analysis frame (via the offloader ref) and a slimmed SEQuential copy (DT/data nulled, fitted models kept) once per worker process, so each task carries only small integers. The serial body is factored into _one_boot_log_hr and shared by both paths; the RNG is rebuilt from seed + sample_idx + 1 exactly as before, so results are bit-identical to the serial loop. To keep the pooled models cheap to ship, _GlumFit no longer pickles its full design matrix (_X_design can be ~100s of MB): __getstate__ caches the small coefficient covariance and stores the observation count, then drops _X_design. bse/summary still work after unpickle; predict never needed it. This also lightens offload and fit(parallel=True). Benchmark (496k rows, 20 bootstraps, glum, 8 cores): hazard 73.3s -> 40.8s (1.80x), total 137.2s -> 73.4s (1.87x), hazard ratio bit-identical. Tests: extend the _GlumFit pickle test to confirm bse survives the _X_design drop. test_parallel_hazard_matches_serial now exercises the parallel hazard path for both backends. --- pySEQTarget/analysis/_hazard.py | 121 ++++++++++++++++++++++++------- pySEQTarget/helpers/_glum_fit.py | 24 +++++- tests/test_glum.py | 10 ++- 3 files changed, 124 insertions(+), 31 deletions(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 72f22a5..1886d6e 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -1,4 +1,6 @@ +import copy import warnings +from concurrent.futures import ProcessPoolExecutor import numpy as np import polars as pl @@ -33,8 +35,6 @@ def _calculate_hazard_single(self, data, idx=None, val=None): return _create_hazard_output(None, None, None, val, self) if self.bootstrap_nboot > 0: - boot_log_hrs = [] - # outcome_model[model_pos + 1] was fit on _boot_samples[sample_idx]; # skipped replicates make this mapping non-identity, so iterate it # explicitly rather than assuming model index == sample index. @@ -42,31 +42,18 @@ def _calculate_hazard_single(self, data, idx=None, val=None): if boot_sample_idx is None: boot_sample_idx = list(range(len(self._boot_samples))) - for model_pos, sample_idx in enumerate(boot_sample_idx): - if self.seed is not None: - self._rng = np.random.RandomState(self.seed + sample_idx + 1) - id_counts = self._boot_samples[sample_idx] - - counts = pl.DataFrame( - { - self.id_col: list(id_counts.keys()), - "_count": list(id_counts.values()), - } - ) - boot_data = ( - data.lazy() - .join(counts.lazy(), on=self.id_col, how="inner") - .with_columns(pl.int_ranges(0, pl.col("_count")).alias("_rep")) - .explode("_rep") - .drop("_count", "_rep") - .collect() - ) - - boot_log_hr = _hazard_handler( - self, boot_data, idx, model_pos + 1, self._rng - ) - if boot_log_hr is not None and not np.isnan(boot_log_hr): - boot_log_hrs.append(boot_log_hr) + # The per-replicate hazard simulation is GIL-bound (patsy design build), + # so spread it over a process pool when parallel=True. Needs a concrete + # seed (always set since SEQuential pins a default) so each replicate's + # RNG — and therefore the result — is identical to the serial path. + if getattr(self, "parallel", False) and self.seed is not None: + boot_log_hrs = _parallel_boot_log_hrs(self, data, idx, boot_sample_idx) + else: + boot_log_hrs = [] + for model_pos, sample_idx in enumerate(boot_sample_idx): + boot_log_hr = _one_boot_log_hr(self, data, idx, model_pos, sample_idx) + if boot_log_hr is not None and not np.isnan(boot_log_hr): + boot_log_hrs.append(boot_log_hr) if len(boot_log_hrs) == 0: return _create_hazard_output(np.exp(full_log_hr), None, None, val, self) @@ -121,6 +108,86 @@ def _truncate_to_first_event(tmp, id_col, event_col): ) +def _one_boot_log_hr(self, data, idx, model_pos, sample_idx): + """Build one bootstrap resample of ``data`` and return its log hazard ratio. + + The RNG is rebuilt from ``seed + sample_idx + 1`` (matching the serial loop + exactly), so this is bit-identical whether called serially or in a worker. + """ + seed = getattr(self, "seed", None) + rng = np.random.RandomState(seed + sample_idx + 1) if seed is not None else self._rng + + id_counts = self._boot_samples[sample_idx] + counts = pl.DataFrame( + { + self.id_col: list(id_counts.keys()), + "_count": list(id_counts.values()), + } + ) + boot_data = ( + data.lazy() + .join(counts.lazy(), on=self.id_col, how="inner") + .with_columns(pl.int_ranges(0, pl.col("_count")).alias("_rep")) + .explode("_rep") + .drop("_count", "_rep") + .collect() + ) + return _hazard_handler(self, boot_data, idx, model_pos + 1, rng) + + +# Process-pool worker state. Set once per worker process by the initializer so +# each task ships only small integers, not the (slimmed) SEQuential object or +# the analysis frame. +_HZ_WORKER_OBJ = None +_HZ_WORKER_DATA = None + + +def _hazard_pool_init(obj, data_ref): + global _HZ_WORKER_OBJ, _HZ_WORKER_DATA + _HZ_WORKER_OBJ = obj + _HZ_WORKER_DATA = obj._offloader.load_dataframe(data_ref) + + +def _hazard_pool_task(idx, model_pos, sample_idx): + return _one_boot_log_hr(_HZ_WORKER_OBJ, _HZ_WORKER_DATA, idx, model_pos, sample_idx) + + +def _parallel_boot_log_hrs(self, data, idx, boot_sample_idx): + """Run the bootstrap hazard simulations over a process pool. + + The analysis frame is handed to each worker process once (via the offloader + ref + pool initializer), and a slimmed copy of ``self`` carries the fitted + models. Results are gathered in submission order, matching the serial loop; + NaN/None replicates are dropped the same way. + """ + data_ref = self._offloader.save_dataframe(data, f"_haz_DT_{idx}") + + # Slim copy for the pool: drop the large frames workers reload from data_ref; + # keep the fitted models, bootstrap samples, and config. _GlumFit and + # SEQuential each drop their unpicklable / heavy state on pickle. + slim = copy.copy(self) + slim.DT = None + slim.data = None + slim._rng = None + + boot_log_hrs = [] + with ProcessPoolExecutor( + max_workers=self.ncores, + initializer=_hazard_pool_init, + initargs=(slim, data_ref), + ) as executor: + futures = [ + executor.submit(_hazard_pool_task, idx, model_pos, sample_idx) + for model_pos, sample_idx in enumerate(boot_sample_idx) + ] + for future in futures: + result = future.result() + if result is not None and not np.isnan(result): + boot_log_hrs.append(result) + + return boot_log_hrs + + def _hazard_handler(self, data, idx, boot_idx, rng): exclude_cols = [ "followup", diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index e2ffa97..74a4b03 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -53,7 +53,13 @@ def __init__( self._glum = glum_model self._design_info = design_info self._X_design = X_design # includes the intercept column + self._nobs = X_design.shape[0] self._sample_weight = sample_weight + # Lazily-filled cache of the (small) coefficient covariance matrix. It + # lets __getstate__ drop the full design matrix (_X_design can be 100s + # of MB) while keeping bse/summary working after unpickle — important + # for the process pool and offload, which ship many fitted models. + self._cov_cached = None # Inputs to rebuild ``design_info`` on unpickle: the patsy DesignInfo # itself cannot be pickled (patsy #26), so we keep the formula and a # tiny reference frame (which preserves each categorical column's full, @@ -80,6 +86,12 @@ def __getstate__(self): state = self.__dict__.copy() state.pop("_design_info", None) state.pop("model", None) + # Replace the full design matrix with the small cached covariance so the + # pickled model stays lightweight (the design matrix can be 100s of MB). + # bse/summary still work via _cov_cached; predict never needs _X_design. + if state.get("_cov_cached") is None: + state["_cov_cached"] = self.cov_params() + state["_X_design"] = None return state def __setstate__(self, state): @@ -117,12 +129,20 @@ def predict(self, data, transform=True): return self._glum.predict(X_arr) def cov_params(self): + if self._cov_cached is not None: + return self._cov_cached X = self._X_design + if X is None: + raise RuntimeError( + "cov_params unavailable: design matrix was dropped on pickle and " + "no covariance was cached." + ) mu = self._glum.predict(X[:, 1:]) w = mu * (1.0 - mu) if self._sample_weight is not None: w = w * np.asarray(self._sample_weight) - return np.linalg.pinv(X.T @ (w[:, None] * X)) + self._cov_cached = np.linalg.pinv(X.T @ (w[:, None] * X)) + return self._cov_cached @property def bse(self): @@ -158,7 +178,7 @@ def summary(self): "GLM (glum backend)", "Binomial", "logit", - str(self._X_design.shape[0]), + str(self._nobs), ] }, index=["Model:", "Family:", "Link:", "No. Observations:"], diff --git a/tests/test_glum.py b/tests/test_glum.py index 4132c94..f9a71c5 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -395,12 +395,18 @@ def test_glum_model_pickle_roundtrip_preserves_predictions(): assert list(m2.model.data.design_info.column_names) == list( m.model.data.design_info.column_names ) - # predictions are bit-identical through both predict paths + # predictions are bit-identical through both predict paths. The design + # matrix is an external input (the unpickled model drops its own _X_design + # to stay lightweight), so feed the original to both for the transform=False + # path. np.testing.assert_array_equal(m2.predict(df), pred) np.testing.assert_array_equal( - m2.predict(m2._X_design, transform=False), + m2.predict(m._X_design, transform=False), m.predict(m._X_design, transform=False), ) + # bse still works after unpickle even though _X_design was dropped (cached cov) + assert m2._X_design is None + np.testing.assert_allclose(m2.bse.values, m.bse.values, rtol=0, atol=0) def test_glum_warm_start_dropped_when_design_columns_mismatch(): From b9a3dc62f4202bf8e8b2cb54ddb2eb7a885a7cf2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 9 Jun 2026 13:47:56 +0000 Subject: [PATCH 08/32] Auto-format code --- pySEQTarget/analysis/_hazard.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 1886d6e..eb16eba 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -115,7 +115,9 @@ def _one_boot_log_hr(self, data, idx, model_pos, sample_idx): exactly), so this is bit-identical whether called serially or in a worker. """ seed = getattr(self, "seed", None) - rng = np.random.RandomState(seed + sample_idx + 1) if seed is not None else self._rng + rng = ( + np.random.RandomState(seed + sample_idx + 1) if seed is not None else self._rng + ) id_counts = self._boot_samples[sample_idx] counts = pl.DataFrame( From 8e8854540628815b11359246f1d8fba38245c449 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Wed, 10 Jun 2026 12:29:14 +0100 Subject: [PATCH 09/32] Preserve main-fit weight models across the bootstrap loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _fit_numerator/_fit_denominator store the weight models with `self.numerator_model = fits`, overwriting on every fit. In a serial bootstrap the replicates run in-process, so after bootstrapping `self.numerator_model`/`self.denominator_model` hold the last resampled replicate's models — and the numerator/denominator summaries then report a bootstrap replicate's observation count instead of the main fit's, shifting with bootstrap conditions. (The parallel path was unaffected, since replicates run in worker copies, so it already showed the main fit.) bootstrap_loop now snapshots the main-fit weight-model attributes (numerator/denominator/cense/visit models and weight_stats) after the main fit and restores them after the replicate loop. Replicates still refit and use their own weights during fitting; only the stored models used for display are pinned to the main fit, so serial and parallel agree. Estimates are unchanged. Add tests/test_weight_main_fit_preserved.py (statsmodels + glum, serial and parallel). --- pySEQTarget/helpers/_bootstrap.py | 30 ++++++++++ tests/test_weight_main_fit_preserved.py | 76 +++++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 tests/test_weight_main_fit_preserved.py diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 72a5cf7..a8a8fe7 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -11,6 +11,23 @@ from ._format_time import _format_time +# Side-effect attributes set by the main fit that bootstrap replicates overwrite +# when they run in-process (the serial path: each replicate calls the fit body +# again, and _fit_numerator/_fit_denominator do `self.X_model = fits`). Snapshot +# them after the main fit and restore after the replicate loop so summaries +# reflect the main fit, not the last replicate. The parallel path already keeps +# them (replicates run in worker copies), so restore is a no-op there. +_MAIN_FIT_ATTRS = ( + "numerator_model", + "denominator_model", + "cense_numerator_model", + "cense_denominator_model", + "visit_numerator_model", + "visit_denominator_model", + "weight_stats", +) + + def _prepare_boot_data(self, data, boot_id): id_counts = self._boot_samples[boot_id] @@ -113,6 +130,14 @@ def wrapper(self, *args, **kwargs): full = method(self, *args, **kwargs) results.append(full) + # Snapshot the main-fit weight models before any in-process replicate + # can overwrite them; restored just before returning. + main_fit_state = { + attr: getattr(self, attr) + for attr in _MAIN_FIT_ATTRS + if hasattr(self, attr) + } + if getattr(self, "bootstrap_nboot") > 0 and getattr( self, "_boot_samples", None ): @@ -214,6 +239,11 @@ def wrapper(self, *args, **kwargs): end = time.perf_counter() self._model_time = _format_time(start, end) + # Restore the main-fit weight models so numerator/denominator summaries + # reflect the main fit rather than the last in-process replicate. + for attr, value in main_fit_state.items(): + setattr(self, attr, value) + self.outcome_model = results return results diff --git a/tests/test_weight_main_fit_preserved.py b/tests/test_weight_main_fit_preserved.py new file mode 100644 index 0000000..3e05c1c --- /dev/null +++ b/tests/test_weight_main_fit_preserved.py @@ -0,0 +1,76 @@ +import os + +import pytest +from pytest import approx + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _fit_weighted(bootstrap_nboot, parallel=False, glm_package="statsmodels", seed=42): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + glm_package=glm_package, + weighted=True, + # post-expansion weights are refit on every (resampled) replicate, so + # this is the case where an in-process bootstrap clobbers the stored + # main-fit weight models. + weight_preexpansion=False, + bootstrap_nboot=bootstrap_nboot, + parallel=parallel, + ncores=2, + seed=seed, + ), + ) + s.expand() + if bootstrap_nboot > 0: + s.bootstrap() + s.fit() + return s + + +def _weight_params(s): + return { + "numerator": [list(m.params) for m in s.numerator_model if m is not None], + "denominator": [list(m.params) for m in s.denominator_model if m is not None], + } + + +def _assert_same(a, b): + for kind in ("numerator", "denominator"): + assert len(a[kind]) == len(b[kind]) + for pa, pb in zip(a[kind], b[kind]): + assert pb == approx(pa, rel=1e-9, abs=1e-9) + + +@pytest.mark.parametrize("glm_package", ["statsmodels", "glum"]) +def test_main_weight_models_preserved_after_bootstrap(glm_package): + # _fit_numerator/_fit_denominator overwrite self.X_model on every fit, and a + # serial bootstrap runs replicates in-process — so without preservation the + # stored weight model (used by the numerator/denominator summary) would be + # the last resampled replicate, not the main fit. The main-fit models must + # survive the bootstrap loop. + main = _weight_params(_fit_weighted(0, glm_package=glm_package)) + booted = _weight_params(_fit_weighted(3, glm_package=glm_package)) + _assert_same(main, booted) + + +@pytest.mark.skipif( + os.getenv("CI") == "true", reason="Parallelism test hangs in CI environment" +) +def test_main_weight_models_match_across_serial_and_parallel(): + main = _weight_params(_fit_weighted(0, glm_package="glum")) + serial = _weight_params(_fit_weighted(3, parallel=False, glm_package="glum")) + parallel = _weight_params(_fit_weighted(3, parallel=True, glm_package="glum")) + _assert_same(main, serial) + _assert_same(main, parallel) From e8eab6498eaaf6c757f81d04c01a369d61ae4044 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 10 Jun 2026 11:29:31 +0000 Subject: [PATCH 10/32] Auto-format code --- pySEQTarget/helpers/_bootstrap.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index a8a8fe7..17542b3 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -10,7 +10,6 @@ from ._format_time import _format_time - # Side-effect attributes set by the main fit that bootstrap replicates overwrite # when they run in-process (the serial path: each replicate calls the fit body # again, and _fit_numerator/_fit_denominator do `self.X_model = fits`). Snapshot @@ -133,9 +132,7 @@ def wrapper(self, *args, **kwargs): # Snapshot the main-fit weight models before any in-process replicate # can overwrite them; restored just before returning. main_fit_state = { - attr: getattr(self, attr) - for attr in _MAIN_FIT_ATTRS - if hasattr(self, attr) + attr: getattr(self, attr) for attr in _MAIN_FIT_ATTRS if hasattr(self, attr) } if getattr(self, "bootstrap_nboot") > 0 and getattr( From 1f3a627d4e36066e5c4d5611f6fa0a07a12c0c57 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Wed, 10 Jun 2026 12:49:46 +0100 Subject: [PATCH 11/32] Report the censored/uncensored split in verbose expansion output Under method="censoring" the verbose output now prints how many expanded rows enter the outcome model (switch != 1) versus how many are artificially censored, alongside the total. This makes the count line up with implementations that report only the modelled rows (e.g. Stata seqtte), where the difference is just the retained censored rows. ITT prints no split (no artificial censoring). Add tests/test_verbose_censored_counts.py. --- pySEQTarget/SEQuential.py | 11 +++++++ tests/test_verbose_censored_counts.py | 47 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 tests/test_verbose_censored_counts.py diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index a5761d8..ef2e11e 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -197,6 +197,17 @@ def expand(self): if self.verbose: n, m = self.DT.shape print(f"Final analysis dataset: {n:,} observations, {m} variables") + # Under censoring the outcome model is fit only on the un-censored + # rows (switch != 1, matching _outcome_fit); the rest are retained in + # the dataset but artificially censored. Report the split so the + # count lines up with implementations that print only the modelled + # rows (e.g. Stata seqtte). + if self.method == "censoring" and "switch" in self.DT.columns: + n_censored = self.DT.filter(pl.col("switch") == 1).height + print( + f" entering outcome model (uncensored): {n - n_censored:,}\n" + f" artificially censored (treatment switch): {n_censored:,}" + ) end = time.perf_counter() self._expansion_time = _format_time(start, end) diff --git a/tests/test_verbose_censored_counts.py b/tests/test_verbose_censored_counts.py new file mode 100644 index 0000000..4ede282 --- /dev/null +++ b/tests/test_verbose_censored_counts.py @@ -0,0 +1,47 @@ +import re + +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _expand(method, capsys): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=method, + parameters=SEQopts(verbose=True), + ) + s.expand() + return s, capsys.readouterr().out + + +def test_verbose_reports_uncensored_censored_split(capsys): + # Under censoring the verbose output reports how many expanded rows enter the + # outcome model (un-censored) vs are artificially censored, so the count + # lines up with implementations that print only the modelled rows. + s, out = _expand("censoring", capsys) + + total = int(re.search(r"Final analysis dataset: ([\d,]+)", out).group(1).replace(",", "")) + unc = int(re.search(r"uncensored\): ([\d,]+)", out).group(1).replace(",", "")) + cen = int(re.search(r"treatment switch\): ([\d,]+)", out).group(1).replace(",", "")) + + assert unc + cen == total + assert cen > 0 # SEQdata has treatment switches, so some rows are censored + # The reported un-censored count must equal the rows _outcome_fit fits on. + assert unc == s.DT.filter(pl.col("switch") != 1).height + + +def test_verbose_no_censored_split_for_itt(capsys): + # ITT applies no artificial censoring, so the split is not reported. + _, out = _expand("ITT", capsys) + assert "uncensored" not in out + assert "artificially censored" not in out From dad6e2804e1bd784ac3d5d9b4e44bb526a061b7f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 10 Jun 2026 11:50:01 +0000 Subject: [PATCH 12/32] Auto-format code --- tests/test_verbose_censored_counts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_verbose_censored_counts.py b/tests/test_verbose_censored_counts.py index 4ede282..b007957 100644 --- a/tests/test_verbose_censored_counts.py +++ b/tests/test_verbose_censored_counts.py @@ -30,7 +30,9 @@ def test_verbose_reports_uncensored_censored_split(capsys): # lines up with implementations that print only the modelled rows. s, out = _expand("censoring", capsys) - total = int(re.search(r"Final analysis dataset: ([\d,]+)", out).group(1).replace(",", "")) + total = int( + re.search(r"Final analysis dataset: ([\d,]+)", out).group(1).replace(",", "") + ) unc = int(re.search(r"uncensored\): ([\d,]+)", out).group(1).replace(",", "")) cen = int(re.search(r"treatment switch\): ([\d,]+)", out).group(1).replace(",", "")) From 0cc419bc5ab68167130ec4ca341548df008e1305 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 11 Jun 2026 10:11:53 +0100 Subject: [PATCH 13/32] Fix bootstrap weight corruption under weight_preexpansion By joining weights on a recovered original-ID key instead of collapsing the resampled IDs, so each replicate copy of a multiply-sampled subject accumulates its own weight product. --- pySEQTarget/weighting/_weight_bind.py | 23 +++++--- tests/test_bootstrap_weights.py | 76 +++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 8 deletions(-) create mode 100644 tests/test_bootstrap_weights.py diff --git a/pySEQTarget/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py index d159af5..0269a1c 100644 --- a/pySEQTarget/weighting/_weight_bind.py +++ b/pySEQTarget/weighting/_weight_bind.py @@ -2,6 +2,7 @@ def _weight_bind(self, WDT): + drop_after_join = [] if self.weight_preexpansion: join = "inner" on = [self.id_col, "period"] @@ -9,23 +10,29 @@ def _weight_bind(self, WDT): # On a bootstrap pass _prepare_boot_data transformed id_col so that # each replicate has a unique value -- integer math (orig_id * id_mult # + replicate) for int IDs, "{orig_id}_{replicate}" for string IDs. - # Recover the original ID here so the join to WDT (which still carries - # un-resampled originals) lines up. No-op on the main fit pass. + # WDT still carries the un-resampled originals, so join on a recovered + # original-ID key. Do NOT overwrite id_col itself: the weight cum_prod + # below groups on (id_col, trial), and collapsing replicate copies of a + # multiply-sampled subject into one group would interleave their rows + # and corrupt the cumulative weights (each copy must accumulate its own + # product independently). No-op on the main fit pass. is_boot = getattr(self, "_current_boot_idx", None) is not None if is_boot: if self.DT.schema[self.id_col].is_integer(): - self.DT = self.DT.with_columns( - (pl.col(self.id_col) // self._boot_id_mult).alias(self.id_col) - ) + orig_id = pl.col(self.id_col) // self._boot_id_mult else: - self.DT = self.DT.with_columns( - pl.col(self.id_col).str.replace(r"_\d+$", "").alias(self.id_col) - ) + orig_id = pl.col(self.id_col).str.replace(r"_\d+$", "") + self.DT = self.DT.with_columns(orig_id.alias("_orig_id")) + WDT = WDT.rename({self.id_col: "_orig_id"}) + on = ["_orig_id", "period"] + drop_after_join = ["_orig_id"] else: join = "left" on = [self.id_col, "trial", "followup"] WDT = self.DT.join(WDT, on=on, how=join) + if drop_after_join: + WDT = WDT.drop(drop_after_join) if self.visit_colname is not None: visit = pl.col(self.visit_colname) == 0 diff --git a/tests/test_bootstrap_weights.py b/tests/test_bootstrap_weights.py new file mode 100644 index 0000000..4ebfb33 --- /dev/null +++ b/tests/test_bootstrap_weights.py @@ -0,0 +1,76 @@ +"""Regression test: bootstrap weights with weight_preexpansion=True. + +_weight_bind joins the pre-expansion weight frame (un-resampled original IDs) +onto the bootstrap-resampled DT. It must do so WITHOUT collapsing the resampled +IDs back to originals: the weight cum_prod groups on (id, trial), and merging +the replicate copies of a multiply-sampled subject into one group interleaves +their rows — turning weights a, ab into a, a², a²b, a²b²… (each copy compounds +the other's). Every replicate copy duplicates the same source rows, so the +correct cumulative weights are identical across copies. +""" + +import sys + +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def test_boot_weights_identical_across_replicate_copies(monkeypatch): + # The package __init__ re-exports shadow the submodule names, so patch the + # name inside the SEQuential module via sys.modules. + seq_mod = sys.modules["pySEQTarget.SEQuential"] + wb_mod = sys.modules["pySEQTarget.weighting._weight_bind"] + + captured = [] + orig = wb_mod._weight_bind + + def spy(self, WDT): + result = orig(self, WDT) + if getattr(self, "_current_boot_idx", None) is not None: + captured.append(self.DT) + return result + + monkeypatch.setattr(seq_mod, "_weight_bind", spy) + + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + bootstrap_nboot=1, + bootstrap_sample=1.0, + seed=42, + ), + ) + s.expand() + s.bootstrap() + s.fit() + + assert len(captured) == 1 + DT = captured[0] + + # The resampled encoded IDs must survive the bind (replicate copies stay + # distinct groups for the cum_prod) ... + id_mult = s._boot_id_mult + orig_ids = set(s.data["ID"].unique().to_list()) + assert not set(DT["ID"].unique().to_list()) <= orig_ids + + # ... and with replicate sampling (sample=1.0 guarantees duplicated + # subjects), every copy of the same original (id, trial, followup) row must + # carry the SAME cumulative weight. + decoded = DT.with_columns((pl.col("ID") // id_mult).alias("_orig")) + dup = decoded.group_by(["_orig", "trial", "followup"]).agg( + [pl.len().alias("n"), pl.col("weight").n_unique().alias("n_weights")] + ) + assert dup.filter(pl.col("n") > 1).height > 0 # duplicates actually present + assert dup.filter(pl.col("n_weights") > 1).height == 0 From 71d4802426108062dab97b110fe144d88bf6a5c8 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 11 Jun 2026 11:01:43 +0100 Subject: [PATCH 14/32] Fix operator precedence in selection_random filter It previously silently dropped all sampled control trials when treatment_level[0] != 0 --- pySEQTarget/expansion/_selection.py | 9 ++++++-- tests/test_selection_random.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/expansion/_selection.py b/pySEQTarget/expansion/_selection.py index 63b7361..4cc1336 100644 --- a/pySEQTarget/expansion/_selection.py +++ b/pySEQTarget/expansion/_selection.py @@ -37,9 +37,14 @@ def _random_selection(self): ).alias("trialID") ) .filter( + # Parentheses matter: `|` binds tighter than `!=`, so without them + # this parses as `(is_in | col) != level`, which silently drops + # every sampled control trial when treatment_level[0] != 0. pl.col("trialID").is_in(sample) - | pl.col(f"{self.treatment_col}{self.indicator_baseline}") - != self.treatment_level[0] + | ( + pl.col(f"{self.treatment_col}{self.indicator_baseline}") + != self.treatment_level[0] + ) ) .drop("trialID") ) diff --git a/tests/test_selection_random.py b/tests/test_selection_random.py index 6dbe2b8..3607dcc 100644 --- a/tests/test_selection_random.py +++ b/tests/test_selection_random.py @@ -54,3 +54,36 @@ def test_selection_random_is_reproducible_with_fixed_seed(): a = _build(selection_random=True, selection_sample=0.5, seed=7) b = _build(selection_random=True, selection_sample=0.5, seed=7) assert a.DT.equals(b.DT) + + +def test_selection_random_nonzero_control_level(): + # Regression: the filter used `is_in(sample) | col != level`, which parses + # as `(is_in | col) != level` and silently dropped every sampled control + # trial whenever treatment_level[0] != 0 (e.g. [1, 2]) — the whole control + # arm vanished. Sampled controls must be retained. + prob = 0.5 + + def build(**opts): + opts.setdefault("seed", 1) + s = SEQuential( + load_data("SEQdata_multitreatment"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(treatment_level=[1, 2], **opts), + ) + s.expand() + return s + + base = _arm_trial_starts(build().DT) + sel = _arm_trial_starts(build(selection_random=True, selection_sample=prob).DT) + + # Non-control arm (level 2) fully retained; control arm (level 1) + # subsampled to the requested fraction — not dropped entirely. + assert sel[2] == base[2] + assert sel[1] == int(prob * base[1]) From 5fca21a7ebc9bf0a2ca5f4efb3bb6a1c0e5a8dd9 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 11 Jun 2026 16:11:53 +0100 Subject: [PATCH 15/32] Normalize time_varying_cols and fixed_cols To empty lists in the constructor so omitting these documented-Optional arguments no longer crashes --- pySEQTarget/SEQuential.py | 7 +++++-- tests/test_optional_covariate_args.py | 30 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 tests/test_optional_covariate_args.py diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index ef2e11e..3153010 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -64,8 +64,11 @@ def __init__( self.eligible_col = eligible_col self.treatment_col = treatment_col self.outcome_col = outcome_col - self.time_varying_cols = time_varying_cols - self.fixed_cols = fixed_cols + # Normalize the documented-Optional covariate lists to [] once, so the + # many downstream `for col in self.fixed_cols` / set() sites need no + # None guards. + self.time_varying_cols = time_varying_cols if time_varying_cols else [] + self.fixed_cols = fixed_cols if fixed_cols else [] self.method = method self._time_initialized = datetime.datetime.now() diff --git a/tests/test_optional_covariate_args.py b/tests/test_optional_covariate_args.py new file mode 100644 index 0000000..7660c89 --- /dev/null +++ b/tests/test_optional_covariate_args.py @@ -0,0 +1,30 @@ +"""Regression test: time_varying_cols and fixed_cols are documented Optional. + +Constructing without them used to crash in _param_checker (`set(None)`), and +several downstream sites iterate self.fixed_cols directly. Omitting both must +work through the whole pipeline. +""" + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def test_pipeline_runs_without_covariate_args(): + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + method="ITT", + parameters=SEQopts(km_curves=True, seed=42), + ) + s.expand() + s.fit() + s.survival() + + # The auto-built outcome formula contains no covariate terms beyond the + # treatment/followup/trial defaults. + assert "sex" not in s.covariates + assert s.km_data.height > 0 From 592da9229d0c53d566427e4dd7f1c86a6c28e091 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 11 Jun 2026 16:15:45 +0100 Subject: [PATCH 16/32] Replace Python 2 dict.has_key with dict.get in SEQoutput.retrieve_data So that switch diagnostics are retrievable, and add the missing f-prefix to its error message --- pySEQTarget/SEQoutput.py | 12 +++--------- tests/test_accessor.py | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index cecf28c..96c02c8 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -141,19 +141,13 @@ def retrieve_data( case "nonunique_compevent": data = self.diagnostic_tables.get("nonunique_compevent") case "unique_switches": - if self.diagnostic_tables.has_key("unique_switches"): - data = self.diagnostic_tables["unique_switches"] - else: - data = None + data = self.diagnostic_tables.get("unique_switches") case "nonunique_switches": - if self.diagnostic_tables.has_key("nonunique_switches"): - data = self.diagnostic_tables["nonunique_switches"] - else: - data = None + data = self.diagnostic_tables.get("nonunique_switches") case _: data = self.km_data if data is None: - raise ValueError("Data {type} was not created in the SEQuential process") + raise ValueError(f"Data {type} was not created in the SEQuential process") return data def to_md(self, filename="SEQuential_results.md") -> None: diff --git a/tests/test_accessor.py b/tests/test_accessor.py index ab9b796..1a5ff81 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -25,3 +25,29 @@ def test_ITT_collector(): collector.retrieve_data("unique_outcomes") with pytest.raises(ValueError): collector.retrieve_data("km_data") + # ITT produces no switch diagnostics: a clean ValueError, not the + # Python-2 dict.has_key AttributeError this used to raise. + with pytest.raises(ValueError, match="not created"): + collector.retrieve_data("unique_switches") + + +def test_censoring_collector_switch_diagnostics(): + # Under method="censoring" the switch diagnostics exist and must be + # retrievable (regression for dict.has_key). + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts(), + ) + s.expand() + s.fit() + collector = s.collect() + assert collector.retrieve_data("unique_switches").height > 0 + assert collector.retrieve_data("nonunique_switches").height > 0 From c0d568cc28190e6929f732f9ed23694e7e425d0c Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 11 Jun 2026 16:23:04 +0100 Subject: [PATCH 17/32] Fix UnboundLocalError when collect() is called before fit() By deriving outcome and compevent model lists under a single guard with a None fallback --- pySEQTarget/SEQuential.py | 14 +++++++++----- tests/test_accessor.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 3153010..c4e6e8c 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -453,13 +453,17 @@ def collect(self) -> SEQoutput: "collection_time": self._time_collected, } - if self.compevent_colname is not None: - compevent_models = [model["compevent"] for model in self.outcome_model] - else: - compevent_models = None - if self.outcome_model is not None: outcome_models = [model["outcome"] for model in self.outcome_model] + if self.compevent_colname is not None: + compevent_models = [model["compevent"] for model in self.outcome_model] + else: + compevent_models = None + else: + # collect() before fit(): no models to report, but the rest of the + # output (diagnostics, timings) is still valid. + outcome_models = None + compevent_models = None if self.risk_estimates is None: risk_ratio = risk_difference = None diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 1a5ff81..15c6767 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -31,6 +31,29 @@ def test_ITT_collector(): collector.retrieve_data("unique_switches") +def test_collect_before_fit(): + # collect() without fit() must return an SEQoutput with None models rather + # than raising UnboundLocalError. Diagnostics from expand() still come + # through. + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(), + ) + s.expand() + collector = s.collect() + assert collector.outcome_models is None + assert collector.compevent_models is None + assert collector.retrieve_data("unique_outcomes").height > 0 + + def test_censoring_collector_switch_diagnostics(): # Under method="censoring" the switch diagnostics exist and must be # retrievable (regression for dict.has_key). From 7e1e51e4f23ca3e00347b71560894ee023ec795e Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 11 Jun 2026 16:32:29 +0100 Subject: [PATCH 18/32] Skip padded None entries in excused_colnames validation And raise a clear error for missing columns instead of a confusing polars TypeError --- pySEQTarget/error/_data_checker.py | 8 +++++++ tests/test_excused_colnames.py | 38 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 tests/test_excused_colnames.py diff --git a/pySEQTarget/error/_data_checker.py b/pySEQTarget/error/_data_checker.py index 217be1c..ec6bc8a 100644 --- a/pySEQTarget/error/_data_checker.py +++ b/pySEQTarget/error/_data_checker.py @@ -37,6 +37,14 @@ def _data_checker(self): ) for col in self.excused_colnames: + # _param_checker pads the list with None up to len(treatment_level) + # when fewer excused columns are supplied. + if col is None: + continue + if col not in self.data.columns: + raise ValueError( + f"excused_colnames entry '{col}' not found in data columns." + ) violations = ( self.data.sort([self.id_col, self.time_col]) .group_by(self.id_col) diff --git a/tests/test_excused_colnames.py b/tests/test_excused_colnames.py new file mode 100644 index 0000000..47793e5 --- /dev/null +++ b/tests/test_excused_colnames.py @@ -0,0 +1,38 @@ +"""Validation of excused_colnames in _data_checker. + +_param_checker pads excused_colnames with None up to len(treatment_level); +the data checker used to feed that None into pl.col() and crash with a +confusing TypeError before the analysis even started. +""" + +import pytest + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _build(**opts): + return SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts(**opts), + ) + + +def test_excused_colnames_shorter_than_treatment_level(): + # One excused column for two treatment levels: the padded None entry must + # be skipped, not validated. + s = _build(excused=True, excused_colnames=["excusedZero"]) + assert s.excused_colnames == ["excusedZero", None] + + +def test_excused_colnames_missing_column_raises_clearly(): + with pytest.raises(ValueError, match="not found in data columns"): + _build(excused=True, excused_colnames=["nonexistent_col"]) From 98ab567b4afd916c1be57a9716f098e463135abe Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 11 Jun 2026 21:50:04 +0100 Subject: [PATCH 19/32] Make _JaxFit picklable By rebuilding the patsy design info from formula and reference frame on unpickle so glm_package jax works with offload and parallel bootstrap --- pySEQTarget/helpers/_jax_fit.py | 60 ++++++++++++++++++++++++++-- tests/test_jax.py | 70 +++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/pySEQTarget/helpers/_jax_fit.py b/pySEQTarget/helpers/_jax_fit.py index 201c503..e7bbf63 100644 --- a/pySEQTarget/helpers/_jax_fit.py +++ b/pySEQTarget/helpers/_jax_fit.py @@ -25,6 +25,16 @@ def __init__( self._design_info = design_info self._feature_names = feature_names self._X_design = X_mat.values + self._nobs = X_mat.shape[0] + # Lazily-filled coefficient covariance cache: lets __getstate__ drop the + # full design matrix while keeping bse/summary working after unpickle. + self._cov_cached = None + # Inputs to rebuild ``design_info`` on unpickle: patsy DesignInfo cannot + # be pickled (patsy #26), so keep the formula plus a tiny reference + # frame (which preserves each categorical column's full, ordered dtype + # categories) and re-parse on __setstate__. Mirrors _GlumFit. + self._formula = formula + self._ref_frame = df_pd.head(2).copy() X_arr = X_mat.drop(columns=["Intercept"], errors="ignore").values y_raw = y_mat.values.ravel() @@ -51,12 +61,46 @@ def __init__( ) # statsmodels 'like' exposure + self._build_model_namespace(design_info, feature_names) + self.exog_names = feature_names + self.params = self._build_params() + + def _build_model_namespace(self, design_info, feature_names): self.model = types.SimpleNamespace( exog_names=feature_names, data=types.SimpleNamespace(design_info=design_info), ) - self.exog_names = feature_names - self.params = self._build_params() + + def __getstate__(self): + # Drop the unpicklable patsy DesignInfo and the SimpleNamespaces that + # reference it; __setstate__ rebuilds them from the formula + ref_frame. + state = self.__dict__.copy() + state.pop("_design_info", None) + state.pop("model", None) + # Replace the full design matrix with the small cached covariance so + # the pickled model stays lightweight. Covariance is only implemented + # for binary fits; multiclass keeps None (bse raises either way). + if state.get("_cov_cached") is None and self._n_classes == 2: + state["_cov_cached"] = self.cov_params() + state["_X_design"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + _, X_mat = patsy.dmatrices( + self._formula, self._ref_frame, return_type="dataframe" + ) + if list(X_mat.columns) != list(self._feature_names): + # The reference frame's categorical ordering must reproduce the + # frozen column structure exactly, or the coefficients would pair + # with the wrong design columns on predict. Fail loudly rather + # than return silently wrong predictions. + raise RuntimeError( + "_JaxFit design columns changed on unpickle: " + f"{list(X_mat.columns)} != {list(self._feature_names)}" + ) + self._design_info = X_mat.design_info + self._build_model_namespace(self._design_info, self._feature_names) def _coef_components(self): W, b = self._jax.params @@ -120,12 +164,20 @@ def cov_params(self): raise NotImplementedError( "Standard errors are only implemented for binary jax fits." ) + if self._cov_cached is not None: + return self._cov_cached X = self._X_design + if X is None: + raise RuntimeError( + "cov_params unavailable: design matrix was dropped on pickle " + "and no covariance was cached." + ) mu = np.asarray(self._jax.predict(X[:, 1:]))[:, 1] w = mu * (1.0 - mu) if self._sample_weight is not None: w = w * self._sample_weight - return np.linalg.pinv(X.T @ (w[:, None] * X)) + self._cov_cached = np.linalg.pinv(X.T @ (w[:, None] * X)) + return self._cov_cached @property def bse(self): @@ -161,7 +213,7 @@ def summary(self): "GLM (jax backend)", "Binomial", "logit", - str(self._X_design.shape[0]), + str(self._nobs), ] }, index=["Model:", "Family:", "Link:", "No. Observations:"], diff --git a/tests/test_jax.py b/tests/test_jax.py index 7461437..68f8c93 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -120,3 +120,73 @@ def test_jax_warm_start_reaches_same_optimum(): start_params=(cold.params.values, list(cold.model.exog_names)), ) assert list(warm.params) == approx(list(cold.params), rel=1e-3, abs=1e-3) + + +def test_jax_fit_pickle_roundtrip(): + # _JaxFit holds a patsy DesignInfo, which cannot be pickled; offload and + # the parallel bootstrap both pickle fitted models. The wrapper must + # rebuild the design info on unpickle (same strategy as _GlumFit) and keep + # predict/bse/summary working. + import pickle + + df = _binary_frame() + m = _JaxFit("y ~ x1 + x2", df) + + m2 = pickle.loads(pickle.dumps(m)) + + assert list(m2.params) == approx(list(m.params), rel=1e-12, abs=1e-12) + assert m2.predict(df) == approx(m.predict(df), rel=1e-10, abs=1e-12) + assert list(m2.bse) == approx(list(m.bse), rel=1e-10, abs=1e-12) + assert str(m2.summary()) + + +def test_jax_offload_bootstrap_survival(): + # End-to-end: offload=True pickles every fitted model to disk via joblib. + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + glm_package="jax", + bootstrap_nboot=2, + seed=7, + km_curves=True, + offload=True, + ), + ) + s.expand() + s.bootstrap() + s.fit() + s.survival() + assert s.km_data.height > 0 + + +def test_jax_parallel_bootstrap(): + # End-to-end: parallel=True pickles the SEQuential object into worker + # processes and the fitted models back. + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + glm_package="jax", bootstrap_nboot=2, seed=7, parallel=True, ncores=2 + ), + ) + s.expand() + s.bootstrap() + s.fit() + assert len(s.outcome_model) == 3 # main + 2 replicates From 2dbf097d06f4e687ca1c476fd9f7ff8ddba83c5d Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 12 Jun 2026 05:49:45 +0100 Subject: [PATCH 20/32] Add jax to dev dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a8dd6b7..191bf0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ Repository = "https://github.com/CausalInference/pySEQTarget" dev = [ "black", "isort", + "jax", "pytest", "myst-parser", "piccolo_theme", From 228ed8a6a2e5d1ea5b3c1b12a0959ca3697e23d5 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 12 Jun 2026 05:54:01 +0100 Subject: [PATCH 21/32] Reject hazard ratio estimation for method dose-response Instead of silently simulating identical arms and returning HR near 1 --- pySEQTarget/SEQuential.py | 8 +++++++ pySEQTarget/error/_param_checker.py | 8 +++++++ tests/test_hazard.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index c4e6e8c..757d30e 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -389,6 +389,14 @@ def hazard(self) -> None: """ start = time.perf_counter() + if self.method == "dose-response": + raise NotImplementedError( + "Hazard ratio estimation is not supported for method='dose-response': " + "the counterfactual simulation only sets the baseline treatment, but " + "the dose-response outcome model depends on the cumulative dose, so " + "both arms would simulate identical outcomes (HR ≈ 1)." + ) + if not hasattr(self, "outcome_model") or not self.outcome_model: raise ValueError( "Outcome model not found. Please run the 'fit' method before calculating hazard ratio." diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index 14b0b94..bd03520 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -41,6 +41,14 @@ def _param_checker(self): if self.km_curves and self.hazard_estimate: raise ValueError("km_curves and hazard cannot both be set to True.") + if self.hazard_estimate and self.method == "dose-response": + raise ValueError( + "Hazard ratio estimation is not supported for method='dose-response': " + "the counterfactual simulation only sets the baseline treatment, but " + "the dose-response outcome model depends on the cumulative dose, so " + "both arms would simulate identical outcomes (HR ≈ 1)." + ) + if sum([self.followup_class, self.followup_include, self.followup_spline]) > 1: raise ValueError( "Only one of followup_class or followup_include can be set to True." diff --git a/tests/test_hazard.py b/tests/test_hazard.py index f3f1854..746f119 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -119,3 +119,36 @@ def flaky_outcome_fit(seq_self, *args, **kwargs): assert hr["Hazard ratio"][0] is not None and np.isfinite(hr["Hazard ratio"][0]) assert hr["LCI"][0] is not None assert hr["UCI"][0] is not None + + +def _dose_response_model(**opts): + return SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="dose-response", + parameters=SEQopts(weighted=True, weight_preexpansion=True, **opts), + ) + + +def test_dose_response_hazard_estimate_rejected_at_construction(): + # The counterfactual hazard simulation only sets the baseline treatment, + # but the dose-response outcome model depends on cumulative dose — both + # arms would simulate identical outcomes and the HR would silently be ~1. + with pytest.raises(ValueError, match="dose-response"): + _dose_response_model(hazard_estimate=True) + + +def test_dose_response_hazard_call_rejected(): + # hazard() can be called regardless of the hazard_estimate flag, so the + # method itself must refuse too. + s = _dose_response_model() + s.expand() + s.fit() + with pytest.raises(NotImplementedError, match="dose-response"): + s.hazard() From 0899c1690fb68f115fd1bd1aa0fcd3313f3fd89b Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 12 Jun 2026 05:59:01 +0100 Subject: [PATCH 22/32] Freeze categorical levels into the glum pickle reference frame dtypes So models fit on plain string covariates survive offload and parallel round-trips --- pySEQTarget/helpers/_glum_fit.py | 15 ++++--- tests/test_glum.py | 70 ++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 5 deletions(-) diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index 74a4b03..4263467 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -255,11 +255,16 @@ def _fit_glum(formula, data, var_weights=None, start_params=None, design_cache=N glm.fit(X_arr, y_arr, **fit_kwargs) # Keep a minimal reference frame so the (unpicklable) design_info can be - # rebuilt on unpickle. Two rows suffice: patsy derives categorical contrasts - # from each column's full dtype categories, not the observed values, and the - # codebase uses only stateless transforms (precomputed squares, explicit-knot - # splines), so no fit-time state needs preserving. - ref_frame = data.head(2).copy() + # rebuilt on unpickle. Two rows suffice ONLY when each categorical factor's + # full, ordered level set lives in the column dtype — patsy derives the + # contrasts from pd.Categorical dtype categories, but for plain string + # columns it falls back to the observed values, and two rows rarely cover + # every level. Freeze the design's levels into the frame's dtypes so the + # re-parse reproduces the frozen column structure regardless of source + # dtype. (The codebase uses only stateless transforms — precomputed + # squares, explicit-knot splines — so no other fit-time state needs + # preserving.) + ref_frame = _align_categories(design_info, data.head(2).copy()) return _GlumFit( glm, design_info, diff --git a/tests/test_glum.py b/tests/test_glum.py index f9a71c5..c42a078 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -435,3 +435,73 @@ def test_glum_warm_start_dropped_when_design_columns_mismatch(): assert list(bogus_fit.params.values) == approx( list(ref.params.values), rel=1e-8, abs=1e-12 ) + + +def test_glum_pickle_with_plain_string_covariate(): + # ref_frame is data.head(2): for a plain object/string column patsy derives + # the categorical levels from the OBSERVED values, so two rows cannot cover + # a 4-level factor and the unpickle column check used to fail with + # RuntimeError. The design's levels must be frozen into the reference + # frame's dtypes instead. + import pickle + + import numpy as np + import pandas as pd + + from pySEQTarget.helpers._glum_fit import _fit_glum + + rng = np.random.default_rng(0) + n = 2000 + levels = ["a", "b", "c", "d"] + # First two rows share one level so head(2) observes a strict subset + grp = ["a", "a"] + list(rng.choice(levels, n - 2)) + df = pd.DataFrame( + { + "grp": grp, # plain object dtype, NOT pd.Categorical + "x": rng.standard_normal(n), + "y": (rng.random(n) < 0.4).astype(int), + } + ) + + m = _fit_glum("y ~ grp + x", df) + m2 = pickle.loads(pickle.dumps(m)) + + assert list(m2.params) == approx(list(m.params), rel=1e-12, abs=1e-12) + assert list(m2.predict(df)) == approx(list(m.predict(df)), rel=1e-10, abs=1e-12) + + +def test_glum_offload_with_string_time_varying_covariate(): + # End-to-end: offload=True round-trips the weight models through joblib. + # A plain string time-varying covariate in the denominator formula must + # survive the pickle/unpickle cycle. + import polars as pl + + data = load_data("SEQdata").with_columns( + pl.when(pl.col("P") < 9) + .then(pl.lit("low")) + .when(pl.col("P") < 10) + .then(pl.lit("mid")) + .otherwise(pl.lit("high")) + .alias("P_grp") + ) + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P_grp"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + glm_package="glum", + weighted=True, + weight_preexpansion=True, + offload=True, + seed=42, + ), + ) + s.expand() + s.fit() + assert s.DT["weight"].is_finite().all() From 8f602401a316588c6964a94d9151864e66ec3c43 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 12 Jun 2026 06:07:09 +0100 Subject: [PATCH 23/32] Skip refitting weight models on bootstrap replicates under weight_preexpansion By caching the main fit's predicted weight frame, since the pre-expansion data is never resampled and the refits were bit-identical --- pySEQTarget/SEQuential.py | 70 +++++++++++------- ...test_weight_fit_cached_across_bootstrap.py | 72 +++++++++++++++++++ 2 files changed, 116 insertions(+), 26 deletions(-) create mode 100644 tests/test_weight_fit_cached_across_bootstrap.py diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 757d30e..d979e67 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -274,32 +274,50 @@ def fit(self) -> None: boot_idx = self._current_boot_idx if self.weighted: - WDT_pl = _weight_setup(self) - if not self.weight_preexpansion and not self.excused: - WDT_pl = WDT_pl.filter(pl.col("followup") > 0) - - # The weight-fit helpers (_fit_LTFU etc.) use pandas-style indexing - # and pass pandas frames to glum/statsmodels, so we convert once. - # The fits don't mutate WDT_pd - they store models on `self` - so - # we keep the original polars frame for the downstream steps - # rather than paying a pl.from_pandas() round-trip per replicate. - WDT_pd = WDT_pl.to_pandas() - for col in self.fixed_cols: - if col in WDT_pd.columns: - WDT_pd[col] = WDT_pd[col].astype("category") - - _fit_LTFU(self, WDT_pd) - _fit_visit(self, WDT_pd) - _fit_numerator(self, WDT_pd) - _fit_denominator(self, WDT_pd) - - if self.offload: - _offload_weights(self, boot_idx) - - del WDT_pd - WDT = _weight_predict(self, WDT_pl) - _weight_bind(self, WDT) - self.weight_stats = _weight_stats(self) + # With weight_preexpansion the weight models are fit on the + # un-resampled pre-expansion data, so every bootstrap replicate + # would refit bit-identical models and re-predict identical + # weights. Cache the predicted weight frame from the main fit and + # reuse it on replicates; only the join onto the resampled DT + # (_weight_bind) and the resulting weight stats differ. + cached_WDT = ( + getattr(self, "_main_weight_WDT", None) + if boot_idx is not None and self.weight_preexpansion + else None + ) + if cached_WDT is not None: + _weight_bind(self, cached_WDT) + self.weight_stats = _weight_stats(self) + else: + WDT_pl = _weight_setup(self) + if not self.weight_preexpansion and not self.excused: + WDT_pl = WDT_pl.filter(pl.col("followup") > 0) + + # The weight-fit helpers (_fit_LTFU etc.) use pandas-style + # indexing and pass pandas frames to glum/statsmodels, so we + # convert once. The fits don't mutate WDT_pd - they store + # models on `self` - so we keep the original polars frame for + # the downstream steps rather than paying a pl.from_pandas() + # round-trip per replicate. + WDT_pd = WDT_pl.to_pandas() + for col in self.fixed_cols: + if col in WDT_pd.columns: + WDT_pd[col] = WDT_pd[col].astype("category") + + _fit_LTFU(self, WDT_pd) + _fit_visit(self, WDT_pd) + _fit_numerator(self, WDT_pd) + _fit_denominator(self, WDT_pd) + + if self.offload: + _offload_weights(self, boot_idx) + + del WDT_pd + WDT = _weight_predict(self, WDT_pl) + if self.weight_preexpansion and boot_idx is None: + self._main_weight_WDT = WDT + _weight_bind(self, WDT) + self.weight_stats = _weight_stats(self) is_boot = boot_idx is not None start = getattr(self, "_outcome_start_params", None) if is_boot else None diff --git a/tests/test_weight_fit_cached_across_bootstrap.py b/tests/test_weight_fit_cached_across_bootstrap.py new file mode 100644 index 0000000..22fb53a --- /dev/null +++ b/tests/test_weight_fit_cached_across_bootstrap.py @@ -0,0 +1,72 @@ +"""With weight_preexpansion=True the weight models are fit on the un-resampled +pre-expansion data, so bootstrap replicates would refit bit-identical models +and re-predict identical weights every iteration. The main fit's predicted +weight frame is cached and replicates only redo the join onto their resampled +DT — results must be unchanged, the weight fitters must run exactly once. +""" + +import sys + +import numpy as np + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _run(monkeypatch=None, disable_cache=False): + seq_mod = sys.modules["pySEQTarget.SEQuential"] + fit_calls = [] + + real_fit_denominator = seq_mod._fit_denominator + + def spy_fit_denominator(self, WDT): + fit_calls.append(getattr(self, "_current_boot_idx", None)) + return real_fit_denominator(self, WDT) + + if monkeypatch is not None: + monkeypatch.setattr(seq_mod, "_fit_denominator", spy_fit_denominator) + if disable_cache: + real_bind = seq_mod._weight_bind + + def bind_no_cache(self, WDT): + result = real_bind(self, WDT) + # Drop the cache after every bind so each replicate refits. + self._main_weight_WDT = None + return result + + monkeypatch.setattr(seq_mod, "_weight_bind", bind_no_cache) + + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + bootstrap_nboot=3, + seed=42, + ), + ) + s.expand() + s.bootstrap() + s.fit() + coefs = np.concatenate([np.asarray(m["outcome"].params) for m in s.outcome_model]) + return fit_calls, coefs + + +def test_weight_models_fit_once_across_bootstrap(monkeypatch): + fit_calls, _ = _run(monkeypatch) + assert fit_calls == [None] # main fit only, no replicate refits + + +def test_cached_weights_match_refit_weights(monkeypatch): + _, cached = _run(monkeypatch) + fit_calls, refit = _run(monkeypatch, disable_cache=True) + assert len(fit_calls) == 4 # cache disabled: main + 3 replicate refits + assert np.array_equal(cached, refit) # bit-identical outcome coefficients From 1d90b06981b31b695c7ecbedc66e534e5f3b6560 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 12 Jun 2026 06:14:19 +0100 Subject: [PATCH 24/32] Offload all weight models under their real attribute names And make SEQoutput.summary load offloaded refs and tolerate absent model lists, and drop the duplicate _DT parquet write in the serial bootstrap path --- pySEQTarget/SEQoutput.py | 19 +++++- pySEQTarget/helpers/_bootstrap.py | 8 +-- pySEQTarget/weighting/_weight_offload.py | 40 +++++++----- tests/test_offload.py | 77 ++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 21 deletions(-) diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index 96c02c8..e2d18b2 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -8,7 +8,7 @@ import polars as pl from statsmodels.base.wrapper import ResultsWrapper -from .helpers import _build_md, _build_pdf +from .helpers import Offloader, _build_md, _build_pdf from .SEQopts import SEQopts @@ -90,7 +90,22 @@ def summary( case _: models = self.outcome_models - return [model.summary() for model in models if model is not None] + if models is None: + return [] + + # Under offload=True the stored entries are path refs; load them back. + loader = None + if self.options is not None and self.options.offload: + loader = Offloader(enabled=True, dir=self.options.offload_dir) + + summaries = [] + for model in models: + if model is None: + continue + if loader is not None: + model = loader.load_model(model) + summaries.append(model.summary()) + return summaries def retrieve_data( self, diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 17542b3..f31a7d4 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -185,12 +185,12 @@ def wrapper(self, *args, **kwargs): self._rng = original_rng self.DT = self._offloader.load_dataframe(original_DT_ref) else: - # Keep original data in memory if offloading is disabled to avoid unnecessary I/O + # original_DT_ref already holds the parquet ref (offload on) or + # the frame itself (offload off) from the save above — don't + # write the parquet a second time. With offload on, drop the + # in-memory frame; replicates reload from the ref. if self._offloader.enabled: - original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT") del original_DT - else: - original_DT_ref = original_DT skipped = 0 boot_sample_idx = [] diff --git a/pySEQTarget/weighting/_weight_offload.py b/pySEQTarget/weighting/_weight_offload.py index 04603e8..2879cc3 100644 --- a/pySEQTarget/weighting/_weight_offload.py +++ b/pySEQTarget/weighting/_weight_offload.py @@ -1,19 +1,29 @@ def _offload_weights(self, boot_idx): - """Helper to offload weight models to disk""" - weight_models = [ + """Offload fitted weight models to disk, replacing them with path refs. + + numerator_model/denominator_model are lists with one fit per treatment + level; the cense/visit models are single fits. Entries already offloaded + (str refs) or never fit (None) are left as-is. Consumers go through + Offloader.load_model, which passes non-str values through. + """ + for attr, name in ( ("numerator_model", "numerator"), ("denominator_model", "denominator"), - ("LTFU_model", "LTFU"), - ("visit_model", "visit"), - ] - - for model_attr, model_name in weight_models: - if hasattr(self, model_attr): - model_list = getattr(self, model_attr) - if model_list and isinstance(model_list, list) and len(model_list) > 0: - latest_model = model_list[-1] - if latest_model is not None: - offloaded = self._offloader.save_model( - latest_model, model_name, boot_idx + ): + model_list = getattr(self, attr, None) + if isinstance(model_list, list): + for i, model in enumerate(model_list): + if model is not None and not isinstance(model, str): + model_list[i] = self._offloader.save_model( + model, f"{name}{i}", boot_idx ) - model_list[-1] = offloaded + + for attr, name in ( + ("cense_numerator_model", "cense_numerator"), + ("cense_denominator_model", "cense_denominator"), + ("visit_numerator_model", "visit_numerator"), + ("visit_denominator_model", "visit_denominator"), + ): + model = getattr(self, attr, None) + if model is not None and not isinstance(model, str): + setattr(self, attr, self._offloader.save_model(model, name, boot_idx)) diff --git a/tests/test_offload.py b/tests/test_offload.py index 1e82858..ddffe2f 100644 --- a/tests/test_offload.py +++ b/tests/test_offload.py @@ -40,3 +40,80 @@ def test_compevent_offload(): warnings.filterwarnings("ignore") model.fit() model.survival() + + +def test_weight_models_fully_offloaded(tmp_path): + # _offload_weights used to check nonexistent attributes (LTFU_model, + # visit_model) and only offload the LAST treatment level's model. All + # fitted weight models must end up as path refs, and summaries must load + # them back transparently. + data = load_data("SEQdata_LTFU") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=False, + cense_colname="LTFU", + offload=True, + offload_dir=str(tmp_path), + seed=42, + ), + ) + s.expand() + s.fit() + + for m in s.numerator_model + s.denominator_model: + assert m is None or isinstance(m, str) + assert isinstance(s.cense_numerator_model, str) + assert isinstance(s.cense_denominator_model, str) + + out = s.collect() + for kind in ("numerator", "denominator", "outcome"): + summaries = out.summary(kind) + assert len(summaries) >= 1 + assert all(str(smry) for smry in summaries) + + +def test_serial_bootstrap_offload_writes_DT_once(monkeypatch, tmp_path): + # The serial bootstrap path used to save the _DT parquet twice per fit. + from pySEQTarget.helpers._offloader import Offloader + + writes = [] + real_save = Offloader.save_dataframe + + def spy(self, df, name): + writes.append(name) + return real_save(self, df, name) + + monkeypatch.setattr(Offloader, "save_dataframe", spy) + + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + bootstrap_nboot=2, + seed=42, + offload=True, + offload_dir=str(tmp_path), + ), + ) + s.expand() + s.bootstrap() + s.fit() + + assert writes.count("_DT") == 1 From 1b619971819b8888a195b9a65b5d8d93bae71b3a Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 12 Jun 2026 06:18:36 +0100 Subject: [PATCH 25/32] Ship the SEQuential object and analysis frame to parallel bootstrap workers once per worker This is via a pool initializer instead of pickling them into every task --- pySEQTarget/helpers/_bootstrap.py | 41 ++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index f31a7d4..a80fecb 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -111,6 +111,29 @@ def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): return result +# Process-pool worker state for the parallel bootstrap fit. Set once per +# worker process by the initializer so each task ships only the replicate +# index — not the (slimmed) SEQuential object or the full analysis frame, +# which previously crossed the process boundary once per task. +_FIT_WORKER_OBJ = None +_FIT_WORKER_DATA = None +_FIT_WORKER_CALL = None + + +def _fit_pool_init(obj, data_ref, method_name, seed, args, kwargs): + global _FIT_WORKER_OBJ, _FIT_WORKER_DATA, _FIT_WORKER_CALL + _FIT_WORKER_OBJ = obj + _FIT_WORKER_DATA = obj._offloader.load_dataframe(data_ref) + _FIT_WORKER_CALL = (method_name, seed, args, kwargs) + + +def _fit_pool_task(i): + method_name, seed, args, kwargs = _FIT_WORKER_CALL + return _bootstrap_worker( + _FIT_WORKER_OBJ, method_name, _FIT_WORKER_DATA, i, seed, args, kwargs + ) + + def bootstrap_loop(method): @wraps(method) def wrapper(self, *args, **kwargs): @@ -150,19 +173,13 @@ def wrapper(self, *args, **kwargs): self._rng = None self.DT = None - with ProcessPoolExecutor(max_workers=ncores) as executor: + with ProcessPoolExecutor( + max_workers=ncores, + initializer=_fit_pool_init, + initargs=(self, original_DT_ref, method_name, seed, args, kwargs), + ) as executor: futures = { - executor.submit( - _bootstrap_worker, - self, - method_name, - original_DT_ref, - i, - seed, - args, - kwargs, - ): i - for i in range(nboot) + executor.submit(_fit_pool_task, i): i for i in range(nboot) } skipped = 0 boot_sample_idx = [] From 234fd44eb9c7f013277dffcea68eedfeea604d67 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 12 Jun 2026 06:38:25 +0100 Subject: [PATCH 26/32] Share the polars-to-pandas conversion across numerator/denominator and cense/visit weight predictions Instead of reconverting the same rows for every model --- pySEQTarget/helpers/_predict_model.py | 18 +++++++++++--- pySEQTarget/weighting/_weight_pred.py | 35 +++++++++++++++++++-------- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index 672f4f3..13c2655 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -41,14 +41,22 @@ def _safe_predict(model, data, clip_probs=True): return probs -def _predict_model(self, model, newdata): - newdata = newdata.to_pandas() +def _prep_predict_frame(self, newdata): + """Convert a polars frame to pandas with fixed_cols cast to category. - # Original behavior - convert fixed_cols to category + Split out from _predict_model so callers predicting with several models on + the same rows (e.g. numerator + denominator in _weight_predict) can pay + the conversion once and share the frame. + """ + newdata = newdata.to_pandas() for col in self.fixed_cols: if col in newdata.columns: newdata[col] = newdata[col].astype("category") + return newdata + +def _predict_model_pd(model, newdata): + """Predict on an already-prepared pandas frame, with category fix retry.""" try: return np.array(model.predict(newdata)) except Exception as e: @@ -57,3 +65,7 @@ def _predict_model(self, model, newdata): return np.array(model.predict(newdata)) else: raise + + +def _predict_model(self, model, newdata): + return _predict_model_pd(model, _prep_predict_frame(self, newdata)) diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 6a2bca9..56341b5 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -2,6 +2,7 @@ import polars as pl from ..helpers import _predict_model +from ..helpers._predict_model import _predict_model_pd, _prep_predict_frame def _extract_class_probability(p, level_idx, is_binary): @@ -59,17 +60,22 @@ def _weight_predict(self, WDT): denom_model = self._offloader.load_model(self.denominator_model[i]) num_model = self._offloader.load_model(self.numerator_model[i]) - if denom_model is not None and lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, denom_model, subset) + if (denom_model is None and num_model is None) or lag_mask.sum() == 0: + continue + + # Numerator and denominator predict on the same rows — pay the + # filter + pandas conversion once and share the frame. + subset_pd = _prep_predict_frame(self, WDT.filter(pl.Series(lag_mask))) + + if denom_model is not None: + p = _predict_model_pd(denom_model, subset_pd) p_class = _extract_class_probability(p, i, is_binary) pred_denom[lag_mask] = np.where( switched_treatment[lag_mask], 1.0 - p_class, p_class ) - if num_model is not None and lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, num_model, subset) + if num_model is not None: + p = _predict_model_pd(num_model, subset_pd) p_class = _extract_class_probability(p, i, is_binary) pred_num[lag_mask] = np.where( switched_treatment[lag_mask], 1.0 - p_class, p_class @@ -167,12 +173,17 @@ def _weight_predict(self, WDT): .otherwise(pl.col("numerator")) .alias("numerator") ) + # Full-frame pandas conversion shared by the cense and visit predictions + # (built lazily — only when at least one of those model pairs exists). + WDT_pd = None + if self.cense_colname is not None: cense_num_model = self._offloader.load_model(self.cense_numerator_model) cense_denom_model = self._offloader.load_model(self.cense_denominator_model) if cense_num_model is not None and cense_denom_model is not None: - p_num = _predict_model(self, cense_num_model, WDT).flatten() - p_denom = _predict_model(self, cense_denom_model, WDT).flatten() + WDT_pd = _prep_predict_frame(self, WDT) + p_num = _predict_model_pd(cense_num_model, WDT_pd).flatten() + p_denom = _predict_model_pd(cense_denom_model, WDT_pd).flatten() WDT = ( WDT.with_columns( [ @@ -196,8 +207,12 @@ def _weight_predict(self, WDT): visit_num_model = self._offloader.load_model(self.visit_numerator_model) visit_denom_model = self._offloader.load_model(self.visit_denominator_model) if visit_num_model is not None and visit_denom_model is not None: - p_num = _predict_model(self, visit_num_model, WDT).flatten() - p_denom = _predict_model(self, visit_denom_model, WDT).flatten() + # The visit formulas don't reference the _cense column added above, + # so the frame converted before the cense block is still valid. + if WDT_pd is None: + WDT_pd = _prep_predict_frame(self, WDT) + p_num = _predict_model_pd(visit_num_model, WDT_pd).flatten() + p_denom = _predict_model_pd(visit_denom_model, WDT_pd).flatten() WDT = ( WDT.with_columns( [ From 1434d16162cc8c306e028df05124079f8c5f0029 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 12 Jun 2026 06:59:17 +0100 Subject: [PATCH 27/32] Route _build_md model sections through SEQoutput.summary So reports survive absent weight models and offloaded refs, fix the summary and retrieve_data default annotations, and name all three followup flags in the param checker error. --- pySEQTarget/SEQoutput.py | 9 ++-- pySEQTarget/error/_param_checker.py | 3 +- pySEQTarget/helpers/_output_files.py | 49 ++++++++----------- pyproject.toml | 1 + tests/test_output_files.py | 71 ++++++++++++++++++++++++++++ 5 files changed, 101 insertions(+), 32 deletions(-) create mode 100644 tests/test_output_files.py diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index e2d18b2..8860296 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -72,7 +72,10 @@ def plot(self) -> None: plt.show() def summary( - self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]] + self, + type: Optional[ + Literal["numerator", "denominator", "outcome", "compevent"] + ] = None, ) -> List: """ Returns a list of model summaries of either the numerator, denominator, outcome, or competing event models @@ -109,7 +112,7 @@ def summary( def retrieve_data( self, - type=Optional[ + type: Optional[ Literal[ "km_data", "hazard", @@ -124,7 +127,7 @@ def retrieve_data( "unique_switches", "nonunique_switches", ] - ], + ] = None, ) -> pl.DataFrame: """ Getter for data stored within ``SEQoutput`` diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index bd03520..9efa54a 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -51,7 +51,8 @@ def _param_checker(self): if sum([self.followup_class, self.followup_include, self.followup_spline]) > 1: raise ValueError( - "Only one of followup_class or followup_include can be set to True." + "Only one of followup_class, followup_include, or followup_spline " + "can be set to True." ) if self.followup_spline_df < 2: diff --git a/pySEQTarget/helpers/_output_files.py b/pySEQTarget/helpers/_output_files.py index aab4177..4eccd99 100644 --- a/pySEQTarget/helpers/_output_files.py +++ b/pySEQTarget/helpers/_output_files.py @@ -12,6 +12,19 @@ def _build_md(self, img_path: str = None) -> str: lines = [] + def _model_section(title, kind): + # SEQoutput.summary handles absent model lists (e.g. no numerator + # models under weighted ITT) and loads offloaded path refs. + summaries = self.summary(kind) + if not summaries: + return + lines.append(f"### {title}") + lines.append("") + lines.append("```") + lines.append(str(summaries[0])) + lines.append("```") + lines.append("") + lines.append(f"# SEQuential Analysis: {datetime.date.today()}: {self.method}") lines.append("") @@ -19,42 +32,22 @@ def _build_md(self, img_path: str = None) -> str: lines.append("## Weighting") lines.append("") - lines.append("### Numerator Model") - lines.append("") - lines.append("```") - lines.append(str(self.numerator_models[0].summary())) - lines.append("```") - lines.append("") + _model_section("Numerator Model", "numerator") + _model_section("Denominator Model", "denominator") - lines.append("### Denominator Model") - lines.append("") - lines.append("```") - lines.append(str(self.denominator_models[0].summary())) - lines.append("```") - lines.append("") + if self.options.compevent_colname is not None: + _model_section("Competing Event Model", "compevent") - if self.options.compevent_colname is not None and self.compevent_models: - lines.append("### Competing Event Model") + if self.weight_statistics is not None: + lines.append("### Weighting Statistics") lines.append("") - lines.append("```") - lines.append(str(self.compevent_models[0].summary())) - lines.append("```") + lines.append(self.weight_statistics.to_pandas().to_markdown(index=False)) lines.append("") - lines.append("### Weighting Statistics") - lines.append("") - lines.append(self.weight_statistics.to_pandas().to_markdown(index=False)) - lines.append("") - lines.append("## Outcome") lines.append("") - lines.append("### Outcome Model") - lines.append("") - lines.append("```") - lines.append(str(self.outcome_models[0].summary())) - lines.append("```") - lines.append("") + _model_section("Outcome Model", "outcome") if self.options.hazard_estimate and self.hazard is not None: lines.append("### Hazard") diff --git a/pyproject.toml b/pyproject.toml index 191bf0d..08e7c2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ dev = [ "sphinx", "sphinx-copybutton", "sphinx-autodoc-typehints", + "tabulate", ] [tool.setuptools.packages.find] diff --git a/tests/test_output_files.py b/tests/test_output_files.py new file mode 100644 index 0000000..e9f7dcf --- /dev/null +++ b/tests/test_output_files.py @@ -0,0 +1,71 @@ +"""Markdown report generation (SEQoutput.to_md / _build_md). + +_build_md used to index numerator_models[0]/outcome_models[0] directly, which +crashed for weighted ITT analyses (no treatment-weight models exist — the +attribute is None) and for offloaded models (path refs, not fitted objects). +It now routes through SEQoutput.summary, which handles both. +""" + +import pytest + +# pandas.DataFrame.to_markdown needs tabulate (the "output" extra). +pytest.importorskip("tabulate") + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def test_to_md_weighted_ITT_without_numerator_models(tmp_path): + s = SEQuential( + load_data("SEQdata_LTFU"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(weighted=True, cense_colname="LTFU", seed=42), + ) + s.expand() + s.fit() + out = s.collect() + + md_file = tmp_path / "report.md" + out.to_md(str(md_file)) + content = md_file.read_text() + assert "Outcome Model" in content + # No treatment-weight models under ITT: the section is skipped, not a crash. + assert "Numerator Model" not in content + + +def test_to_md_with_offloaded_models(tmp_path): + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + offload=True, + offload_dir=str(tmp_path / "models"), + seed=42, + ), + ) + s.expand() + s.fit() + out = s.collect() + + md_file = tmp_path / "report.md" + out.to_md(str(md_file)) + content = md_file.read_text() + assert "Numerator Model" in content + assert "Denominator Model" in content + assert "Outcome Model" in content From 1868b7ca782fd7996ea04af57bb9470070373421 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Mon, 15 Jun 2026 10:29:57 +0100 Subject: [PATCH 28/32] Remove Linux tests Because the GHA runner has no GPU these tests don't do anything different to the macOS tests. --- .github/workflows/python-app.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index cdf2d43..4fe70a5 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [macos-26, ubuntu-latest] + os: [macos-26] python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: From 752af61128058d4966a42ac1c809e9e0c172275b Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 16 Jun 2026 11:24:20 +0100 Subject: [PATCH 29/32] Clarify unique vs non-unique in diagnostic table docs and report labels --- pySEQTarget/SEQoutput.py | 21 ++++++++++++++++++++- pySEQTarget/helpers/_output_files.py | 18 +++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index 8860296..bcce6a7 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -41,7 +41,13 @@ class SEQoutput: :type risk_difference: pl.DataFrame or None :param time: Timings for every step of the process completed thus far :type time: dict or None - :param diagnostic_tables: Diagnostic tables for unique and nonunique outcome events and treatment switches + :param diagnostic_tables: Diagnostic tables (outcome, follow-up, switch, and + competing-event counts where applicable), each split by baseline treatment + arm. The "unique" tables count distinct subjects; the "nonunique" tables + count rows: total outcome events for the outcome tables, and total + person-time intervals (expanded follow-up rows) for the follow-up tables. + For a one-time (terminal) outcome the unique and nonunique outcome counts + coincide, since each subject contributes at most one event row. :type diagnostic_tables: dict or None """ @@ -132,6 +138,19 @@ def retrieve_data( """ Getter for data stored within ``SEQoutput`` + The diagnostic tables come in "unique" and "nonunique" variants that count + different things, each broken down by baseline treatment arm: + + - ``unique_outcomes`` / ``nonunique_outcomes``: distinct subjects who had + the outcome vs. the total number of outcome events. These coincide for a + one-time (terminal) outcome, since each subject contributes at most one + event row. + - ``unique_followup`` / ``nonunique_followup``: distinct subjects + contributing follow-up vs. the total number of person-time intervals + (expanded rows). The nonunique count is much larger because each subject + contributes one row per follow-up period; it is the denominator that, + with ``nonunique_outcomes``, gives the per-arm event rate. + :param type: Data which you would like to access, ['km_data', 'hazard', 'risk_ratio', 'risk_difference', 'unique_outcomes', 'nonunique_outcomes', 'unique_followup', 'nonunique_followup', diff --git a/pySEQTarget/helpers/_output_files.py b/pySEQTarget/helpers/_output_files.py index 4eccd99..dd4cbb7 100644 --- a/pySEQTarget/helpers/_output_files.py +++ b/pySEQTarget/helpers/_output_files.py @@ -78,10 +78,26 @@ def _model_section(title, kind): lines.append("") if self.diagnostic_tables: + # Clarify what each unique/nonunique table actually counts, so the + # rendered headings are not ambiguous (see SEQoutput.retrieve_data). + diag_descriptions = { + "unique_outcomes": "distinct subjects who had the outcome", + "nonunique_outcomes": "total outcome events", + "unique_followup": "distinct subjects contributing follow-up", + "nonunique_followup": "person-time intervals", + "unique_compevent": "distinct subjects with a competing event", + "nonunique_compevent": "total competing-event intervals", + "unique_switches": "distinct subjects who switched", + "nonunique_switches": "total switch intervals", + } lines.append("## Diagnostic Tables") lines.append("") for name, table in self.diagnostic_tables.items(): - lines.append(f"### {name.replace('_', ' ').title()}") + heading = name.replace("_", " ").title() + description = diag_descriptions.get(name) + if description: + heading = f"{heading} ({description})" + lines.append(f"### {heading}") lines.append("") lines.append(table.to_pandas().to_markdown(index=False)) lines.append("") From 7dfaf82b735f6ca1aae8227922e0efcaa632c53e Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 16 Jun 2026 11:27:23 +0100 Subject: [PATCH 30/32] Remove comment from random selection filter --- pySEQTarget/expansion/_selection.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pySEQTarget/expansion/_selection.py b/pySEQTarget/expansion/_selection.py index 4cc1336..adaedf2 100644 --- a/pySEQTarget/expansion/_selection.py +++ b/pySEQTarget/expansion/_selection.py @@ -37,9 +37,6 @@ def _random_selection(self): ).alias("trialID") ) .filter( - # Parentheses matter: `|` binds tighter than `!=`, so without them - # this parses as `(is_in | col) != level`, which silently drops - # every sampled control trial when treatment_level[0] != 0. pl.col("trialID").is_in(sample) | ( pl.col(f"{self.treatment_col}{self.indicator_baseline}") From b8dff51cb9e0490555992a0450d3e75682ec9f5b Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 16 Jun 2026 12:56:21 +0100 Subject: [PATCH 31/32] Relax cached-weight bootstrap test to numerical precision --- tests/test_weight_fit_cached_across_bootstrap.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_weight_fit_cached_across_bootstrap.py b/tests/test_weight_fit_cached_across_bootstrap.py index 22fb53a..f3a1cbd 100644 --- a/tests/test_weight_fit_cached_across_bootstrap.py +++ b/tests/test_weight_fit_cached_across_bootstrap.py @@ -1,8 +1,9 @@ """With weight_preexpansion=True the weight models are fit on the un-resampled -pre-expansion data, so bootstrap replicates would refit bit-identical models +pre-expansion data, so bootstrap replicates would refit identical models and re-predict identical weights every iteration. The main fit's predicted weight frame is cached and replicates only redo the join onto their resampled -DT — results must be unchanged, the weight fitters must run exactly once. +DT — results must be unchanged (to numerical precision), the weight fitters +must run exactly once. """ import sys @@ -69,4 +70,8 @@ def test_cached_weights_match_refit_weights(monkeypatch): _, cached = _run(monkeypatch) fit_calls, refit = _run(monkeypatch, disable_cache=True) assert len(fit_calls) == 4 # cache disabled: main + 3 replicate refits - assert np.array_equal(cached, refit) # bit-identical outcome coefficients + # Identical to numerical precision: the cached and refit paths assemble the + # GLM input via different code routes, so multi-threaded BLAS can differ in + # the last few ULPs (passes bit-identical on CI, not always locally). A tight + # tolerance still catches any real divergence, which would be far larger. + assert np.allclose(cached, refit, rtol=0, atol=1e-10) From 69cabccaecfb39419bca64955800625e8b8f56ce Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Tue, 16 Jun 2026 13:05:31 +0100 Subject: [PATCH 32/32] Fix flake8 warnings in SEQopts and optional-dependency test imports --- pySEQTarget/SEQopts.py | 3 ++- tests/test_jax.py | 6 +++--- tests/test_output_files.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index ae0fec9..19becbd 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -25,7 +25,8 @@ class SEQopts: :param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions, default ``[]`` :param expand_only: If True, ``SEQuential.expand()`` returns the expanded dataset and skips weighting, modelling, and survival steps - :param glm_package: Backend for fitting logistic (outcome/competing-event) models ["statsmodels", "glum", or "jax"], default "statsmodels". + :param glm_package: Backend for fitting logistic (outcome/competing-event) + models ["statsmodels", "glum", or "jax"], default "statsmodels". :param followup_class: Boolean to force followup values to be treated as classes :param followup_include: Boolean to force regular followup values into model covariates :param followup_spline: Boolean to force followup values to be fit to cubic spline diff --git a/tests/test_jax.py b/tests/test_jax.py index 68f8c93..ae6bfa2 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -7,9 +7,9 @@ # every platform — skip the whole module rather than erroring at collection. pytest.importorskip("jax") -from pySEQTarget import SEQopts, SEQuential -from pySEQTarget.data import load_data -from pySEQTarget.helpers._jax_fit import _JaxFit +from pySEQTarget import SEQopts, SEQuential # noqa: E402 +from pySEQTarget.data import load_data # noqa: E402 +from pySEQTarget.helpers._jax_fit import _JaxFit # noqa: E402 def _fit(method, glm_package, dataset="SEQdata", **opts): diff --git a/tests/test_output_files.py b/tests/test_output_files.py index e9f7dcf..8ff47ac 100644 --- a/tests/test_output_files.py +++ b/tests/test_output_files.py @@ -11,8 +11,8 @@ # pandas.DataFrame.to_markdown needs tabulate (the "output" extra). pytest.importorskip("tabulate") -from pySEQTarget import SEQopts, SEQuential -from pySEQTarget.data import load_data +from pySEQTarget import SEQopts, SEQuential # noqa: E402 +from pySEQTarget.data import load_data # noqa: E402 def test_to_md_weighted_ITT_without_numerator_models(tmp_path):