"""Implements the Sequence node for the sequencer"""
# pylint: disable= invalid-name, too-many-instance-attributes, too-many-public-methods
import logging
from datetime import datetime
import asyncio
import contextvars as cv
import attr
import networkx as nx
from zope.interface import implementer
from .. import ob
from ..counter import Counter
from .state import T_STATE
from .interface import (
INode,
_BaseNode,
)
from .action import (
StartNode,
EndNode,
make_node,
)
from .utils import uniqueId
logger = logging.getLogger(__name__)
user_logger = logging.getLogger("seq.user")
[docs]@implementer(INode)
@attr.s
class Sequence(_BaseNode):
"""Basic sequencer node.
Use :meth:`create` to properly build :class:`Sequence` objects.
Attributes:
state: node state
============ ===============================
Context Variables
---------------------------------------------
Name Desc
============ ===============================
current_seq The parent of the current node
root Top level DAG's root
============ ===============================
Examples:
>>> s = Sequence.create(a,b, name="my sequence")
# execute Sequence
>>> await s.start()
# Get running Sequence from Inside function a():
>>> def a():
... current_seq = Sequence.current_seq
... # now current_seq is the node s
... assert current_seq == s"""
_seq = attr.ib(init=False, default=attr.Factory(list), repr=False)
graph = attr.ib(init=False, default=attr.Factory(nx.DiGraph), repr=False)
_start_node = attr.ib(init=False, default=None, repr=False)
_end_node = attr.ib(init=False, default=None, repr=False)
_context = attr.ib(init=False, default=attr.Factory(dict), repr=False)
debug = attr.ib(init=False, default=False, repr=False)
_parameters = attr.ib(init=False, repr=False, default=attr.Factory(dict))
# contextvar
current_seq = cv.ContextVar("current_seq", default=None)
root = cv.ContextVar("root", default=None) # set by OB runner
def __attrs_post_init__(self):
"""
Assigns node's name and id.
"""
# if self.serial_number is None:
# self.serial_number = Counter.new_value()
if self.name is None:
self.name = "Sequence"
if self.id is None:
self.id = "%s_%s" % (self.name, uniqueId(id(self)))
self._result = {}
self.running_checkpoint.set() # running is allowed
# make it iterable
def __iter__(self):
return self.graph.__iter__()
@property
def context(self):
"""Get context dictionary, preferably from root node."""
root = Sequence.root.get()
if root is not None:
self._context = root._context
else:
Sequence.root.set(self)
# self._context = {}
return self._context
@context.setter
def context(self, ctx):
self._context = ctx
@property
def seq(self):
"""Retrieves the sequence list.
:meta private:
"""
return self._seq
@seq.setter
def seq(self, s):
"""
List used to create the Sequence's graph.
Parameters
-----------
s (Iterable): Sequence list.
Members of 's' represents sequence nodes. If they are coroutines
they are automagically converted to :class:'Action' nodes.
"""
from collections.abc import Iterable
assert isinstance(s, Iterable)
self._seq = list(map(make_node, s))
# print("SEQ", self._seq)
@property
def G(self):
"""returns the graph object"""
return self.graph
[docs] async def start_step(self):
"""Standard sequence's start step.
Sets the sequence's state to RUNNING
"""
# logger.info(
# "This is the <start> of a sequence: %s", Sequence.current_seq.get()
# )
user_logger.info("Sequence %s starting", self.name)
self.state = T_STATE.RUNNING
if self.debug:
breakpoint()
[docs] async def end_step(self):
"""Standard sequence's end step.
Evaluates the sequence's final state.
Collects node's result and put them in the sequence's result attribute.
"""
user_logger.info("Sequence %s finish", self.name)
G = self.graph
states = [
G.nodes[key]["node"].in_error for key in nx.topological_sort(G)
]
# grab node's state and check for ERROR
self.in_error = True if any(states) else False
self.state = T_STATE.FINISHED # This publishes node's state.
# collects node results
# self.result = {
# key: G.nodes[key]["node"].result for key in nx.topological_sort(G)
# }
# #logger.debug("collected results (%s): %s", self.id, self.result)
@property
def start_node(self):
"""Returns the start node.
If it does not exist, it creates it."""
if self._start_node is None:
self._start_node = StartNode(
self.start_step, name="begin", id="begin_" + self.id
)
return self._start_node
@property
def end_node(self):
"""Returns the end node.
If it does not exist, it creates it."""
if self._end_node is None:
self._end_node = EndNode(
self.end_step, name="end", id="end_" + self.id
)
return self._end_node
[docs] def append(self, s):
"""Appends a node to the Sequence"""
self.seq.append(s)
[docs] def make_sequence(self, parent_tpl=None):
"""Builds this sequence execution graph.
Joins the Sequence's nodes together.
"""
ctrl = ob.OB.controller.get()
if ctrl:
self.runtime_flags = ctrl.runtime_flags.get(self.serial_number, 0)
self.graph = G = nx.DiGraph()
G.add_node(self.start_node.id, node=self.start_node)
parent = self.start_node
for el in self.seq:
el.make_sequence(parent_tpl) # pass down link to Tpl if any
el.state = T_STATE.NOT_STARTED
G.add_node(el.id, node=el)
G.add_edge(parent.id, el.id)
parent = el
leaf_nodes = [
node
for node in G.nodes()
if G.in_degree(node) != 0 and G.out_degree(node) == 0
]
G.add_node(self.end_node.id, node=self.end_node)
self.start_node.state = self.end_node.state = T_STATE.NOT_STARTED
for l in leaf_nodes:
G.add_edge(l, self.end_node.id)
[docs] def create_node_tasks(self, resume=False):
"""Creates Task object associated to this node"""
G = self.graph
g_list = list(nx.topological_sort(G))
g_nodes = [G.nodes[key] for key in g_list]
if resume:
g_nodes = [
node
for node in g_nodes
if node["node"].state != T_STATE.FINISHED
]
for node in g_nodes:
key = node["node"].id
in_edges = [u for u, v in G.in_edges(key)]
# logger.debug("%s -> IN edges: %s", self._name(node), in_edges)
input_list = [G.nodes[key] for key in in_edges]
new_task = node["node"].make_task(node, input_list, resume)
node["task"] = new_task
node["node"].state = T_STATE.SCHEDULED
[docs] def reschedule_node(self, node_id):
"""Reschedule a node for execution"""
# id = node.id
G = self.graph
g_node = G.nodes[node_id] # get Graph's node dict
node = g_node["node"]
task = g_node["task"]
# reset state
node.state = T_STATE.SCHEDULED
task.reset(resume=True)
[docs] def start(self, make_sequence=True, resume=False):
"""This is the entry point for Sequence execution.
Returns:
Returns the :class:`SeqTask` object that executes the sequence
Raises:
Exception: Any exception received is re-raised and the sequence is aborted.
"""
if resume:
make_sequence = False
context = self.context
context["started"] = datetime.now()
context["name"] = self.name
if make_sequence:
self.make_sequence()
task = asyncio.create_task(self.execute(resume))
# logger.debug("TASK ....: %s", task)
return task
[docs] async def run(self):
"""Runs the node -- This executes the task"""
from ..ob import RTFLAG
result = None
# logger.debug("check pause flag")
if self.runtime_flags & RTFLAG.PAUSE:
self.state = T_STATE.PAUSED
self.pause()
# check permission to execute
await self.running_checkpoint.wait()
# now we are running ...
if self.running_checkpoint.is_set():
self.t_start = datetime.now()
# logger.debug("check skip flag")
if self.runtime_flags & RTFLAG.SKIP:
self.skip = True
self.state = T_STATE.FINISHED
user_logger.info("Sequence %s skipped", self.name)
else:
try:
task = self.main_task()
result = await task.run()
except asyncio.CancelledError:
self.state = T_STATE.ABORTED
raise
except Exception as e:
self.in_error = True
self.state = T_STATE.CANCELLED
self.exception = e
raise
return result
[docs] async def execute(self, resume=False, propagate=False):
"""Executes node -- this just creates the asyncio task"""
# if propagate:
# for el in self.seq:
# el._parameters = self.parameters
# logger.info("execute with _params (prop=%s): %s| %s", propagate, self.name, self.parameters)
Sequence.current_seq.set(self)
self.create_node_tasks(resume=resume)
try:
_result = await self.run()
except:
raise
finally:
self.t_end = datetime.now()
self.context["finished"] = self.t_end
return _result
[docs] async def resume(self):
"""Resume node execution"""
super().resume()
if self.state != T_STATE.PAUSED:
await self.start(resume=True)
[docs] def main_task(self):
"""Returns the objective node of the sequence -- the end node"""
G = self.graph
node = self.end_node
task = G.nodes[node.id]["task"]
return task
def _name(self, node):
return node["node"].name
async def __call__(self, resume=False):
result = await self.execute(resume)
return result
[docs] def abort(self):
"""
Aborts the sequence.
Goes trough the full graph and aborts the tasks associated to nodes (if any).
Do not allow nodes to run by taking away its running_checkpoint attribute.
"""
from .utils import Visitor
for s, node_id in Visitor(self):
#node = s.get_node(node_id)
# node.running_checkpoint.clear()
try:
task = s.get_task(node_id)
_ = task and task.abort()
except KeyError:
pass
[docs] def nodes(self):
"""Return nodes from Graph"""
return list(nx.topological_sort(self.G))
# return self.G.nodes
[docs] def get_node(self, node_id):
"""Get node by id"""
return self.G.nodes[node_id]["node"]
[docs] def get_task(self, node_id):
"""Get task by node_id"""
return self.G.nodes[node_id]["task"]
@property
def parameters(self):
"""Return parameters"""
return self._parameters
[docs] def par(self, k):
"""Get a parameter value"""
value = self._parameters[k]["value"]
#user_logger.info("Get par: %s -> %s", k , value)
return value
[docs] def set(self, p):
"""Sets the value of a parameter"""
if not p.name in self._parameters:
raise KeyError(p.name)
logger.debug("Set var: %s -> %s", p.name, p.value)
self._parameters[p.name]["value"] = p.value
# def set_params(self, params):
# """ "Update parameters from list"""
# for k, v in params.items():
# self._parameters[k] = v
@property
def state(self):
return super().state
def _publish_state(self):
c = ob.OB.controller.get()
if c:
c.notify_state_change(self, datetime.now().isoformat())
@state.setter
def state(self, value):
"""Sets the node state"""
super(Sequence, self.__class__).state.fset(self, value)
self._publish_state()
@property
def full_state(self):
return super().full_state
[docs] @staticmethod
def get_context():
assert Sequence.root
assert Sequence.current_seq
return Sequence.current_seq.get().context
[docs] @staticmethod
def create(*args, **kw):
"""
Sequence node constructor.
Args:
*args: Variable length list of nodes or coroutines that compose the
sequence.
Keyword Args:
id: Node id
name: node name
"""
tpl = Sequence(**kw)
tpl.seq = list(args)
return tpl