Source code for seqlib.nodes.sequence

"""Implements the Sequence node for the sequencer"""
import logging
import pdb
from datetime import datetime
import asyncio
import contextvars as cv
import attr

import networkx as nx
from zope.interface import implementer

from .state import T_STATE
from .interface import (
    INode,
    _BaseNode,
)

from .action import (
    uniqueId,
    StartNode,
    EndNode,
    make_node,
)

logger = logging.getLogger(__name__)


[docs]@implementer(INode) @attr.s class Sequence(_BaseNode): """Basic sequencer node. Use :meth:`create` to build poperly :class:`Sequence` objects. Attributes: state: node state ============ =============================== Context Variables --------------------------------------------- Name Desc ============ =============================== current_tpl The parent of the current node root Top level DAG's root ============ =============================== Examples: >>> s = Sequence.create(a,b, name="my sequence") # execute Sequence >>> await s.start() # Get running Sequence from Inside function a(): >>> def a(): ... current_tpl = Sequence.current_tpl ... # now current_tpl is the node s ... assert current_tpl == s """ _seq = attr.ib(init=False, default=attr.Factory(list), repr=False) graph = attr.ib(init=False, default=attr.Factory(nx.DiGraph), repr=False) _start_node = attr.ib(init=False, default=None, repr=False) _end_node = attr.ib(init=False, default=None, repr=False) _context = attr.ib(init=False, default=attr.Factory(dict), repr=False) debug = attr.ib(init=False, default=False, repr=False) # Loop's contextvar current_tpl = cv.ContextVar("current_tpl", default=None) root = cv.ContextVar("root", default=None) # set by OB runner def __attrs_post_init__(self): """ Assigns node's name and id. """ if self.name is None: self.name = "Sequence" if self.id is None: self.id = "%s_%s" % (self.name, uniqueId(id(self))) self._result = {} @property def context(self): """ Get context dictionary, preferably from root node. """ root = Sequence.root.get() if root is not None: self._context = root._context else: Sequence.root.set(self) # self._context = {} return self._context @context.setter def context(self, ctx): self._context = ctx @property def seq(self): """Retrieves the sequence list. :meta private: """ return self._seq @seq.setter def seq(self, s): """ List used to create the Sequence's graph. Parameters ----------- s (Iterable): Sequence list. Members of 's' represents sequence nodes. If they are coroutines they are automagically converted to :class:'Action' nodes. """ from collections.abc import Iterable assert isinstance(s, Iterable) self._seq = list(map(make_node, s)) @property def G(self): """returns the graph object""" return self.graph async def start_step(self): """Standard sequence's start step. Sets the sequence's state to RUNNING """ logger.info( "This is the <start> of a sequence: %s", Sequence.current_tpl.get() ) self.state = T_STATE.RUNNING if self.debug: pdb.set_trace() async def end_step(self): """Standard sequence's end step. Evaluates the sequence's final state. Collects node's result and put them in the sequence's result attribute. """ # current_tpl = Sequence.current_tpl.get() logger.info("this is the <end> of the sequence: %s", self.name) self.state = T_STATE.FINISHED G = self.graph states = [ G.nodes[key]["node"].in_error for key in nx.topological_sort(G) ] # grab node's state and check for ERROR if any(states): self.in_error = True # collects node results # self.result = { # key: G.nodes[key]["node"].result for key in nx.topological_sort(G) # } # logger.debug("collected results (%s): %s", self.id, self.result) @property def start_node(self): """Returns the start node. If it does not exist, it creates it.""" if self._start_node is None: self._start_node = StartNode( self.start_step, name="start", id="start_" + self.id ) return self._start_node @property def end_node(self): """Returns the end node. If it does not exist, it creates it.""" if self._end_node is None: self._end_node = EndNode( self.end_step, name="end", id="end_" + self.id ) return self._end_node def append(self, s): """Appends a node to the Sequence""" self.seq.append(s) def make_sequence(self): """Builds this sequence execution graph.""" logger.debug("create sequence: %s", self.name) G = self.graph G.add_node(self.start_node.id, node=self.start_node) G.add_node(self.end_node.id, node=self.end_node) parent = self.start_node for el in self.seq: logger.debug("add node: %s", el) el.make_sequence() G.add_node(el.id, node=el) if el.deps: for d in el.deps: assert d.id in G G.add_edge(d.id, el.id) else: G.add_edge(parent.id, el.id) parent = el leaf_nodes = [ node for node in G.nodes() if G.in_degree(node) != 0 and G.out_degree(node) == 0 ] logger.debug("leaf nodes: %s", leaf_nodes) for l in leaf_nodes: G.add_edge(l, self.end_node.id) def create_node_tasks(self, resume=False): """Creates Task object associated to this node""" G = self.graph logger.info( "Instance node tasks -- %s (resume = %s)", self.name, resume ) logger.debug("G: %s", G.nodes) # sort the Graph (topological sort) so it does not matter what order are # nodes added to the graph g_list = list(nx.topological_sort(G)) g_nodes = [G.nodes[key] for key in g_list] if resume: g_nodes = [ node for node in g_nodes if node["node"].state != T_STATE.FINISHED ] for node in g_nodes: key = node["node"].id in_edges = [u for u, v in G.in_edges(key)] logger.debug("%s -> IN edges: %s", self._name(node), in_edges) input_list = [G.nodes[key] for key in in_edges] new_task = node["node"].make_task(node, input_list, resume) node["task"] = new_task # if node["node"].state != T_STATE.FINISHED: node["node"].state = T_STATE.SCHEDULED def reschedule_node(self, node_id): """ Reschedule a node for execution """ # id = node.id G = self.graph g_node = G.nodes[node_id] # get Graph's node dict node = g_node["node"] task = g_node["task"] # reset state node.state = T_STATE.SCHEDULED task.reset(resume=True)
[docs] def start(self, make_sequence=True, resume=False): """This is the entry point for Sequence execution. Returns: Returns the :class:`SeqTask` object that executes the sequence Raises: Exception: Any exception received is re-raised and the sequence is aborted. """ if resume: make_sequence = False context = self.context context["started"] = datetime.now() context["name"] = self.name if make_sequence: self.make_sequence() task = asyncio.create_task(self.execute(resume)) return task
async def run(self): """Runs the node -- This executes the task""" result = None self.t_start = datetime.now() if self.skip: self.state = T_STATE.FINISHED self.t_end = datetime.now() else: try: task = self.main_task() result = await task.run() self.t_end = datetime.now() except Exception as e: self.in_error = True self.state = T_STATE.CANCELLED self.exception = e raise return result async def execute(self, resume=False): """Executes node -- this just creates the asyncio task""" logger.debug("execute: %s", self) Sequence.current_tpl.set(self) self.create_node_tasks(resume=resume) _result = await self.run()
[docs] async def resume(self): """Resume node execution""" await self.start(resume=True)
def main_task(self): """ Returns the objective node of the sequence -- the end node""" G = self.graph node = self.end_node task = G.nodes[node.id]["task"] return task def _name(self, node): return node["node"].name async def __call__(self, resume=False): result = await self.execute(resume) return result
[docs] def abort(self): """ Aborts the sequence. Goes trough the full graph and aborts the tasks associated to nodes (if any). """ from .utils import Visitor for s, node_id in Visitor(self): # node = s.get_node(node_id) try: task = s.get_task(node_id) logger.info("abort: %s", task) _unused = task and task.abort() except KeyError: pass
def nodes(self): """Return nodes from Graph""" return self.G.nodes def get_node(self, node_id): """Get node by id""" return self.G.nodes[node_id]["node"] def get_task(self, node_id): """Get task by node_id""" return self.G.nodes[node_id]["task"] @property def state(self): return super().state async def publish_state(self): async with self._cond: self._cond.notify(1) @state.setter def state(self, value): """Sets the node state""" super(Sequence, self.__class__).state.fset(self, value) asyncio.create_task(self.publish_state()) @staticmethod def get_context(): assert Sequence.root assert Sequence.current_tpl return Sequence.current_tpl.get().context
[docs] @staticmethod def create(*args, **kw): """ Sequence node constructor. Args: *args: Variable length list of nodes or coroutines that compose the sequence. Keyword Args: id: Node id name: node name """ tpl = Sequence(**kw) tpl.seq = list(args) return tpl