Source code for stride.optimisation.pipelines.pipeline


from mosaic import tessera

from .steps import steps_registry
from ...core import Operator, no_grad


__all__ = ['Pipeline']


[docs] @tessera class Pipeline(Operator): """ A pipeline represents a series of processing steps that will be applied in order to a series of inputs. Pipelines encode pre-processing or post-processing steps such as filtering time traces or smoothing a gradient. Parameters ---------- steps : list, optional List of steps that form the pipeline. Steps can be callable or strings pointing to a default, pre-defined step. """ def __init__(self, steps=None, **kwargs): super().__init__(**kwargs) self._no_grad = kwargs.pop('no_grad', False) self._kwargs = kwargs steps = steps or [] self._steps = [] for step in steps: do_raise = True if isinstance(step, tuple): step, do_raise = step if isinstance(step, str): step_cls = steps_registry.get(step, None) if step_cls is None and do_raise: raise ValueError('Pipeline step %s does not exist in the registry' % step) if step_cls is not None: self._steps.append(step_cls(**kwargs)) else: self._steps.append(step)
[docs] async def forward(self, *args, **kwargs): """ Apply all steps in the pipeline in order. """ next_args = args for step in self._steps: if self._no_grad: with no_grad(*next_args, **kwargs): next_args = await step(*next_args, **{**self._kwargs, **kwargs}) else: next_args = await step(*next_args, **{**self._kwargs, **kwargs}) next_args = (next_args,) if len(args) == 1 else next_args if len(args) == 1: return next_args[0] else: return next_args
[docs] async def adjoint(self, *args, **kwargs): input_args, input_kwargs = self.inputs outputs = args[:self.num_outputs] for step in self._steps: outputs = step.adjoint(*outputs, *input_args, **{**self._kwargs, **kwargs}) if len(outputs) == 1: return outputs[0] else: return outputs