Source code for stride.optimisation.pipelines.pipeline


from mosaic.utils import camel_case

from mosaic import tessera

from . import steps as steps_module
from ...core import Operator, no_grad


__all__ = ['Pipeline']


[docs]@tessera class Pipeline(Operator): """ A pipeline represents a series of processing steps that will be applied in order to a series of inputs. Pipelines encode pre-processing or post-processing steps such as filtering time traces or smoothing a gradient. Parameters ---------- steps : list, optional List of steps that form the pipeline. Steps can be callable or strings pointing to a default, pre-defined step. """ def __init__(self, steps=None, **kwargs): super().__init__(**kwargs) self._no_grad = kwargs.pop('no_grad', False) self._kwargs = kwargs steps = steps or [] self._steps = [] for step in steps: if isinstance(step, str): step_module = getattr(steps_module, step) step = getattr(step_module, camel_case(step)) self._steps.append(step(**kwargs)) else: self._steps.append(step)
[docs] async def forward(self, *args, **kwargs): """ Apply all steps in the pipeline in order. """ next_args = args for step in self._steps: if self._no_grad: with no_grad(*next_args, **kwargs): next_args = await step(*next_args, **{**self._kwargs, **kwargs}) else: next_args = await step(*next_args, **{**self._kwargs, **kwargs}) next_args = (next_args,) if len(args) == 1 else next_args if len(args) == 1: return next_args[0] else: return next_args
[docs] async def adjoint(self, *args, **kwargs): input_args, input_kwargs = self.inputs outputs = args[:self.num_outputs] for step in self._steps: outputs = step.adjoint(*outputs, *input_args, **{**self._kwargs, **kwargs}) if len(outputs) == 1: return outputs[0] else: return outputs