From ecd9139be54843d9f5e69c02ebf4285070f4c096 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Thu, 4 Sep 2025 14:55:09 +0100 Subject: [PATCH] Avoid square qq (use axline) --- pymgcv/plot.py | 17 +++++++---------- tests/test_plot.py | 1 - 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pymgcv/plot.py b/pymgcv/plot.py index 721e146..d45e630 100644 --- a/pymgcv/plot.py +++ b/pymgcv/plot.py @@ -756,7 +756,7 @@ def qq( qq_fun: Callable[[AbstractGAM], QQResult] = qq_simulate, scatter_kwargs: dict | None = None, fill_between_kwargs: dict | None = None, - plot_kwargs: dict | None = None, + axline_kwargs: dict | None = None, ax: Axes | None = None, ) -> Axes: """A Q-Q plot of deviance residuals. @@ -773,7 +773,7 @@ def qq( scatter_kwargs: Key word arguments passed to `matplotlib.pyplot.scatter`. fill_between_kwargs: Key word arguments passed to `matplotlib.pyplot.fill_between`, for plotting the confidence interval. - plot_kwargs: Key word arguments passed to `matplotlib.pyplot.plot` for + axline_kwargs: Key word arguments passed to `matplotlib.pyplot.axline` for plotting the reference line. Pass {"disable": True} to avoid plotting. ax: Matplotlib axes to use for the plot. @@ -842,11 +842,11 @@ def qq( if "c" not in scatter_kwargs and "color" not in scatter_kwargs: scatter_kwargs["color"] = "black" - plot_kwargs = {} if plot_kwargs is None else plot_kwargs + axline_kwargs = {} if axline_kwargs is None else axline_kwargs - if "c" not in plot_kwargs and "color" not in plot_kwargs: - plot_kwargs["color"] = "gray" - plot_kwargs.setdefault("linestyle", "--") + if "c" not in axline_kwargs and "color" not in axline_kwargs: + axline_kwargs["color"] = "gray" + axline_kwargs.setdefault("linestyle", "--") ax = plt.gca() if ax is None else ax qq_data = qq_fun(gam) @@ -859,10 +859,7 @@ def qq( ax.scatter(qq_data.theoretical, qq_data.residuals, **scatter_kwargs) ax.set_xlabel("Theoretical Quantiles") ax.set_ylabel("Residuals") - - min_val = min(ax.get_xlim()[0], ax.get_ylim()[0]) - max_val = max(ax.get_xlim()[1], ax.get_ylim()[1]) - _with_disable(ax.plot)([min_val, max_val], [min_val, max_val], **plot_kwargs) + _with_disable(ax.axline)((0, 0), slope=1, **axline_kwargs) return ax diff --git a/tests/test_plot.py b/tests/test_plot.py index f0f008d..ad6dbde 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -136,7 +136,6 @@ def test_qq(test_case: tc.GAMTestCase): pass else: raise - plt.close("all")