"""Implements the Sequence node for the sequencer"""
import logging
import pdb
from datetime import datetime
import asyncio
import contextvars as cv
import attr
import networkx as nx
from zope.interface import implementer
from .state import T_STATE
from .interface import (
INode,
_BaseNode,
)
from .action import (
uniqueId,
StartNode,
EndNode,
make_node,
)
logger = logging.getLogger(__name__)
[docs]@implementer(INode)
@attr.s
class Sequence(_BaseNode):
"""Basic sequencer node.
Use :meth:`create` to build poperly :class:`Sequence` objects.
Attributes:
state: node state
============ ===============================
Context Variables
---------------------------------------------
Name Desc
============ ===============================
current_tpl The parent of the current node
root Top level DAG's root
============ ===============================
Examples:
>>> s = Sequence.create(a,b, name="my sequence")
# execute Sequence
>>> await s.start()
# Get running Sequence from Inside function a():
>>> def a():
... current_tpl = Sequence.current_tpl
... # now current_tpl is the node s
... assert current_tpl == s
"""
_seq = attr.ib(init=False, default=attr.Factory(list), repr=False)
graph = attr.ib(init=False, default=attr.Factory(nx.DiGraph), repr=False)
_start_node = attr.ib(init=False, default=None, repr=False)
_end_node = attr.ib(init=False, default=None, repr=False)
_context = attr.ib(init=False, default=attr.Factory(dict), repr=False)
debug = attr.ib(init=False, default=False, repr=False)
# Loop's contextvar
current_tpl = cv.ContextVar("current_tpl", default=None)
root = cv.ContextVar("root", default=None) # set by OB runner
def __attrs_post_init__(self):
"""
Assigns node's name and id.
"""
if self.name is None:
self.name = "Sequence"
if self.id is None:
self.id = "%s_%s" % (self.name, uniqueId(id(self)))
self._result = {}
@property
def context(self):
""" Get context dictionary, preferably from root node.
"""
root = Sequence.root.get()
if root is not None:
self._context = root._context
else:
Sequence.root.set(self)
# self._context = {}
return self._context
@context.setter
def context(self, ctx):
self._context = ctx
@property
def seq(self):
"""Retrieves the sequence list.
:meta private:
"""
return self._seq
@seq.setter
def seq(self, s):
"""
List used to create the Sequence's graph.
Parameters
-----------
s (Iterable): Sequence list.
Members of 's' represents sequence nodes. If they are coroutines
they are automagically converted to :class:'Action' nodes.
"""
from collections.abc import Iterable
assert isinstance(s, Iterable)
self._seq = list(map(make_node, s))
@property
def G(self):
"""returns the graph object"""
return self.graph
async def start_step(self):
"""Standard sequence's start step.
Sets the sequence's state to RUNNING
"""
logger.info(
"This is the <start> of a sequence: %s", Sequence.current_tpl.get()
)
self.state = T_STATE.RUNNING
if self.debug:
pdb.set_trace()
async def end_step(self):
"""Standard sequence's end step.
Evaluates the sequence's final state.
Collects node's result and put them in the sequence's result attribute.
"""
# current_tpl = Sequence.current_tpl.get()
logger.info("this is the <end> of the sequence: %s", self.name)
self.state = T_STATE.FINISHED
G = self.graph
states = [
G.nodes[key]["node"].in_error for key in nx.topological_sort(G)
]
# grab node's state and check for ERROR
if any(states):
self.in_error = True
# collects node results
# self.result = {
# key: G.nodes[key]["node"].result for key in nx.topological_sort(G)
# }
# logger.debug("collected results (%s): %s", self.id, self.result)
@property
def start_node(self):
"""Returns the start node.
If it does not exist, it creates it."""
if self._start_node is None:
self._start_node = StartNode(
self.start_step, name="start", id="start_" + self.id
)
return self._start_node
@property
def end_node(self):
"""Returns the end node.
If it does not exist, it creates it."""
if self._end_node is None:
self._end_node = EndNode(
self.end_step, name="end", id="end_" + self.id
)
return self._end_node
def append(self, s):
"""Appends a node to the Sequence"""
self.seq.append(s)
def make_sequence(self):
"""Builds this sequence execution graph."""
logger.debug("create sequence: %s", self.name)
G = self.graph
G.add_node(self.start_node.id, node=self.start_node)
G.add_node(self.end_node.id, node=self.end_node)
parent = self.start_node
for el in self.seq:
logger.debug("add node: %s", el)
el.make_sequence()
G.add_node(el.id, node=el)
if el.deps:
for d in el.deps:
assert d.id in G
G.add_edge(d.id, el.id)
else:
G.add_edge(parent.id, el.id)
parent = el
leaf_nodes = [
node
for node in G.nodes()
if G.in_degree(node) != 0 and G.out_degree(node) == 0
]
logger.debug("leaf nodes: %s", leaf_nodes)
for l in leaf_nodes:
G.add_edge(l, self.end_node.id)
def create_node_tasks(self, resume=False):
"""Creates Task object associated to this node"""
G = self.graph
logger.info(
"Instance node tasks -- %s (resume = %s)", self.name, resume
)
logger.debug("G: %s", G.nodes)
# sort the Graph (topological sort) so it does not matter what order are
# nodes added to the graph
g_list = list(nx.topological_sort(G))
g_nodes = [G.nodes[key] for key in g_list]
if resume:
g_nodes = [
node
for node in g_nodes
if node["node"].state != T_STATE.FINISHED
]
for node in g_nodes:
key = node["node"].id
in_edges = [u for u, v in G.in_edges(key)]
logger.debug("%s -> IN edges: %s", self._name(node), in_edges)
input_list = [G.nodes[key] for key in in_edges]
new_task = node["node"].make_task(node, input_list, resume)
node["task"] = new_task
# if node["node"].state != T_STATE.FINISHED:
node["node"].state = T_STATE.SCHEDULED
def reschedule_node(self, node_id):
""" Reschedule a node for execution
"""
# id = node.id
G = self.graph
g_node = G.nodes[node_id] # get Graph's node dict
node = g_node["node"]
task = g_node["task"]
# reset state
node.state = T_STATE.SCHEDULED
task.reset(resume=True)
[docs] def start(self, make_sequence=True, resume=False):
"""This is the entry point for Sequence execution.
Returns:
Returns the :class:`SeqTask` object that executes the sequence
Raises:
Exception: Any exception received is re-raised and the sequence is aborted.
"""
if resume:
make_sequence = False
context = self.context
context["started"] = datetime.now()
context["name"] = self.name
if make_sequence:
self.make_sequence()
task = asyncio.create_task(self.execute(resume))
return task
async def run(self):
"""Runs the node -- This executes the task"""
result = None
self.t_start = datetime.now()
if self.skip:
self.state = T_STATE.FINISHED
self.t_end = datetime.now()
else:
try:
task = self.main_task()
result = await task.run()
self.t_end = datetime.now()
except Exception as e:
self.in_error = True
self.state = T_STATE.CANCELLED
self.exception = e
raise
return result
async def execute(self, resume=False):
"""Executes node -- this just creates the asyncio task"""
logger.debug("execute: %s", self)
Sequence.current_tpl.set(self)
self.create_node_tasks(resume=resume)
_result = await self.run()
[docs] async def resume(self):
"""Resume node execution"""
await self.start(resume=True)
def main_task(self):
""" Returns the objective node of the sequence -- the end node"""
G = self.graph
node = self.end_node
task = G.nodes[node.id]["task"]
return task
def _name(self, node):
return node["node"].name
async def __call__(self, resume=False):
result = await self.execute(resume)
return result
[docs] def abort(self):
"""
Aborts the sequence.
Goes trough the full graph and aborts the tasks associated to nodes (if any).
"""
from .utils import Visitor
for s, node_id in Visitor(self):
# node = s.get_node(node_id)
try:
task = s.get_task(node_id)
logger.info("abort: %s", task)
_unused = task and task.abort()
except KeyError:
pass
def nodes(self):
"""Return nodes from Graph"""
return self.G.nodes
def get_node(self, node_id):
"""Get node by id"""
return self.G.nodes[node_id]["node"]
def get_task(self, node_id):
"""Get task by node_id"""
return self.G.nodes[node_id]["task"]
@property
def state(self):
return super().state
async def publish_state(self):
async with self._cond:
self._cond.notify(1)
@state.setter
def state(self, value):
"""Sets the node state"""
super(Sequence, self.__class__).state.fset(self, value)
asyncio.create_task(self.publish_state())
@staticmethod
def get_context():
assert Sequence.root
assert Sequence.current_tpl
return Sequence.current_tpl.get().context
[docs] @staticmethod
def create(*args, **kw):
"""
Sequence node constructor.
Args:
*args: Variable length list of nodes or coroutines that compose the sequence.
Keyword Args:
id: Node id
name: node name
"""
tpl = Sequence(**kw)
tpl.seq = list(args)
return tpl