Source code for mosaic.runtime.monitor


import os
import time
import psutil
import asyncio
import datetime
import subprocess as cmd_subprocess

import mosaic
from .runtime import Runtime, RuntimeProxy
from .utils import MonitoredResource, MonitoredObject
from .strategies import RoundRobin
from ..file_manipulation import h5
from ..utils import subprocess, at_exit
from ..utils.utils import memory_limit, cpu_count
from ..utils.logger import LoggerManager, _stdout, _stderr
from ..profile import profiler, global_profiler


__all__ = ['Monitor', 'monitor_strategies']


monitor_strategies = {
    'round-robin': RoundRobin
}


def _cpu_mask(num_workers, worker_index, num_threads):
    # Work out the first core ID for this subjob
    startid = (worker_index - 1) * num_threads

    # This is the process CPU ID
    valsum = {}
    for j in range(0, num_threads):
        # Thread CPU ID
        threadid = startid + j
        # Convert to bitmask components
        pos = int(threadid / 4)
        offset = threadid - pos * 4
        val = 2 ** offset
        # This is a fat bitmask so add up the thread values in the right position
        valsum[pos] = valsum.get(pos, 0) + val

    valmask = ''
    # Generate the hex repreesntation of the fat bitmask
    for j in range(max(valsum.keys()), -1, -1):
        valmask = f'{valmask}{valsum.get(j, 0):X}'

    # Append to the list of masks in the appropriate way for this subjob
    mask = '0x' + f'{valmask}'

    return mask


