Source code for stride.optimisation.optimisers.optimiser


import numpy as np
from abc import ABC, abstractmethod

import mosaic

from ..pipelines import ProcessGlobalGradient, ProcessModelIteration


__all__ = ['LocalOptimiser']


[docs] class LocalOptimiser(ABC): """ Base class for a local optimiser. It takes the value of the gradient and applies it to the variable. Parameters ---------- variable : Variable Variable to which the optimiser refers. process_grad : callable, optional Optional processing function to apply on the gradient prior to applying it. process_model : callable, optional Optional processing function to apply on the model after updating it. kwargs Extra parameters to be used by the class. """ def __init__(self, variable, **kwargs): if not hasattr(variable, 'needs_grad') or not variable.needs_grad: raise ValueError('To be optimised, a variable needs to be set with "needs_grad=True"') self.variable = variable self.dump_grad = kwargs.pop('dump_grad', False) self.dump_prec = kwargs.pop('dump_prec', False) self._process_grad = kwargs.pop('process_grad', ProcessGlobalGradient(**kwargs)) self._process_model = kwargs.pop('process_model', ProcessModelIteration(**kwargs)) self.reset_block = kwargs.pop('reset_block', False) self.reset_iteration = kwargs.pop('reset_iteration', False)
[docs] def clear_grad(self): """ Clear the internal gradient buffers of the variable. Returns ------- """ self.variable.clear_grad()
[docs] @abstractmethod def step(self, **kwargs): """ Apply the optimiser. Parameters ---------- kwargs Extra parameters to be used by the method. Returns ------- Variable Updated variable. """ pass
[docs] @abstractmethod def reset(self, **kwargs): """ Reset optimiser state along with any stored buffers. Parameters ---------- kwargs Extra parameters to be used by the method. Returns ------- """ pass
[docs] def dump(self, *args, **kwargs): """ Dump latest version of the optimiser. Parameters ---------- kwargs Extra parameters to be used by the method Returns ------- """ self.variable.dump(*args, **kwargs)
[docs] def load(self, *args, **kwargs): """ Load latest version of the optimiser. Parameters ---------- kwargs Extra parameters to be used by the method Returns ------- """ self.variable.load(*args, **kwargs)
[docs] async def pre_process(self, grad=None, processed_grad=None, **kwargs): """ Pre-process the variable gradient before using it to take the step. Parameters ---------- grad : Data, optional Gradient to use for the step, defaults to variable gradient. processed_grad : Data, optional Processed gradient to use for the step, defaults to processed variable gradient. kwargs Extra parameters to be used by the method. Returns ------- Variable Updated variable. """ logger = mosaic.logger() logger.perf('Updating variable %s,' % self.variable.name) if processed_grad is None: if grad is None: if hasattr(self.variable, 'is_proxy') and self.variable.is_proxy: await self.variable.pull(attr='grad') problem = kwargs.pop('problem', None) iteration = kwargs.pop('iteration', None) dump_grad = kwargs.pop('dump_grad', self.dump_grad) dump_prec = kwargs.pop('dump_prec', self.dump_prec) if dump_grad and problem is not None: self.variable.grad.dump(path=problem.output_folder, project_name=problem.name, parameter='raw_%s' % self.variable.grad.name, version=iteration.abs_id+1) if dump_prec and self.variable.grad.prec is not None and problem is not None: self.variable.grad.prec.dump(path=problem.output_folder, project_name=problem.name, version=iteration.abs_id+1) grad = self.variable.process_grad(**kwargs) if dump_grad and problem is not None: grad.dump(path=problem.output_folder, project_name=problem.name, version=iteration.abs_id+1) min_dir = np.min(grad.data) max_dir = np.max(grad.data) logger.perf('\t grad before processing in range [%e, %e]' % (min_dir, max_dir)) processed_grad = await self._process_grad(grad, variable=self.variable, **kwargs) min_dir = np.min(processed_grad.data) max_dir = np.max(processed_grad.data) min_var = np.min(self.variable.data) max_var = np.max(self.variable.data) logger.perf('\t grad after processing in range [%e, %e]' % (min_dir, max_dir)) logger.perf('\t variable range before update [%e, %e]' % (min_var, max_var)) return processed_grad
[docs] async def post_process(self, **kwargs): """ Perform any necessary post-processing of the variable. Parameters ---------- Returns ------- """ processed_variable = await self._process_model(self.variable, **kwargs) self.variable.extended_data[:] = processed_variable.extended_data[:] min_var = np.min(self.variable.extended_data) max_var = np.max(self.variable.extended_data) logger = mosaic.logger() logger.perf('\t variable range after update [%e, %e]' % (min_var, max_var)) self.variable.release_grad()