Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/linear_functional.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
"\n",
"gplt.continuous_1d(\n",
" term=gam.predictors[\"deaths\"][0],\n",
" term=0,\n",
" gam=gam,\n",
" ax=axes[0],\n",
" )\n",
Expand Down
26 changes: 18 additions & 8 deletions pymgcv/gam.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def covariance(
@overload
def partial_effect(
self,
term: AbstractTerm,
term: AbstractTerm | int,
target: str | None = None,
data: pd.DataFrame | Mapping[str, pd.Series | np.ndarray] | None = None,
*,
Expand All @@ -461,7 +461,7 @@ def partial_effect(
@overload
def partial_effect(
self,
term: AbstractTerm,
term: AbstractTerm | int,
target: str | None = None,
data: pd.DataFrame | Mapping[str, pd.Series | np.ndarray] | None = None,
*,
Expand All @@ -470,7 +470,7 @@ def partial_effect(

def partial_effect(
self,
term: AbstractTerm,
term: AbstractTerm | int,
target: str | None = None,
data: pd.DataFrame | Mapping[str, pd.Series | np.ndarray] | None = None,
*,
Expand All @@ -483,7 +483,8 @@ def partial_effect(

Args:
term: The specific term to evaluate (must match a term used in the
original model specification)
original model specification) or an integer index representing
the position of the term in the target's predictor list
target: Name of the target variable (response variable or family
parameter name from the model specification). If set to None, an error
is raised when multiple predictors are present; otherwise, the sole
Expand All @@ -496,8 +497,6 @@ def partial_effect(
raise ValueError(
"Cannot compute partial effect before fitting the model.",
)
if data is not None:
self._check_data(data, requires=term)

if target is None:
if len(self.all_predictors) > 1:
Expand All @@ -506,6 +505,12 @@ def partial_effect(
)
target = list(self.all_predictors.keys())[0]

if isinstance(term, int):
term = self.all_predictors[target][term]

if data is not None:
self._check_data(data, requires=term)

data = data if data is not None else self.fit_state.data

formula_idx = list(self.all_predictors.keys()).index(target)
Expand Down Expand Up @@ -547,7 +552,7 @@ def penalty_edf(self):

def partial_residuals(
self,
term: AbstractTerm,
term: AbstractTerm | int,
target: str | None = None,
data: pd.DataFrame | Mapping[str, pd.Series | np.ndarray] | None = None,
*,
Expand All @@ -561,7 +566,9 @@ def partial_residuals(
different functional form might be more appropriate.

Args:
term: The model term to compute partial residuals for.
term: The model term to compute partial residuals for. If an integer,
it is interpreted as the index of the term in the predictor of
``target``.
target: Name of the target variable (response variable or family
parameter name from the model specification). If set to None, an error
is raised when multiple predictors are present; otherwise, the sole
Expand All @@ -588,6 +595,9 @@ def partial_residuals(
)
target = list(self.all_predictors.keys())[0]

if isinstance(term, int):
term = self.all_predictors[target][term]

link_fit = self.predict(data)[target] # _check_data called within predict
data = data if data is not None else self.fit_state.data
data = deepcopy(data)
Expand Down
34 changes: 29 additions & 5 deletions pymgcv/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _plot_wrapper(axes: Iterable[Axes], **kwargs: Any):

def continuous_1d(
*,
term: AbstractTerm,
term: AbstractTerm | int,
gam: AbstractGAM,
target: str | None = None,
data: pd.DataFrame | Mapping[str, pd.Series | np.ndarray] | None = None,
Expand All @@ -273,6 +273,8 @@ def continuous_1d(

Args:
term: The model term to plot. Must be a univariate term (single variable).
If an integer is provided, it is assumed to be the index of the term
in the predictor of ``target``.
gam: GAM model containing the term to plot.
target: Name of the target variable (response variable or family
parameter name from the model specification). If set to None, an error
Expand Down Expand Up @@ -310,6 +312,9 @@ def continuous_1d(
)
target = list(gam.all_predictors.keys())[0]

if isinstance(term, int):
term = gam.all_predictors[target][term]

data = data if data is not None else gam.fit_state.data
data = deepcopy(data)
term = _FactorSmoothToByInterface(term) if _is_random_wiggly(term) else term
Expand Down Expand Up @@ -390,7 +395,7 @@ def continuous_1d(

def continuous_2d(
*,
term: AbstractTerm,
term: AbstractTerm | int,
gam: AbstractGAM,
target: str | None = None,
data: pd.DataFrame | Mapping[str, np.ndarray | pd.Series] | None = None,
Expand All @@ -408,7 +413,9 @@ def continuous_2d(

Args:
term: The bivariate term to plot. Must have exactly two variables.
Can be S('x1', 'x2') or T('x1', 'x2').
Can be S('x1', 'x2') or T('x1', 'x2'). If an integer is provided,
it is interpreted as the index of the term the list of predictors
for ``target``.
gam: GAM model containing the term to plot.
target: Name of the target variable (response variable or family
parameter name from the model specification). If set to None, an error
Expand Down Expand Up @@ -446,6 +453,9 @@ def continuous_2d(
)
target = list(gam.all_predictors.keys())[0]

if isinstance(term, int):
term = gam.all_predictors[target][term]

data = data if data is not None else gam.fit_state.data
data = deepcopy(data)
term = _FactorSmoothToByInterface(term) if _is_random_wiggly(term) else term
Expand Down Expand Up @@ -530,7 +540,7 @@ def continuous_2d(

def categorical(
*,
term: L,
term: L | int,
gam: AbstractGAM,
target: str | None = None,
data: pd.DataFrame | Mapping[str, pd.Series | np.ndarray] | None = None,
Expand Down Expand Up @@ -575,6 +585,12 @@ def categorical(
)
target = list(gam.all_predictors.keys())[0]

if isinstance(term, int):
term = gam.all_predictors[target][term] # type: ignore - checked below

if not isinstance(term, L):
raise TypeError("The term must be a linear term.")

data = gam.fit_state.data if data is None else data

errorbar_kwargs = {} if errorbar_kwargs is None else errorbar_kwargs
Expand Down Expand Up @@ -632,7 +648,7 @@ def categorical(

def random_effect(
*,
term: S,
term: S | int,
gam: AbstractGAM,
target: str | None = None,
confidence_interval_level: float = 0.95,
Expand All @@ -650,6 +666,8 @@ def random_effect(
Args:
term: The random effect term to plot. Must be a smooth term with a
[`RandomEffect`][pymgcv.basis_functions.RandomEffect] basis function.
If an integer is provided, it is assumed to be the index of the term
in the predictors for ``target``.
gam: The fitted GAM model containing the random effect.
target: The target variable to plot when multiple predictors are present.
If None and only one predictor exists, that predictor is used.
Expand Down Expand Up @@ -686,6 +704,12 @@ def random_effect(
)
target = list(gam.all_predictors.keys())[0]

if isinstance(term, int):
term = gam.all_predictors[target][term] # type: ignore - checked below

if not isinstance(term, S):
raise TypeError("Term is not a smooth term.")

scatter_kwargs = {} if scatter_kwargs is None else scatter_kwargs
scatter_kwargs.setdefault("s", 0.05 * rcParams["lines.markersize"] ** 2)

Expand Down