[docs] class Monitor(Runtime): """ The monitor takes care of keeping track of the state of the network and collects statistics about it. It also handles the allocation of tesserae to certain workers. """ is_monitor = True def __init__(self, **kwargs): super().__init__(**kwargs) self._memory_limit = memory_limit() self.strategy_name = kwargs.get('monitor_strategy', 'round-robin') self._monitor_strategy = monitor_strategies[self.strategy_name](self) self._monitored_nodes = dict() self._monitored_tessera = dict() self._monitored_tasks = dict() self._dirty_tessera = set() self._dirty_tasks = set() now = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') self._profile_filename = '%s.profile.h5' % now self._init_filename = None self._start_t = time.time() self._end_t = None
[docs] async def init(self, **kwargs): """ Asynchronous counterpart of ``__init__``. Parameters ---------- kwargs Returns ------- """ await super().init(**kwargs) # Maybe dump file if kwargs.get('dump_init', False): self.init_file({}) # Start local cluster await self.init_warehouse(**kwargs) if self.mode in ['local', 'interactive']: await self.init_local(**kwargs) else: await self.init_cluster(**kwargs)
[docs] async def init_local(self, **kwargs): """ Init nodes in local mode. Parameters ---------- kwargs Returns ------- """ num_workers = kwargs.get('num_workers', 1) def start_node(*args, **extra_kwargs): kwargs.update(extra_kwargs) kwargs['runtime_indices'] = 0 mosaic.init('node', *args, **kwargs, wait=True) node_proxy = RuntimeProxy(name='node', indices=0) node_subprocess = subprocess(start_node)(name=node_proxy.uid, daemon=False) node_subprocess.start_process() node_proxy.subprocess = node_subprocess self._nodes[node_proxy.uid] = node_proxy await self._comms.wait_for(node_proxy.uid) while node_proxy.uid not in self._monitored_nodes: await asyncio.sleep(0.1) self._comms.start_heartbeat(node_proxy.uid) self.logger.info('Listening at <NODE:0 | WORKER:0-%d>' % num_workers)
[docs] async def init_cluster(self, **kwargs): """ Init nodes in cluster mode. Parameters ---------- kwargs Returns ------- """ node_list = kwargs.get('node_list', None) if node_list is None: raise ValueError('No node_list was provided to initialise mosaic in cluster mode') num_cpus = cpu_count() num_nodes = len(node_list) num_workers = kwargs.get('num_workers', 1) num_threads = kwargs.get('num_threads', None) or num_cpus // num_workers log_level = kwargs.get('log_level', 'info') runtime_address = self.address runtime_port = self.port pubsub_port = self.pubsub_port ssh_flags = os.environ.get('SSH_FLAGS', '') ssh_commands = os.environ.get('SSH_COMMANDS', None) ssh_commands = ssh_commands + ';' if ssh_commands else '' reuse_head = '--reuse-head' if self.reuse_head else '--free-head' in_slurm = os.environ.get('SLURM_NODELIST', None) is not None tasks = [] for node_index, node_address in zip(range(num_nodes), node_list): node_proxy = RuntimeProxy(name='node', indices=node_index) remote_cmd = (f'{ssh_commands} ' f'mrun --node -i {node_index} ' f'--monitor-address {runtime_address} --monitor-port {runtime_port} ' f'--pubsub-port {pubsub_port} ' f'-n {num_nodes} -nw {num_workers} -nth {num_threads} ' f'--cluster --{log_level} {reuse_head}') if in_slurm: cpu_mask = _cpu_mask(1, 1, num_cpus) cmd = (f'srun {ssh_flags} --nodes=1 --ntasks=1 --tasks-per-node={num_cpus} ' f'--cpu-bind=mask_cpu:{cpu_mask} --mem-bind=local ' f'--oversubscribe ' f'--distribution=block:block ' f'--hint=nomultithread --no-kill ' f'--nodelist={node_address} ' f'{remote_cmd}') else: cmd = (f'ssh {ssh_flags} {node_address} ' f'"{remote_cmd}"') node_subprocess = cmd_subprocess.Popen(cmd, shell=True, stdout=_stdout, stderr=_stderr) node_proxy.subprocess = node_subprocess self._nodes[node_proxy.uid] = node_proxy async def wait_for(proxy): await self._comms.wait_for(proxy.uid) while proxy.uid not in self._monitored_nodes: await asyncio.sleep(0.1) return proxy tasks.append(wait_for(node_proxy)) for node_proxy in asyncio.as_completed(tasks): node_proxy = await node_proxy self.logger.debug('Started node %s' % node_proxy.uid) for node_uid in self._nodes.keys(): self._comms.start_heartbeat(node_uid) self.logger.debug('Started heartbeat with node %s' % node_uid) self.logger.info('Listening at <NODE:%d-%d | ' 'WORKER:0-%d address=%s>' % (0, num_nodes, num_workers, ', '.join(node_list)))
[docs] def init_file(self, runtime_config): runtime_id = self.uid runtime_address = self.address runtime_port = self.port pubsub_port = self.pubsub_port # Store runtime ID, address and port in a tmp file for the # head to use path = os.path.join(os.getcwd(), 'mosaic-workspace') if not os.path.exists(path): os.makedirs(path) self._init_filename = os.path.join(path, 'monitor.key') with open(self._init_filename, 'w') as file: file.write('[ADDRESS]\n') file.write('UID=%s\n' % runtime_id) file.write('ADD=%s\n' % runtime_address) file.write('PRT=%s\n' % runtime_port) file.write('PUB=%s\n' % pubsub_port) file.write('[ARGS]\n') for key, value in runtime_config.items(): if key in ['runtime_indices', 'address', 'port', 'monitor_address', 'monitor_port', 'node_list']: continue if isinstance(value, str): file.write('%s="%s"\n' % (key, value)) else: file.write('%s=%s\n' % (key, value)) def _rm_dirs(): os.remove(self._init_filename) at_exit.add(_rm_dirs)
[docs] def set_logger(self): """ Set up logging. Returns ------- """ self.logger = LoggerManager() if self.mode == 'interactive': self.logger.set_remote(runtime_id='head', format=self.mode) else: self.logger.set_local(format=self.mode)
[docs] def set_profiler(self): """ Set up profiling. Returns ------- """ global_profiler.set_local() self._loop.interval(self.append_description, interval=10)
[docs] def update_node(self, sender_id, update, sub_resources): if sender_id not in self._monitored_nodes: self._monitored_nodes[sender_id] = MonitoredResource(sender_id) node = self._monitored_nodes[sender_id] node.update(update, **sub_resources) self._monitor_strategy.update_node(node)
[docs] def add_tessera_event(self, sender_id, msgs): msgs = [msgs] if not isinstance(msgs, list) else msgs for msg in msgs: self._add_tessera_event(sender_id, **msg)
[docs] def add_task_event(self, sender_id, msgs): msgs = [msgs] if not isinstance(msgs, list) else msgs for msg in msgs: self._add_task_event(sender_id, **msg)
[docs] def add_tessera_profile(self, sender_id, msgs): msgs = [msgs] if not isinstance(msgs, list) else msgs for msg in msgs: self._add_tessera_profile(sender_id, **msg)
[docs] def add_task_profile(self, sender_id, msgs): msgs = [msgs] if not isinstance(msgs, list) else msgs for msg in msgs: self._add_task_profile(sender_id, **msg)
def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs): if uid not in self._monitored_tessera: self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid) obj = self._monitored_tessera[uid] obj.add_event(sender_id, **kwargs) self._monitor_strategy.update_tessera(obj) self._dirty_tessera.add(uid) def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs): if uid not in self._monitored_tasks: self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id) obj = self._monitored_tasks[uid] obj.add_event(sender_id, **kwargs) self._monitor_strategy.update_task(obj) self._dirty_tasks.add(uid) def _add_tessera_profile(self, sender_id, runtime_id, uid, profile): if uid not in self._monitored_tessera: self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid) obj = self._monitored_tessera[uid] obj.add_profile(sender_id, profile) self._dirty_tessera.add(uid) def _add_task_profile(self, sender_id, runtime_id, uid, tessera_id, profile): if uid not in self._monitored_tasks: self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id) obj = self._monitored_tasks[uid] obj.add_profile(sender_id, profile) self._dirty_tasks.add(uid)
[docs] def append_description(self): if not profiler.tracing: return if not len(self._dirty_tessera) and not len(self._dirty_tasks): return description = { 'monitored_tessera': {}, 'monitored_tasks': {}, } for uid in self._dirty_tessera: tessera = self._monitored_tessera[uid] description['monitored_tessera'][uid] = tessera.append() for uid in self._dirty_tasks: task = self._monitored_tasks[uid] description['monitored_tasks'][uid] = task.append() self._append_description(description) self._dirty_tessera = set() self._dirty_tasks = set()
def _append_description(self, description): if not h5.file_exists(filename=self._profile_filename): description['start_t'] = self._start_t with h5.HDF5(filename=self._profile_filename, mode='w') as file: file.dump(description) else: with h5.HDF5(filename=self._profile_filename, mode='a') as file: file.append(description)
[docs] async def stop(self, sender_id=None): """ Stop runtime. Parameters ---------- sender_id : str Returns ------- """ # Delete files if self._init_filename is not None: try: os.remove(self._init_filename) except Exception: pass # Get final profile updates before closing if profiler.tracing: profiler.stop() self._end_t = time.time() description = { 'end_t': self._end_t, 'monitored_tessera': {}, 'monitored_tasks': {}, } for uid, tessera in self._monitored_tessera.items(): tessera.collect() description['monitored_tessera'][tessera.uid] = tessera.append() for uid, task in self._monitored_tasks.items(): task.collect() description['monitored_tasks'][task.uid] = task.append() self._append_description(description) # Close warehouse await self._local_warehouse.stop() self._local_warehouse.subprocess.join_process() # Close nodes for node_id, node in self._nodes.items(): await node.stop() if hasattr(node.subprocess, 'stop_process'): node.subprocess.join_process() if isinstance(node.subprocess, cmd_subprocess.Popen): ps_process = psutil.Process(node.subprocess.pid) for child in ps_process.children(recursive=True): child.kill() ps_process.kill() await super().stop(sender_id)
[docs] async def select_worker(self, sender_id): """ Select appropriate worker to allocate a tessera. Parameters ---------- sender_id : str Returns ------- str UID of selected worker. """ while not len(self._monitored_nodes.keys()): await asyncio.sleep(0.1) return self._monitor_strategy.select_worker(sender_id)
[docs] async def barrier(self, sender_id, timeout=None): """ Wait until all pending tasks are done. If no timeout is provided, the barrier will wait indefinitely. Parameters ---------- timeout : float, optional Returns ------- """ pending_tasks = [] for task in self._monitored_tasks.values(): if task.state in ['done', 'failed', 'collected']: continue pending_tasks.append(task) self.logger.info('Pending barrier tasks %d' % len(pending_tasks)) tic = time.time() while pending_tasks: await asyncio.sleep(0.5) for task in pending_tasks: if task.state in ['done', 'failed', 'collected']: pending_tasks.remove(task) if timeout is not None and (time.time() - tic) > timeout: break