Source code for stride.utils.operators


import numpy as np

from ..core import Operator
from ..problem import StructuredData


__all__ = ['Add', 'Mul', 'Concatenate']


[docs]class Add(Operator):
[docs] def forward(self, a, b, **kwargs): return a + b
[docs] def adjoint(self, d_sum, a, b, **kwargs): return d_sum, d_sum
[docs]class Mul(Operator):
[docs] def forward(self, a, b, **kwargs): return a * b
[docs] def adjoint(self, d_mul, a, b, **kwargs): d_a = d_mul * b d_b = a * d_mul return d_a, d_b
[docs]class Concatenate(Operator):
[docs] def forward(self, *args, **kwargs): axis = kwargs.pop('axis', 0) concat_data = [np.array(each.data) for each in args] concat_data = np.stack(concat_data, axis=axis) concat = StructuredData(name='concat', data=concat_data, grid=args[0].grid) return concat
[docs] def adjoint(self, d_concat, *args, **kwargs): axis = kwargs.pop('axis', 0) d_args = [] for arg_i in range(len(args)): d_arg_i_data = np.take(d_concat.data, arg_i, axis=axis) d_arg_i = args[arg_i].alike(name='grad_%s' % args[arg_i].name, data=d_arg_i_data) d_args.append(d_arg_i) return tuple(d_args)