Source code for seq.lib.nodes.sequence

"""Implements the Sequence node for the sequencer"""
# pylint: disable= invalid-name, too-many-instance-attributes, too-many-public-methods
import logging
from datetime import datetime
import asyncio
import contextvars as cv
import attr

import networkx as nx
from zope.interface import implementer
from .. import ob
from ..counter import Counter
from .state import T_STATE
from .interface import (
    INode,
    _BaseNode,
)

from .action import (
    StartNode,
    EndNode,
    make_node,
)

from .utils import uniqueId

logger = logging.getLogger(__name__)
user_logger = logging.getLogger("seq.user")


[docs]@implementer(INode) @attr.s class Sequence(_BaseNode): """Basic sequencer node. Use :meth:`create` to properly build :class:`Sequence` objects. Attributes: state: node state ============ =============================== Context Variables --------------------------------------------- Name Desc ============ =============================== current_seq The parent of the current node root Top level DAG's root ============ =============================== Examples: >>> s = Sequence.create(a,b, name="my sequence") # execute Sequence >>> await s.start() # Get running Sequence from Inside function a(): >>> def a(): ... current_seq = Sequence.current_seq ... # now current_seq is the node s ... assert current_seq == s""" _seq = attr.ib(init=False, default=attr.Factory(list), repr=False) graph = attr.ib(init=False, default=attr.Factory(nx.DiGraph), repr=False) _start_node = attr.ib(init=False, default=None, repr=False) _end_node = attr.ib(init=False, default=None, repr=False) _context = attr.ib(init=False, default=attr.Factory(dict), repr=False) debug = attr.ib(init=False, default=False, repr=False) _parameters = attr.ib(init=False, repr=False, default=attr.Factory(dict)) # contextvar current_seq = cv.ContextVar("current_seq", default=None) root = cv.ContextVar("root", default=None) # set by OB runner def __attrs_post_init__(self): """ Assigns node's name and id. """ # if self.serial_number is None: # self.serial_number = Counter.new_value() if self.name is None: self.name = "Sequence" if self.id is None: self.id = "%s_%s" % (self.name, uniqueId(id(self))) self._result = {} self.running_checkpoint.set() # running is allowed # make it iterable def __iter__(self): return self.graph.__iter__() @property def context(self): """Get context dictionary, preferably from root node.""" root = Sequence.root.get() if root is not None: self._context = root._context else: Sequence.root.set(self) # self._context = {} return self._context @context.setter def context(self, ctx): self._context = ctx @property def seq(self): """Retrieves the sequence list. :meta private: """ return self._seq @seq.setter def seq(self, s): """ List used to create the Sequence's graph. Parameters ----------- s (Iterable): Sequence list. Members of 's' represents sequence nodes. If they are coroutines they are automagically converted to :class:'Action' nodes. """ from collections.abc import Iterable assert isinstance(s, Iterable) self._seq = list(map(make_node, s)) # print("SEQ", self._seq) @property def G(self): """returns the graph object""" return self.graph
[docs] async def start_step(self): """Standard sequence's start step. Sets the sequence's state to RUNNING """ # logger.info( # "This is the <start> of a sequence: %s", Sequence.current_seq.get() # ) user_logger.info("Sequence %s starting", self.name) self.state = T_STATE.RUNNING if self.debug: breakpoint()
[docs] async def end_step(self): """Standard sequence's end step. Evaluates the sequence's final state. Collects node's result and put them in the sequence's result attribute. """ user_logger.info("Sequence %s finish", self.name) G = self.graph states = [ G.nodes[key]["node"].in_error for key in nx.topological_sort(G) ] # grab node's state and check for ERROR self.in_error = True if any(states) else False self.state = T_STATE.FINISHED # This publishes node's state.
# collects node results # self.result = { # key: G.nodes[key]["node"].result for key in nx.topological_sort(G) # } # #logger.debug("collected results (%s): %s", self.id, self.result) @property def start_node(self): """Returns the start node. If it does not exist, it creates it.""" if self._start_node is None: self._start_node = StartNode( self.start_step, name="begin", id="begin_" + self.id ) return self._start_node @property def end_node(self): """Returns the end node. If it does not exist, it creates it.""" if self._end_node is None: self._end_node = EndNode( self.end_step, name="end", id="end_" + self.id ) return self._end_node
[docs] def append(self, s): """Appends a node to the Sequence""" self.seq.append(s)
[docs] def make_sequence(self, parent_tpl=None): """Builds this sequence execution graph. Joins the Sequence's nodes together. """ ctrl = ob.OB.controller.get() if ctrl: self.runtime_flags = ctrl.runtime_flags.get(self.serial_number, 0) self.graph = G = nx.DiGraph() G.add_node(self.start_node.id, node=self.start_node) parent = self.start_node for el in self.seq: el.make_sequence(parent_tpl) # pass down link to Tpl if any el.state = T_STATE.NOT_STARTED G.add_node(el.id, node=el) G.add_edge(parent.id, el.id) parent = el leaf_nodes = [ node for node in G.nodes() if G.in_degree(node) != 0 and G.out_degree(node) == 0 ] G.add_node(self.end_node.id, node=self.end_node) self.start_node.state = self.end_node.state = T_STATE.NOT_STARTED for l in leaf_nodes: G.add_edge(l, self.end_node.id)
[docs] def create_node_tasks(self, resume=False): """Creates Task object associated to this node""" G = self.graph g_list = list(nx.topological_sort(G)) g_nodes = [G.nodes[key] for key in g_list] if resume: g_nodes = [ node for node in g_nodes if node["node"].state != T_STATE.FINISHED ] for node in g_nodes: key = node["node"].id in_edges = [u for u, v in G.in_edges(key)] # logger.debug("%s -> IN edges: %s", self._name(node), in_edges) input_list = [G.nodes[key] for key in in_edges] new_task = node["node"].make_task(node, input_list, resume) node["task"] = new_task node["node"].state = T_STATE.SCHEDULED
[docs] def reschedule_node(self, node_id): """Reschedule a node for execution""" # id = node.id G = self.graph g_node = G.nodes[node_id] # get Graph's node dict node = g_node["node"] task = g_node["task"] # reset state node.state = T_STATE.SCHEDULED task.reset(resume=True)
[docs] def start(self, make_sequence=True, resume=False): """This is the entry point for Sequence execution. Returns: Returns the :class:`SeqTask` object that executes the sequence Raises: Exception: Any exception received is re-raised and the sequence is aborted. """ if resume: make_sequence = False context = self.context context["started"] = datetime.now() context["name"] = self.name if make_sequence: self.make_sequence() task = asyncio.create_task(self.execute(resume)) # logger.debug("TASK ....: %s", task) return task
[docs] async def run(self): """Runs the node -- This executes the task""" from ..ob import RTFLAG result = None # logger.debug("check pause flag") if self.runtime_flags & RTFLAG.PAUSE: self.state = T_STATE.PAUSED self.pause() # check permission to execute await self.running_checkpoint.wait() # now we are running ... if self.running_checkpoint.is_set(): self.t_start = datetime.now() # logger.debug("check skip flag") if self.runtime_flags & RTFLAG.SKIP: self.skip = True self.state = T_STATE.FINISHED user_logger.info("Sequence %s skipped", self.name) else: try: task = self.main_task() result = await task.run() except asyncio.CancelledError: self.state = T_STATE.ABORTED raise except Exception as e: self.in_error = True self.state = T_STATE.CANCELLED self.exception = e raise return result
[docs] async def execute(self, resume=False, propagate=False): """Executes node -- this just creates the asyncio task""" # if propagate: # for el in self.seq: # el._parameters = self.parameters # logger.info("execute with _params (prop=%s): %s| %s", propagate, self.name, self.parameters) Sequence.current_seq.set(self) self.create_node_tasks(resume=resume) try: _result = await self.run() except: raise finally: self.t_end = datetime.now() self.context["finished"] = self.t_end return _result
[docs] async def resume(self): """Resume node execution""" super().resume() if self.state != T_STATE.PAUSED: await self.start(resume=True)
[docs] def main_task(self): """Returns the objective node of the sequence -- the end node""" G = self.graph node = self.end_node task = G.nodes[node.id]["task"] return task
def _name(self, node): return node["node"].name async def __call__(self, resume=False): result = await self.execute(resume) return result
[docs] def abort(self): """ Aborts the sequence. Goes trough the full graph and aborts the tasks associated to nodes (if any). Do not allow nodes to run by taking away its running_checkpoint attribute. """ from .utils import Visitor for s, node_id in Visitor(self): #node = s.get_node(node_id) # node.running_checkpoint.clear() try: task = s.get_task(node_id) _ = task and task.abort() except KeyError: pass
[docs] def nodes(self): """Return nodes from Graph""" return list(nx.topological_sort(self.G))
# return self.G.nodes
[docs] def get_node(self, node_id): """Get node by id""" return self.G.nodes[node_id]["node"]
[docs] def get_task(self, node_id): """Get task by node_id""" return self.G.nodes[node_id]["task"]
@property def parameters(self): """Return parameters""" return self._parameters
[docs] def par(self, k): """Get a parameter value""" value = self._parameters[k]["value"] #user_logger.info("Get par: %s -> %s", k , value) return value
[docs] def set(self, p): """Sets the value of a parameter""" if not p.name in self._parameters: raise KeyError(p.name) logger.debug("Set var: %s -> %s", p.name, p.value) self._parameters[p.name]["value"] = p.value
# def set_params(self, params): # """ "Update parameters from list""" # for k, v in params.items(): # self._parameters[k] = v @property def state(self): return super().state def _publish_state(self): c = ob.OB.controller.get() if c: c.notify_state_change(self, datetime.now().isoformat()) @state.setter def state(self, value): """Sets the node state""" super(Sequence, self.__class__).state.fset(self, value) self._publish_state() @property def full_state(self): return super().full_state
[docs] @staticmethod def get_context(): assert Sequence.root assert Sequence.current_seq return Sequence.current_seq.get().context
[docs] @staticmethod def create(*args, **kw): """ Sequence node constructor. Args: *args: Variable length list of nodes or coroutines that compose the sequence. Keyword Args: id: Node id name: node name """ tpl = Sequence(**kw) tpl.seq = list(args) return tpl