Source code for stride.optimisation.optimisers.gradient_descent



from .optimiser import LocalOptimiser


__all__ = ['GradientDescent']


[docs] class GradientDescent(LocalOptimiser): """ Implementation of a gradient descent update. Parameters ---------- variable : Variable Variable to which the optimiser refers. step_size : float, optional Step size for the update, defaults to 1. kwargs Extra parameters to be used by the class. """ def __init__(self, variable, **kwargs): super().__init__(variable, **kwargs)
[docs] async def pre_process(self, grad=None, processed_grad=None, **kwargs): processed_grad = await super().pre_process(grad=grad, processed_grad=processed_grad, **kwargs) return processed_grad
[docs] async def update_variable(self, step_size, direction): self.variable.data[:] -= step_size * direction.data return self.variable