Source code for stride.optimisation.optimisers.optimiser


import numpy as np
from abc import ABC, abstractmethod

import mosaic

from ..step_length import LineSearch
from ..pipelines import ProcessGlobalGradient, ProcessModelIteration


__all__ = ['LocalOptimiser']


[docs] class LocalOptimiser(ABC): """ Base class for a local optimiser. It takes the value of the gradient and applies it to the variable. Parameters ---------- variable : Variable Variable to which the optimiser refers. step_size : float or LineSearch, optional Step size for the update, defaults to constant 1. process_grad : callable, optional Optional processing function to apply on the gradient prior to applying it. process_model : callable, optional Optional processing function to apply on the model after updating it. kwargs Extra parameters to be used by the class. """ def __init__(self, variable, **kwargs): if not hasattr(variable, 'needs_grad') or not variable.needs_grad: raise ValueError('To be optimised, a variable needs to be set with "needs_grad=True"') self.variable = variable self.step_size = kwargs.pop('step_size', 1.) self.test_step_size = kwargs.pop('test_step_size', 1.) self.dump_grad = kwargs.pop('dump_grad', False) self.dump_prec = kwargs.pop('dump_prec', False) self._process_grad = kwargs.pop('process_grad', ProcessGlobalGradient(**kwargs)) self._process_model = kwargs.pop('process_model', ProcessModelIteration(**kwargs)) self.reset_block = kwargs.pop('reset_block', False) self.reset_iteration = kwargs.pop('reset_iteration', False)
[docs] def clear_grad(self): """ Clear the internal gradient buffers of the variable. Returns ------- """ self.variable.clear_grad()
[docs] async def pre_process(self, grad=None, processed_grad=None, **kwargs): """ Pre-process the variable gradient before using it to take the step. Parameters ---------- grad : Data, optional Gradient to use for the step, defaults to variable gradient. processed_grad : Data, optional Processed gradient to use for the step, defaults to processed variable gradient. kwargs Extra parameters to be used by the method. Returns ------- Variable Updated variable. """ logger = mosaic.logger() logger.perf('Updating variable %s,' % self.variable.name) problem = kwargs.pop('problem', None) iteration = kwargs.pop('iteration', None) if processed_grad is None: if grad is None: if hasattr(self.variable, 'is_proxy') and self.variable.is_proxy: await self.variable.pull(attr='grad') dump_grad = kwargs.pop('dump_grad', self.dump_grad) dump_prec = kwargs.pop('dump_prec', self.dump_prec) if dump_grad and problem is not None: self.variable.grad.dump(path=problem.output_folder, project_name=problem.name, parameter='raw_%s' % self.variable.grad.name, version=iteration.abs_id+1) if dump_prec and self.variable.grad.prec is not None and problem is not None: self.variable.grad.prec.dump(path=problem.output_folder, project_name=problem.name, parameter='raw_%s' % self.variable.grad.prec.name, version=iteration.abs_id+1) grad = self.variable.process_grad(**kwargs) if dump_grad and problem is not None: grad.dump(path=problem.output_folder, project_name=problem.name, version=iteration.abs_id+1) if dump_prec and self.variable.grad.prec is not None and problem is not None: self.variable.grad.prec.dump(path=problem.output_folder, project_name=problem.name, version=iteration.abs_id+1) min_dir = np.min(grad.data) max_dir = np.max(grad.data) logger.perf('\t grad before processing in range [%e, %e]' % (min_dir, max_dir)) processed_grad = await self._process_grad(grad, variable=self.variable, **kwargs) dump_processed_grad = kwargs.pop('dump_processed_grad', self.dump_grad) if dump_processed_grad and problem is not None: processed_grad.dump(path=problem.output_folder, project_name=problem.name, parameter='processed_%s' % self.variable.grad.name, version=iteration.abs_id + 1) test_step_size = kwargs.pop('test_step_size', self.test_step_size) processed_grad.data[:] *= test_step_size min_dir = np.min(processed_grad.data) max_dir = np.max(processed_grad.data) min_var = np.min(self.variable.data) max_var = np.max(self.variable.data) logger.perf('\t grad after processing in range [%e, %e]' % (min_dir, max_dir)) logger.perf('\t variable range before update [%e, %e]' % (min_var, max_var)) return processed_grad
[docs] async def step(self, step_size=None, grad=None, processed_grad=None, **kwargs): """ Apply the optimiser. Parameters ---------- step_size : float, optional Step size to use for this application, defaults to instance step. grad : Data, optional Gradient to use for the step, defaults to variable gradient. processed_grad : Data, optional Processed gradient to use for the step, defaults to processed variable gradient. kwargs Extra parameters to be used by the method. Returns ------- Variable Updated variable. """ logger = mosaic.logger() # make copy of variable variable_before = self.variable.copy() # pre-process gradient to get update direction direction = await self.pre_process(grad=grad, processed_grad=processed_grad, **kwargs) # select test step size step_size = self.step_size if step_size is None else step_size step_loop = kwargs.pop('step_loop', None) if isinstance(step_size, LineSearch): step_size.init_search( variable=self.variable, direction=direction, **kwargs ) # optimal step size search while True: # find optimal step if isinstance(step_size, LineSearch): if step_loop is None: next_step = 1. done_search = True else: next_step, done_search = step_size.next_step( variable=self.variable, direction=direction, **kwargs ) else: next_step = step_size done_search = True if done_search: # cap the step if needed max_step = kwargs.pop('max_step', None) max_step = np.inf if not isinstance(max_step, (int, float)) else max_step unclipped_step = next_step if next_step > -0.2: # if bit -ve, still assume grad is right dirn next_step = max(0.1, min(next_step, max_step)) elif max_step < np.inf and next_step < -max_step * 0.75: # in general, prevent -ve steps next_step = -max_step * 0.75 elif next_step < -0.2: next_step = next_step * 0.25 logger.perf('\t taking final update step of %e [unclipped step of %e]' % (next_step, unclipped_step)) else: logger.perf('\t taking test step of %e in line search' % next_step) # restore variable self.variable.data[:] = variable_before.data.copy() # update variable if self.variable.transform is not None: variable = self.variable.transform(self.variable) else: variable = self.variable upd_variable = self.update_variable(next_step, variable, direction) if self.variable.transform is not None: upd_variable = self.variable.transform(upd_variable) self.variable.data[:] = upd_variable.data.copy() # post-process variable after update await self.post_process(**kwargs) # if done, stop search if done_search: break # calculate loss change self.variable.needs_grad = False if hasattr(self.variable, 'push'): await self.variable.push(attr='needs_grad') try: await step_loop() finally: self.variable.needs_grad = True if hasattr(self.variable, 'push'): await self.variable.push(attr='needs_grad') return self.variable
[docs] async def post_process(self, **kwargs): """ Perform any necessary post-processing of the variable. Parameters ---------- Returns ------- """ processed_variable = await self._process_model(self.variable, **kwargs) self.variable.extended_data[:] = processed_variable.extended_data[:] min_var = np.min(self.variable.extended_data) max_var = np.max(self.variable.extended_data) logger = mosaic.logger() logger.perf('\t variable range after update [%e, %e]' % (min_var, max_var)) self.variable.release_grad()
[docs] @abstractmethod def update_variable(self, step_size, variable, direction): """ Parameters ---------- step_size : float Step size to use for updating the variable. variable : Data Variable to update. direction : Data Direction in which to update the variable. Returns ------- Variable Updated variable. """ pass
[docs] def reset(self, **kwargs): """ Reset optimiser state along with any stored buffers. Parameters ---------- kwargs Extra parameters to be used by the method. Returns ------- """ pass
[docs] def dump(self, *args, **kwargs): """ Dump latest version of the optimiser. Parameters ---------- kwargs Extra parameters to be used by the method Returns ------- """ self.variable.dump(*args, **kwargs)
[docs] def load(self, *args, **kwargs): """ Load latest version of the optimiser. Parameters ---------- kwargs Extra parameters to be used by the method Returns ------- """ self.variable.load(*args, **kwargs)