Source code for mosaic.core.task


import sys
import uuid
import tblib
import asyncio
import weakref
import operator
from cached_property import cached_property

from .. import types
from .base import Base, RemoteBase, ProxyBase, RuntimeDisconnectedError
from ..types import WarehouseObject
from ..utils import Future, MultiError, sizeof, remote_sizeof


__all__ = ['Task', 'TaskProxy', 'TaskOutputGenerator', 'TaskOutput', 'TaskDone']


[docs] class Task(RemoteBase): """ When we call a method on a remote tessera, two things will happen: - a Task will be generated on the remote tessera and queued to be executed by it; - and a TaskProxy is generated on the calling side as a pointer to that remote task. We can use the task proxy to wait for the completion of the task (``await task_proxy``), as an argument to other tessera method calls, or to retrieve the result of the task (``await task_proxy.result()``). It is also possible to access references to the individual outputs of the task by using ``task_proxy.outputs``. Outputs can be accessed through their position: ``task_proxy.outputs[0]`` will reference the first output of the task. A reference to the termination of the task is also available through ``task_proxy.outputs.done``, which can be used to create explicit dependencies between tasks, thus controlling the order of execution. Tasks on a particular tessera are guaranteed to be executed in the order in which they were called, but no such guarantees exist for tasks on different tesserae. A completed task is kept in memory at the worker for as long as there are proxy references to it. If none exist, it will be made available for garbage collection. Objects of class Task should not be instantiated directly by the user. Parameters ---------- uid : str UID of the task. sender_id : str UID of the caller. tessera : Tessera Tessera on which the task is to be executed. method : callable Method associated with the task. args : tuple, optional Arguments to pass to the method. kwargs : optional Keyword arguments to pass to the method. """ type = 'task' is_remote = True def __init__(self, uid, sender_id, tessera, method, *args, **kwargs): super().__init__(uid, *args, **kwargs) self._sender_id = sender_id self._tessera = weakref.proxy(tessera) kwargs = self._fill_config(**kwargs) self.method = method self.args = args self.kwargs = kwargs self._tic = 0 self._elapsed = None self._arg_size = 0 self._args_pending = set() self._kwargs_pending = set() self._args_value = dict() self._kwargs_value = dict() self._args_state = dict() self._kwargs_state = dict() self._ready_future = Future() self._result = None self._exception = None self._state = None self.runtime.register(self) self.state_changed('pending') self.register_proxy(self._sender_id) @property def sender_id(self): """ Caller UID. """ return self._sender_id @cached_property def tessera_id(self): """ Tessera UID. """ try: return self._tessera.uid except ReferenceError: return None @cached_property def remote_runtime(self): """ Proxies that have references to this task. """ return {self.proxy(each) for each in list(self._proxies)} @property def collectable(self): """ Whether the object is ready for collection. """ return self._state in ['failed', 'done']
[docs] @classmethod def remote_cls(cls): """ Class of the remote. """ return TaskProxy
def _fill_config(self, **kwargs): kwargs['max_retries'] = kwargs.get('max_retries', self._tessera.max_retries) return super()._fill_config(**kwargs)
[docs] def args_value(self): """ Processed value of the args of the task. Returns ------- tuple """ args = [value for key, value in sorted(self._args_value.items(), key=operator.itemgetter(0))] return tuple(args)
[docs] def kwargs_value(self): """ Processed value of the args of the task. Returns ------- dict """ return self._kwargs_value
[docs] async def set_result(self, result): """ Set task result. Parameters ---------- result Returns ------- """ if not isinstance(result, (tuple, dict)): result = (result,) min_size = 1024**1 if isinstance(result, tuple): async def store(_value): return await self.runtime.put(_value, reply=True) async def noop(_value): return _value tasks = [] for value in result: obj_size = sizeof(value) if obj_size > min_size: tasks.append(store(value)) else: tasks.append(noop(value)) stored_result = await asyncio.gather(*tasks) stored_result = tuple(stored_result) elif isinstance(result, dict): async def store(_key, _value): return _key, await self.runtime.put(_value, reply=True) async def noop(_key, _value): return _key, _value tasks = [] for key, value in result.items(): obj_size = sizeof(value) if obj_size > min_size: tasks.append(store(key, value)) else: tasks.append(noop(key, value)) stored_result = {} tasks = await asyncio.gather(*tasks) for key, value in tasks: stored_result[key] = value else: assert False await self.cmd_async(method='set_result', result=stored_result) self._result = stored_result await self.set_done()
[docs] def check_result(self): """ Check if the result is present. Returns ------- str State of the task. Exception or None Exception if task has failed, None otherwise. """ if self._state == 'failed': return 'failed', self._exception else: return self._state, self._result
def _cleanup(self): self.args = None self.kwargs = None self._args_pending = weakref.WeakSet() self._kwargs_pending = weakref.WeakSet() self._args_value = dict() self._kwargs_value = dict() self._args_state = dict() self._kwargs_state = dict()
[docs] def add_event(self, event_name, **kwargs): kwargs['tessera_id'] = self.tessera_id return super().add_event(event_name, **kwargs)
[docs] def add_profile(self, profile, **kwargs): kwargs['tessera_id'] = self.tessera_id return super().add_profile(profile, **kwargs)
[docs] async def prepare_args(self): """ Prepare the arguments of the task for execution. Returns ------- Future """ tasks = [] async def await_size(_arg): self._arg_size += await _arg.size(pending=True) for index in range(len(self.args)): arg = self.args[index] if type(arg) in types.awaitable_types: self._args_state[index] = arg.state if not isinstance(arg, TaskDone): self._args_value[index] = None if arg.state != 'done': self._args_pending.add(arg) def callback(_index, _arg): def _callback(fut): self.loop.run(self._set_arg_done, fut, _index, _arg) return _callback arg.add_done_callback(callback(index, arg)) else: tasks.append( await_size(arg) ) else: self._args_state[index] = 'ready' self._args_value[index] = arg for key, value in self.kwargs.items(): if type(value) in types.awaitable_types: self._kwargs_state[key] = value.state if not isinstance(value, TaskDone): self._kwargs_value[key] = None if value.state != 'done': self._kwargs_pending.add(value) def callback(_key, _arg): def _callback(fut): self.loop.run(self._set_kwarg_done, fut, _key, _arg) return _callback value.add_done_callback(callback(key, value)) else: tasks.append( await_size(value) ) else: self._kwargs_state[key] = 'ready' self._kwargs_value[key] = value await asyncio.gather(*tasks) await self._check_ready() return self._ready_future
[docs] async def set_exception(self, exc): """ Set task exception Parameters ---------- exc : Exception Returns ------- """ self.state_changed('failed') self._exception = exc await self.cmd_async(method='set_exception', exc=exc) # Once done release local copy of the arguments self._cleanup()
[docs] async def set_done(self): """ Set task as done. Returns ------- """ self.state_changed('done') # Once done release local copy of the arguments self._cleanup()
async def _set_arg_done(self, fut, index, arg): if not (await self._check_exception(fut, arg)): return self._arg_size += await arg.size(pending=True) self._args_state[index] = 'done' try: self._args_pending.remove(arg) except KeyError: pass await self._check_ready() async def _set_kwarg_done(self, fut, index, arg): if not (await self._check_exception(fut, arg)): return self._arg_size += await arg.size(pending=True) self._kwargs_state[index] = 'done' try: self._kwargs_pending.remove(arg) except KeyError: pass await self._check_ready() async def _check_exception(self, fut, arg): try: exc = fut.exception() except asyncio.CancelledError: exc = None if exc is not None: exc = MultiError(exc) try: raise RuntimeError('Task failed due to failed argument: %s' % arg) except Exception as fail: exc.add(fail) try: raise exc except MultiError: et, ev, tb = sys.exc_info() tb = tblib.Traceback(tb) await self.set_exception((et, ev, tb)) try: self._ready_future.set_result(True) except asyncio.InvalidStateError: pass return False else: return True async def _check_ready(self): if len(self._args_pending) or len(self._kwargs_pending): return # make sure there's enough memory to pull the arguments wait = 1 while not self.runtime.fits_in_memory(self._arg_size): if self.runtime._running_tasks <= 0: try: raise MemoryOverflowError('Not enough memory to allocate %d bytes ' 'for task %s' % (self._arg_size, self)) except MemoryOverflowError: et, ev, tb = sys.exc_info() tb = tblib.Traceback(tb) await self.set_exception((et, ev, tb)) self._ready_future.set_result(True) return await asyncio.sleep(wait) wait *= 1.2 self.runtime.inc_committed_mem(self._arg_size) # reserve memory to pull args self.runtime.dec_pending_tasks() self.runtime.inc_running_tasks() # pull all arguments awaitable_args = [] for index in range(len(self.args)): arg = self.args[index] if type(arg) in types.awaitable_types: self._args_state[index] = 'ready' async def _await_arg(_index, _arg): _result = await _arg.result() _attr = self._args_value if not isinstance(_arg, TaskDone) else None return _attr, _index, _result awaitable_args.append( _await_arg(index, arg) ) for key, value in self.kwargs.items(): if type(value) in types.awaitable_types: self._kwargs_state[key] = 'ready' async def _await_kwarg(_key, _arg): _result = await _arg.result() _attr = self._kwargs_value if not isinstance(_arg, TaskDone) else None return _attr, _key, _result awaitable_args.append( _await_kwarg(key, value) ) for task in asyncio.as_completed(awaitable_args): attr, key, result = await task if attr is not None: attr[key] = result self.runtime.dec_committed_mem(self._arg_size) # return reserved memory # set task ready if not self._ready_future.done(): self._ready_future.set_result(True) def __del__(self): result = self._result if isinstance(result, tuple): for value in result: if isinstance(value, WarehouseObject): self.loop.run(self.runtime.drop, value.uid) elif isinstance(result, dict): for value in result.values(): if isinstance(value, WarehouseObject): self.loop.run(self.runtime.drop, value.uid)
[docs] class TaskProxy(ProxyBase): """ Proxy pointing to a remote task that has been or will be executed. """ type = 'task_proxy' def __init__(self, proxy, method, *args, **kwargs): super().__init__(*args, **kwargs) self._uid = '%s-%s-%s-%s' % ('task', proxy._cls.cls.__name__.lower(), method, uuid.uuid4().hex) self._tessera_proxy = proxy self._fill_config(**kwargs) self.method = method self.args = args self.kwargs = kwargs self._state = None self._result = None self._done_future = Future() self._outputs = None self.state_changed('pending')
[docs] async def init(self): """ Asynchronous correlate of ``__init__``. Returns ------- """ self.runtime.register(self) task = { 'tessera_id': self._tessera_proxy.uid, 'method': self.method, 'args': self.args, 'kwargs': self.kwargs, } await self.remote_runtime.init_task(task=task, uid=self._uid, reply=True) if self._state == 'pending': self.state_changed('queued')
[docs] def deregister_runtime(self, uid): if uid != self.runtime_id: return super().deregister_runtime(uid) self.state_changed('failed') try: self._done_future.set_exception( RuntimeDisconnectedError('Remote runtime %s became disconnected' % uid) ) except asyncio.InvalidStateError: pass else: # Once done release local copy of the arguments self._cleanup()
@cached_property def runtime_id(self): """ UID of the runtime where the task lives. """ try: return self._tessera_proxy.runtime_id except ReferenceError: return None @cached_property def tessera_id(self): """ Tessera UID. """ try: return self._tessera_proxy.uid except ReferenceError: return None @cached_property def remote_runtime(self): """ Proxy to the runtime where the task lives. """ return self._tessera_proxy.remote_runtime @property def init_future(self): """ Future that will be completed when the remote task is initiated remotely. """ return self._init_future @property def done_future(self): """ Future that will be completed when the remote task is done. """ return self._done_future
[docs] @classmethod def remote_cls(cls): """ Class of the remote. """ return Task
@property def done(self): """ Access to TaskDone of this task. """ return self.outputs.done @property def outputs(self): """ Access individual outputs of the task. """ if self._outputs is None or self._outputs() is None: outputs = TaskOutputGenerator(self) self._outputs = weakref.ref(outputs) else: outputs = self._outputs() return outputs @property def collectable(self): """ Whether the object is ready for collection. """ return self._state in ['failed', 'done']
[docs] def set_done(self): """ Set task as done. Returns ------- """ self.state_changed('done') try: self._done_future.set_result(True) except asyncio.InvalidStateError: pass # Once done release local copy of the arguments self._cleanup()
[docs] def set_result(self, result): """ Set task result. Parameters ---------- result Returns ------- """ self._result = result self.set_done()
[docs] def set_exception(self, exc): """ Set exception during task execution. Parameters ---------- exc : Exception description Returns ------- """ self.state_changed('failed') exc = exc[1].with_traceback(exc[2].as_traceback()) try: self._done_future.set_exception(exc) except asyncio.InvalidStateError: pass else: # Once done release local copy of the arguments self._cleanup()
[docs] def wait(self): """ Wait on the task to be completed. Returns ------- """ return self._done_future.result()
[docs] def add_done_callback(self, fun): """ Add done callback. Parameters ---------- fun : callable Returns ------- """ self._done_future.add_done_callback(fun)
def _cleanup(self): self.args = None self.kwargs = None # Release the strong reference to the tessera proxy once the task is complete # so that it can be garbage collected if necessary try: self._tessera_proxy = weakref.proxy(self._tessera_proxy) except TypeError: pass
[docs] def add_event(self, event_name, **kwargs): kwargs['tessera_id'] = self.tessera_id return super().add_event(event_name, **kwargs)
[docs] def add_profile(self, profile, **kwargs): kwargs['tessera_id'] = self.tessera_id return super().add_profile(profile, **kwargs)
[docs] async def size(self, pending=False): """ Size of the task result in bytes. Returns ------- """ await self if hasattr(self, '_retrieved'): if pending: return 0 else: return sizeof(self._retrieved) return await remote_sizeof(self._result, pending=pending)
[docs] async def result(self): """ Gather remote result from the task. Returns ------- Task result """ await self if hasattr(self, '_retrieved'): return self._retrieved result = self._result if isinstance(result, tuple): async def retrieve(_value): return await self.runtime.get(_value) async def noop(_value): return _value tasks = [] for value in result: if isinstance(value, WarehouseObject): tasks.append(retrieve(value)) else: tasks.append(noop(value)) retrieved = await asyncio.gather(*tasks) retrieved = tuple(retrieved) if len(retrieved) == 1: retrieved = retrieved[0] elif isinstance(result, dict): async def retrieve(_key, _value): return _key, await self.runtime.get(_value) async def noop(_key, _value): return _key, _value tasks = [] for key, value in result.items(): if isinstance(value, WarehouseObject): tasks.append(retrieve(key, value)) else: tasks.append(noop(key, value)) tasks = await asyncio.gather(*tasks) retrieved = {} for key, value in tasks: retrieved[key] = value else: assert False self._result = None setattr(self, '_retrieved', retrieved) return retrieved
[docs] async def check_result(self): """ Check the remote result. Returns ------- """ if self._state != 'done' and self._state != 'failed': state, result = await self.cmd_recv_async(method='check_result') if state == 'done': self.set_result(result) elif state == 'failed': self.set_exception(result)
def __await__(self): yield from self._done_future.__await__() return self _serialisation_attrs = ProxyBase._serialisation_attrs + ['_tessera_proxy', 'method'] @classmethod def _deserialisation_helper(cls, state): instance = super()._deserialisation_helper(state) if not hasattr(instance, 'args'): instance.args = None instance.kwargs = None if not hasattr(instance, '_result'): instance._result = None instance._done_future = Future() if instance.state == 'done': instance.set_done() # Synchronise the task state, in case something has happened between # the moment when it was pickled until it has been re-registered on # this side instance.loop.run(instance.check_result) return instance
[docs] class TaskOutputGenerator: """ Class that generates pointers to specific outputs of a remote task, """ def __init__(self, task_proxy): self._task_proxy = task_proxy self._generated_outputs = weakref.WeakValueDictionary() def __repr__(self): runtime_id = self._task_proxy.runtime_id return "<%s object at %s, uid=%s, runtime=%s, state=%s>" % \ (self.__class__.__name__, id(self), self._task_proxy.uid, runtime_id, self._task_proxy.state) def __getattribute__(self, item): try: return super().__getattribute__(item) except AttributeError: if item not in self._generated_outputs: if item == 'done': generated_output = TaskDone(self._task_proxy) else: generated_output = TaskOutput(item, self._task_proxy) self._generated_outputs[item] = generated_output return self._generated_outputs[item] def __getitem__(self, item): if item not in self._generated_outputs: if item == 'done': generated_output = TaskDone(self._task_proxy) else: generated_output = TaskOutput(item, self._task_proxy) self._generated_outputs[item] = generated_output return self._generated_outputs[item]
class TaskOutputBase(Base): """ Base class for outputs of a task. """ def __init__(self, task_proxy): self._task_proxy = task_proxy self._result = None @property def uid(self): return self._task_proxy.uid @property def state(self): return self._task_proxy.state @cached_property def runtime_id(self): return self._task_proxy.runtime_id @cached_property def remote_runtime(self): return self._task_proxy.remote_runtime @property def init_future(self): return self._task_proxy.init_future @property def done_future(self): return self._task_proxy.done_future def wait(self): return self._task_proxy.wait() async def result(self): pass async def size(self, pending=False): pass def add_done_callback(self, fun): self._task_proxy.add_done_callback(fun) def __await__(self): return (yield from self._task_proxy.__await__())
[docs] class TaskOutput(TaskOutputBase): """ Pointer to specific remote output of a class. """ def __init__(self, key, task_proxy): super().__init__(task_proxy) self._key = key def __repr__(self): runtime_id = self.runtime_id return "<%s object [%s] at %s, uid=%s, runtime=%s, state=%s>" % \ (self.__class__.__name__, self._key, id(self), self.uid, runtime_id, self.state) def _select_result(self, result): if not isinstance(result, tuple) and not isinstance(result, dict): result = (result,) return result[self._key]
[docs] async def size(self, pending=False): """ Size of the task result in bytes. Returns ------- """ await self if self._result is not None: if pending: return 0 else: return sizeof(self._result) elif hasattr(self._task_proxy, '_retrieved'): if pending: return 0 else: result = self._select_result(self._task_proxy._retrieved) return sizeof(result) result = self._select_result(self._task_proxy._result) return await remote_sizeof(result, pending=pending)
[docs] async def result(self): """ Gather output from the remote task. Returns ------- Output """ await self if self._result is None: result = await self._task_proxy.result() self._result = self._select_result(result) return self._result
[docs] class TaskDone(TaskOutputBase): """ Reference to the termination of a remote task. """ def __repr__(self): runtime_id = self.runtime_id return "<%s object at %s, uid=%s, runtime=%s, state=%s>" % \ (self.__class__.__name__, id(self), self.uid, runtime_id, self.state)
[docs] async def size(self, pending=False): return 0
[docs] async def result(self): """ Wait for task termination. Returns ------- """ await self self._result = True return self._result
class MemoryOverflowError(Exception): pass types.awaitable_types += (TaskProxy, TaskOutput, TaskDone) types.remote_types += (Task,) types.proxy_types += (TaskProxy,)