Skip to content

Commit a25c543

Browse files
committed
Restrict bootstrap warm-start to exact column-name matches and retry on failure
1 parent ca5db83 commit a25c543

2 files changed

Lines changed: 26 additions & 8 deletions

File tree

pySEQTarget/SEQuential.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,10 @@ def fit(self) -> None:
279279
models_list = _subgroup_fit(self, start_params=start)
280280
if not is_boot:
281281
self._outcome_start_params = {
282-
val: {key: m.params.values for key, m in sg.items()}
282+
val: {
283+
key: (m.params.values, list(m.model.exog_names))
284+
for key, m in sg.items()
285+
}
283286
for val, sg in zip(self._unique_subgroups, models_list)
284287
}
285288
return models_list
@@ -310,7 +313,10 @@ def fit(self) -> None:
310313
)
311314

312315
if not is_boot:
313-
self._outcome_start_params = {k: m.params.values for k, m in models.items()}
316+
self._outcome_start_params = {
317+
k: (m.params.values, list(m.model.exog_names))
318+
for k, m in models.items()
319+
}
314320

315321
if self.offload:
316322
offloaded_models = {}

pySEQTarget/analysis/_outcome_fit.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,22 @@ def _outcome_fit(
104104

105105
model = smf.glm(**glm_kwargs)
106106

107-
# Drop warm-start coefs if the design matrix column count doesn't match
108-
# (e.g. categorical level dropped or added in a bootstrap resample).
109-
if start_params is not None and len(start_params) != model.exog.shape[1]:
110-
start_params = None
111-
112-
model_fit = model.fit(start_params=start_params)
107+
# Drop warm-start coefs unless the design matrix columns match exactly
108+
# by name — bootstrap resamples can shift categorical reference levels or
109+
# column ordering, in which case the cached coefs are meaningless and
110+
# IRLS can diverge into NaN/Inf and crash LAPACK.
111+
if start_params is not None:
112+
sp_values, sp_names = start_params
113+
if list(model.exog_names) != list(sp_names):
114+
start_params = None
115+
else:
116+
start_params = sp_values
117+
118+
try:
119+
model_fit = model.fit(start_params=start_params)
120+
except Exception:
121+
if start_params is not None:
122+
model_fit = model.fit()
123+
else:
124+
raise
113125
return model_fit

0 commit comments

Comments
 (0)