Source code for stride.utils.operators


import numpy as np
import mosaic

from ..core import Operator


__all__ = ['Add', 'Mul', 'Concatenate']


[docs] class Add(Operator):
[docs] def forward(self, a, b, **kwargs): return a + b
[docs] def adjoint(self, d_sum, a, b, **kwargs): return d_sum, d_sum
[docs] class Mul(Operator):
[docs] def forward(self, a, b, **kwargs): return a * b
[docs] def adjoint(self, d_mul, a, b, **kwargs): d_a = d_mul * b d_b = a * d_mul return d_a, d_b
[docs] @mosaic.tessera class Concatenate(Operator): """ Concatenate multiple StructuredData objects. Parameters ---------- start_end: 2d-array, optional Array containing start and end indices pointing to each objects' position in the concatenated object. Shape should be (num_objects, 2) new_axis : bool, optional Whether to concatenate on a new axis. Defaults to False. axis : int, optional Axis to perform the concatenation. Defaults to 0. """ def __init__(self, *args, **kwargs): self._start_end = kwargs.pop('start_end', None) self.new_axis = kwargs.pop('new_axis', False) self.axis = kwargs.pop('axis', 0) super().__init__(*args, **kwargs)
[docs] async def forward(self, *args, **kwargs): """ Parameters ---------- args : Sequence[StructuredData] Sequence of StructuredData objects to be combined. start_end : 2d-array, optional Array containing start and end indices pointing to each objects' position in the concatenated object. Shape should be (num_objects, 2) axis : int, optional The axis to concatenate on. Defaults to 0. new_axis : bool, optional Whether to create a new axis when concatenating, or maintain the dimensions. Defaults to False. Returns ------- StructuredData Concatenated data as a single StructuredData object. """ # check that at least one StructuredData object has been provided if len(args) == 0: raise TypeError('StructuredData objects missing. Please provide at least one.') new_axis = kwargs.pop('new_axis', None) axis = kwargs.pop('axis', None) start_end = kwargs.pop('start_end', None) if new_axis is not None: self.new_axis = new_axis if axis is not None: self.axis = axis if start_end is not None: # update if required self._start_end = start_end elif self._start_end is None: # build start_end from the arguments self.build_start_end(args) concat_data = [each.data for each in args] if self.new_axis: concat_data = np.stack(concat_data, axis=self.axis) else: concat_data = np.vstack(concat_data) concat = args[0].alike(name='concat_%s' % args[0].name, data=concat_data, shape=None, extended_shape=None, inner=None) return concat
[docs] async def adjoint(self, d_concat, *args, **kwargs): d_args = [] for arg_i in range(len(args)): if self.new_axis: d_arg_i_data = np.take(d_concat.data, arg_i, axis=self.axis) else: # get start and end points of original arguments start, end = self._start_end[arg_i] indices = list(range(start, end)) # preallocate shape for data as np.take has no keepdims option out_array = np.zeros((end-start, d_concat.data.shape[-1]), dtype=d_concat.data.dtype) # extract the data d_arg_i_data = np.take(d_concat.data, indices=indices, axis=self.axis, out=out_array) # insert into stride object d_arg_i = args[arg_i].alike(name='grad_%s' % args[arg_i].name, data=d_arg_i_data, shape=None, extended_shape=None, inner=None) d_args.append(d_arg_i) if len(d_args) > 1: return tuple(d_args) else: return d_args[0]
[docs] def build_start_end(self, args): """ Build start and end indices from the Sequence of StructuredData objects. Parameters ---------- args : Sequence[StructuredData] Sequence of StructuredData objects to be combined. """ start_end = np.zeros((len(args), 2), dtype=np.uint32) start = 0 end = 0 for idx, each in enumerate(args): end += each.data.shape[0] start_end[idx] = [start, end] start = end self._start_end = start_end