From 60fd710da1bc88a871dc696c78bfc76567755501 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Tue, 24 Jan 2023 12:19:28 +0100 Subject: [PATCH] Flatten the implementation of the pipeline() decorator. It seems easier to follow in a single function rather than being spread out in multiple helpers. (It also makes the signature of `pipeline()` explicit without having to duplicate it.) --- slicerator/__init__.py | 125 +++++++++-------------------------------- 1 file changed, 27 insertions(+), 98 deletions(-) diff --git a/slicerator/__init__.py b/slicerator/__init__.py index 4969bf3..6f0e346 100644 --- a/slicerator/__init__.py +++ b/slicerator/__init__.py @@ -3,7 +3,7 @@ import collections.abc import itertools -from functools import wraps +from functools import partial, wraps from copy import copy import inspect @@ -503,7 +503,7 @@ def __setstate__(self, data_as_list): return self.__init__(lambda x: x, data_as_list) -def pipeline(func=None, **kwargs): +def pipeline(func=None, *, retain_doc=False, ancestor_count=1): """Decorator to enable lazy evaluation of a function. When the function is applied to a Slicerator or Pipeline object, it @@ -540,8 +540,8 @@ def pipeline(func=None, **kwargs): Apply the pipeline decorator to your image processing function. >>> @pipeline - ... def color_channel(image, channel): - ... return image[channel, :, :] + ... def color_channel(image, channel): + ... return image[channel, :, :] ... @@ -583,94 +583,19 @@ def pipeline(func=None, **kwargs): ... def sum_offset(img1, img2, offset): ... return img1 + img2 + offset """ - def wrapper(f): - return _pipeline(f, **kwargs) - if func is None: - return wrapper - else: - return wrapper(func) + return partial( + pipeline, retain_doc=retain_doc, ancestor_count=ancestor_count) + if ancestor_count == 'all': + ancestor_count = len( + p for p in inspect.signature(func).parameters + if p.kind.name in ["POSITIONAL_ONLY", "POSITIONAL_OR_KEYWORD"]) -def _pipeline(func_or_class, **kwargs): try: - is_class = issubclass(func_or_class, Pipeline) + is_class = issubclass(func, Pipeline) except TypeError: is_class = False - if is_class: - return _pipeline_fromclass(func_or_class, **kwargs) - else: - return _pipeline_fromfunc(func_or_class, **kwargs) - - -def _pipeline_fromclass(cls, retain_doc=False, ancestor_count=1): - """Actual `pipeline` implementation - - Parameters - ---------- - func : class - Class for lazy evaluation - retain_doc : bool - If True, don't modify `func`'s doc string to say that it has been - made lazy - ancestor_count : int or 'all', optional - Number of inputs to the pipeline. Defaults to 1. - - Returns - ------- - Pipeline - Lazy function evaluation :py:class:`Pipeline` for `func`. - """ - if ancestor_count == 'all': - # subtract 1 for `self` - ancestor_count = len(inspect.getfullargspec(cls).args) - 1 - - @wraps(cls) - def process(*args, **kwargs): - ancestors = args[:ancestor_count] - args = args[ancestor_count:] - all_pipe = all(hasattr(a, '_slicerator_flag') or - isinstance(a, Slicerator) or - isinstance(a, Pipeline) for a in ancestors) - if all_pipe: - return cls(*(ancestors + args), **kwargs) - else: - # Fall back on normal behavior of func, interpreting input - # as a single image. - return cls(*(tuple([a] for a in ancestors) + args), **kwargs)[0] - - if not retain_doc: - if process.__doc__ is None: - process.__doc__ = '' - process.__doc__ = ("This function has been made lazy. When passed\n" - "a Slicerator, it will return a \n" - "Pipeline of the results. When passed \n" - "any other objects, its behavior is " - "unchanged.\n\n") + process.__doc__ - process.__name__ = cls.__name__ - return process - - -def _pipeline_fromfunc(func, retain_doc=False, ancestor_count=1): - """Actual `pipeline` implementation - - Parameters - ---------- - func : callable - Function for lazy evaluation - retain_doc : bool - If True, don't modify `func`'s doc string to say that it has been - made lazy - ancestor_count : int or 'all', optional - Number of inputs to the pipeline. Defaults to 1. - - Returns - ------- - Pipeline - Lazy function evaluation :py:class:`Pipeline` for `func`. - """ - if ancestor_count == 'all': - ancestor_count = len(inspect.getfullargspec(func).args) @wraps(func) def process(*args, **kwargs): @@ -679,24 +604,28 @@ def process(*args, **kwargs): all_pipe = all(hasattr(a, '_slicerator_flag') or isinstance(a, Slicerator) or isinstance(a, Pipeline) for a in ancestors) - if all_pipe: - def proc_func(*x): - return func(*(x + args), **kwargs) - return Pipeline(proc_func, *ancestors) + if is_class: + return (func(*ancestors, *args, **kwargs) + if all_pipe else + # Fall back on normal behavior of func, interpreting input + # as a single image. + func(*[[a] for a in ancestors], *args, **kwargs)[0]) + else: - # Fall back on normal behavior of func, interpreting input - # as a single image. - return func(*(ancestors + args), **kwargs) + return (Pipeline(lambda *x: func(*x, *args, **kwargs), *ancestors) + if all_pipe else + # Fall back on normal behavior of func, interpreting input + # as a single image. + func(*ancestors, *args, **kwargs)) if not retain_doc: if process.__doc__ is None: process.__doc__ = '' - process.__doc__ = ("This function has been made lazy. When passed\n" - "a Slicerator, it will return a \n" - "Pipeline of the results. When passed \n" - "any other objects, its behavior is " - "unchanged.\n\n") + process.__doc__ + process.__doc__ = ( + "This function has been made lazy. When passed a Slicerator, it \n" + "will return a Pipeline of the results. When passed any other \n" + "objects, its behavior is unchanged.\n\n" + process.__doc__) process.__name__ = func.__name__ return process