Pārlūkot izejas kodu

Add version 2 of test framework

Joeri Exelmans 5 gadi atpakaļ
vecāks
revīzija
7d1f877767

+ 14 - 5
src/sccd/runtime/controller.py

@@ -59,7 +59,7 @@ class Controller:
 
     # Run until the event queue has no more due events wrt given timestamp and until all instances are stable.
     # If no timestamp is given (now = None), run until event queue is empty.
-    def run_until(self, now: Optional[Timestamp], pipe: queue.Queue):
+    def run_until(self, now: Optional[Timestamp], pipe: queue.Queue, interrupt: queue.Queue = queue.SimpleQueue()):
 
         unstable: List[Instance] = []
 
@@ -78,7 +78,7 @@ class Controller:
                 pipe.put(pipe_events, block=True, timeout=None)
 
         # Helper. Let all unstable instances execute big steps until they are stable
-        def stabilize():
+        def do_stabilize():
             while unstable:
                 for i in reversed(range(len(unstable))):
                     instance = unstable[i]
@@ -86,9 +86,16 @@ class Controller:
                     process_big_step_output(output)
                     if stable:
                         del unstable[i]
+                try:
+                    interrupt.get_nowait()
+                    return False # interrupted
+                except queue.Empty:
+                    pass
             else:
-                return
+                # already stable
+                return True
             print_debug("all instances stabilized.")
+            return True
 
 
         if not self.initialized:
@@ -114,7 +121,8 @@ class Controller:
                 # check if there's a time leap
                 if timestamp is not self.simulated_time:
                     # before every "time leap", continue to run instances until they are stable.
-                    stabilize()
+                    if not do_stabilize():
+                        return
                     # make time leap
                     self.simulated_time = timestamp
                 # run all instances for whom there are events
@@ -124,6 +132,7 @@ class Controller:
                     if not stable:
                         unstable.append(instance)
             # 2. No more due events -> stabilize
-            stabilize()
+            if not do_stabilize():
+                return
 
         self.simulated_time = now

+ 31 - 0
src/sccd/runtime/model.py

@@ -0,0 +1,31 @@
+from dataclasses import *
+from typing import *
+from sccd.runtime.statechart_syntax import *
+from sccd.runtime.semantic_options import *
+
+# Mapping from event name to event ID
+class EventNamespace:
+  def __init__(self):
+    self.mapping: Dict[str, int] = {}
+
+  def assign_id(self, event: str) -> int:
+    return self.mapping.setdefault(event, len(self.mapping))
+
+  def get_id(self, event: str) -> int:
+    return self.mapping[event]
+
+@dataclass
+class Statechart:
+  root: State
+  states: Dict[str, State] # mapping from state "full name" (e.g. "/parallel/ortho1/a") to state
+  state_list: List[State] # depth-first order
+  transition_list: List[Transition] # source state depth-first order, then document order
+  semantics: SemanticConfiguration = SemanticConfiguration()
+
+class Model:
+  def __init__(self):
+    self.event_namespace: EventNamespace = EventNamespace()
+    self.inports: List[str] = []
+    self.outports: List[str] = []
+    self.classes: Dict[str, Statechart] = {}
+    self.default_class: Optional[str] = None

+ 2 - 5
src/sccd/runtime/object_manager.py

@@ -16,16 +16,13 @@ class ObjectManager(Instance):
         # we need to maintain this set in order to do broadcasts
         self.instances = [self] # object manager is an instance too!
 
-        self._classmodels = {}
-
         self._create(model.default_class)
 
     def _create(self, class_name) -> StatechartInstance:
         # Instantiate the model for each class at most once:
         # The model is shared between instances of the same type.
-        self._classmodels.setdefault(class_name, self.model.classes[class_name]())
-        model = self._classmodels[class_name]
-        i = StatechartInstance(model.statechart, self)
+        statechart = self.model.classes[class_name]
+        i = StatechartInstance(statechart, self)
         self.instances.append(i)
         return i
 

+ 6 - 7
src/sccd/runtime/statechart_instance.py

@@ -143,11 +143,13 @@ class StatechartInstance(Instance):
     # Alternative implementation of candidate generation using mapping from set of enabled events to enabled transitions
     def _transition_candidates2(self) -> Iterable[Transition]:
         enabled_events = self._enabled_events()
-        key = Bitmap.from_list(e.id for e in enabled_events)
+        enabled_events_bitmap = Bitmap.from_list(e.id for e in enabled_events)
+        changed_bitmap = self._combo_step.changed_bitmap
+        key = (enabled_events_bitmap, changed_bitmap)
         try:
             transitions = self.event_mem[key]
         except KeyError:
-            self.event_mem[key] = transitions = [t for t in self.model.transition_list if (not t.trigger or key.has(t.trigger.id))]
+            self.event_mem[key] = transitions = [t for t in self.model.transition_list if (not t.trigger or enabled_events_bitmap.has(t.trigger.id)) and not changed_bitmap.has(t.source.state_id)]
             if self.model.semantics.priority == Priority.SOURCE_CHILD:
                 # Transitions are already in parent -> child (depth-first) order
                 # Only the first transition of the candidates will be executed.
@@ -155,7 +157,7 @@ class StatechartInstance(Instance):
                 transitions.reverse()
 
         def filter_f(t):
-            return self._check_source(t) and self._check_arena(t) and self._check_guard(t, enabled_events)
+            return self._check_source(t) and self._check_guard(t, enabled_events)
         return filter(filter_f, transitions)
 
     def _check_trigger(self, t, events) -> bool:
@@ -175,9 +177,6 @@ class StatechartInstance(Instance):
     def _check_source(self, t) -> bool:
         return self.configuration_bitmap.has(t.source.state_id)
 
-    def _check_arena(self, t) -> bool:
-        return not self._combo_step.changed_bitmap.has(t.source.state_id)
-
     # List of current small step enabled events
     def _enabled_events(self) -> List[Event]:
         events = self._small_step.current_events + self._combo_step.current_events
@@ -223,7 +222,7 @@ class StatechartInstance(Instance):
                     f = lambda s0: not s0.descendants and s0 in s.descendants
                 self.history_values[h.state_id] = list(filter(f, self.configuration))
         print_debug('')
-        print_debug(termcolor.colored('transition %s:  %s 🡪 %s'%(self.model._class.name, t.source.name, t.targets[0].name), 'green'))
+        print_debug(termcolor.colored('transition  %s 🡪 %s'%(t.source.name, t.targets[0].name), 'green'))
         for s in exit_set:
             print_debug(termcolor.colored('  EXIT %s' % s.name, 'green'))
             self.eventless_states -= s.has_eventless_transitions

+ 102 - 0
src/sccd/runtime/test.py

@@ -0,0 +1,102 @@
+import unittest
+from dataclasses import *
+from sccd.runtime.model import *
+from sccd.runtime.controller import *
+
+import threading
+import queue
+
+TestInput = List[InputEvent]
+TestOutput = List[List[Event]]
+
+class Test(unittest.TestCase):
+  def __init__(self, name: str, model: Model, input: TestInput, output: TestOutput):
+    super().__init__()
+    self.name = name
+    self.model = model
+    self.input = input
+    self.output = output
+
+  def __str__(self):
+    return self.name
+
+  def runTest(self):
+    pipe = queue.SimpleQueue()
+    interrupt = queue.SimpleQueue()
+
+    controller = Controller(self.model)
+
+    for i in self.input:
+      controller.add_input(i)
+
+    def controller_thread():
+      try:
+        # Run as-fast-as-possible, always advancing time to the next item in event queue, no sleeping.
+        # The call returns when the event queue is empty and therefore the simulation is finished.
+        controller.run_until(None, pipe, interrupt)
+      except Exception as e:
+        pipe.put(e, block=True, timeout=None)
+        return
+      # Signal end of output
+      pipe.put(None, block=True, timeout=None)
+
+    # start the controller
+    thread = threading.Thread(target=controller_thread)
+    thread.start()
+
+    # check output
+    expected = self.output
+    actual = []
+
+    def fail(msg, kill=False):
+      if kill:
+        interrupt.put(None)
+      thread.join()
+      self.fail(msg + "\nExpected: " + str(expected) + "\nActual: " + str(actual) + ("\n(killed)" if kill else ""))
+
+    while True:
+      data = pipe.get(block=True, timeout=None)
+
+      if isinstance(data, Exception):
+        raise data # Exception was caught in Controller thread, throw it here instead.
+
+      elif data is None:
+        # End of output
+        if len(actual) < len(expected):
+          fail("Less output than expected.")
+        else:
+          return
+
+      else:
+        big_step = data
+        big_step_index = len(actual)
+        actual.append(big_step)
+
+        if len(actual) > len(expected):
+          fail("More output than expected.", kill=True)
+
+        actual_bag = actual[big_step_index]
+        expected_bag = expected[big_step_index]
+
+        if len(actual_bag) != len(expected_bag):
+          fail("Big step %d: output differs." % big_step_index, kill=True)
+
+        # Sort both expected and actual lists of events before comparing.
+        # In theory the set of events at the end of a big step is unordered.
+        key_f = lambda e: "%s.%s"%(e.port, e.name)
+        actual_bag.sort(key=key_f)
+        expected_bag.sort(key=key_f)
+
+        for (act_event, exp_event) in zip(actual_bag, expected_bag):
+          matches = True
+          if exp_event.name != act_event.name :
+            matches = False
+          if exp_event.port != act_event.port :
+            matches = False
+          if len(exp_event.parameters) != len(act_event.parameters) :
+            matches = False
+          for index in range(len(exp_event.parameters)) :
+            if exp_event.parameters[index] !=  act_event.parameters[index]:
+              matches = False
+          if not matches:
+            fail("Big step %d: output differs." % big_step_index, kill=True)

+ 3 - 44
src/sccd/runtime/xml_loader.py

@@ -8,6 +8,7 @@ from sccd.runtime.statechart_syntax import *
 from sccd.runtime.event import Event
 from sccd.runtime.semantic_options import SemanticConfiguration
 from sccd.runtime.controller import InputEvent
+from sccd.runtime.model import *
 import sccd.schema
 
 schema_dir = os.path.dirname(sccd.schema.__file__)
@@ -20,54 +21,17 @@ schema = ET.XMLSchema(ET.parse(schema_path))
 grammar = open(os.path.join(schema_dir,"grammar.g"))
 parser = Lark(grammar, parser="lalr", start=["state_ref", "expr"])
 
-# Mapping from event name to event ID
-class EventNamespace:
-  def __init__(self):
-    self.mapping: Dict[str, int] = {}
-
-  def assign_id(self, event: str) -> int:
-    return self.mapping.setdefault(event, len(self.mapping))
-
-  def get_id(self, event: str) -> int:
-    return self.mapping[event]
-
-# Some types immitating the types that are produced by the compiler
-@dataclass
-class Statechart:
-  root: State
-  states: Dict[str, State] # mapping from state "full name" (e.g. "/parallel/ortho1/a") to state
-  state_list: List[State] # depth-first order
-  transition_list: List[Transition] # source state depth-first order, then document order
-
-  semantics: SemanticConfiguration = SemanticConfiguration()
-  _class: Any = None
-
-@dataclass
-class Class:
-  name: str
-  statechart: Statechart
-
-@dataclass
-class Model:
-  event_namespace: EventNamespace
-  inports: List[str]
-  outports: List[str]
-  classes: Dict[str, Any]
-  default_class: str
-
 @dataclass
 class Test:
   input_events: List[InputEvent]
   expected_events: List[Event]
 
-
 def load_model(src_file) -> Tuple[Model, Optional[Test]]:
   tree = ET.parse(src_file)
   schema.assertValid(tree)
   root = tree.getroot()
 
-  model = Model(event_namespace=EventNamespace(),
-    inports=[], outports=[], classes={}, default_class="")
+  model = Model()
 
   classes = root.findall(".//class", root.nsmap)
   for c in classes:
@@ -77,10 +41,7 @@ def load_model(src_file) -> Tuple[Model, Optional[Test]]:
     scxml_node = c.find("scxml", root.nsmap)
     statechart = load_statechart(scxml_node, model.event_namespace)
 
-    _class = Class(class_name, statechart)
-    statechart._class = _class
-
-    model.classes[class_name] = lambda: _class
+    model.classes[class_name] = statechart
     if default or len(classes) == 1:
       model.default_class = class_name
 
@@ -244,8 +205,6 @@ def load_statechart(scxml_node, event_namespace: EventNamespace) -> Statechart:
   transition_list: List[Transition] = []
   root.init_tree(0, "", states, state_list, transition_list)
 
-  print(transition_list)
-
   for t in transition_list:
     t.optimize()
 

+ 240 - 0
src/sccd/runtime/xml_loader2.py

@@ -0,0 +1,240 @@
+import os
+import lxml.etree as ET
+from lark import Lark
+from sccd.runtime.test import *
+from sccd.runtime.model import *
+from sccd.runtime.statechart_syntax import *
+
+import sccd.schema
+schema_dir = os.path.dirname(sccd.schema.__file__)
+
+# Grammar for parsing state references and expressions
+grammar = open(os.path.join(schema_dir,"grammar.g"))
+parser = Lark(grammar, parser="lalr", start=["state_ref", "expr"])
+
+class ParseError(Exception):
+  def __init__(self, msg):
+    self.msg = msg
+
+def load_expression(parse_node) -> Expression:
+  if parse_node.data == "func_call":
+    function = load_expression(parse_node.children[0])
+    parameters = [load_expression(e) for e in parse_node.children[1].children]
+    return FunctionCall(function, parameters)
+  elif parse_node.data == "string":
+    return StringLiteral(parse_node.children[0].value[1:-1])
+  elif parse_node.data == "identifier":
+    return Identifier(parse_node.children[0].value)
+  elif parse_node.data == "array":
+    elements = [load_expression(e) for e in parse_node.children]
+    return Array(elements)
+  raise ParseError("Can't handle expression type: "+parse_node.data)
+
+# A statechart can only be loaded within the context of a model
+def load_statechart(model: Model, dir, sc_node, name="", default: bool = False):
+
+  def _load(sc_node) -> Statechart:
+
+    def load_action(action_node) -> Optional[Action]:
+      tag = ET.QName(action_node).localname
+      if tag == "raise":
+        name = action_node.get("event")
+        port = action_node.get("port")
+        if not port:
+          event_id = model.event_namespace.assign_id(name)
+          return RaiseInternalEvent(name=name, parameters=[], event_id=event_id)
+        else:
+          if port not in model.outports:
+            model.outports.append(port)
+          return RaiseOutputEvent(name=name, parameters=[], outport=port, time_offset=0)
+      else:
+        raise None
+
+    # parent_node: XML node containing any number of action nodes as direct children
+    def load_actions(parent_node) -> List[Action]:
+      return list(filter(lambda x: x is not None, map(lambda child: load_action(child), parent_node)))
+
+    transition_nodes: List[Tuple[Any, State]] = [] # List of (<transition>, State) tuples
+
+    # Recursively create state hierarchy from XML node
+    # Adding <transition> elements to the 'transitions' list as a side effect
+    def load_state(state_node) -> Optional[State]:
+      state = None
+      name = state_node.get("id", "")
+      tag = ET.QName(state_node).localname
+      if tag == "state":
+          state = State(name)
+      elif tag == "parallel" : 
+          state = ParallelState(name)
+      elif tag == "history":
+        is_deep = state_node.get("type", "shallow") == "deep"
+        if is_deep:
+          state = DeepHistoryState(name)
+        else:
+          state = ShallowHistoryState(name)
+      else:
+        return None
+
+      initial = state_node.get("initial", "")
+      for xml_child in state_node.getchildren():
+          child = load_state(xml_child) # may throw
+          if child:
+            state.addChild(child)
+            if child.short_name == initial:
+              state.default_state = child
+      if not initial and len(state.children) == 1:
+          state.default_state = state.children[0]
+
+      for xml_t in state_node.findall("transition", state_node.nsmap):
+        transition_nodes.append((xml_t, state))
+
+      # Parse enter/exit actions
+      def _get_enter_exit(tag, setter):
+        node = state_node.find(tag, state_node.nsmap)
+        if node is not None:
+          actions = load_actions(node)
+          setter(actions)
+
+      _get_enter_exit("onentry", state.setEnter)
+      _get_enter_exit("onexit", state.setExit)
+
+      return state
+
+    # Build tree structure
+    tree_node = sc_node.find("tree")
+    root_node = tree_node.find("state")
+    root = load_state(root_node)
+
+    # Add transitions
+    next_after_id = 0
+    for t_node, source in transition_nodes:
+      # Parse and find target state
+      target_string = t_node.get("target", "")
+      parse_tree = parser.parse(target_string, start="state_ref")
+      def find_state(sequence) -> State:
+        if sequence.data == "relative_path":
+          el = source
+        elif sequence.data == "absolute_path":
+          el = root
+        for item in sequence.children:
+          if item.type == "PARENT_NODE":
+            el = el.parent
+          elif item.type == "CURRENT_NODE":
+            continue
+          elif item.type == "IDENTIFIER":
+            el = [x for x in el.children if x.short_name == item.value][0]
+        return el
+      targets = [find_state(seq) for seq in parse_tree.children]
+
+      transition = Transition(source, targets)
+
+      # Trigger
+      event = t_node.get("event")
+      port = t_node.get("port")
+      after = t_node.get("after")
+      if after is not None:
+        event = "_after%d" % next_after_id # transition gets unique event name
+        next_after_id += 1
+        trigger = AfterTrigger(event_namespace.assign_id(event), event, Timestamp(after))
+      elif event is not None:
+        trigger = Trigger(event_namespace.assign_id(event), event, port)
+        if port not in model.inports:
+            model.inports.append(port)
+      else:
+        trigger = None
+      transition.setTrigger(trigger)
+      # Actions
+      actions = load_actions(t_node)
+      transition.setActions(actions)
+      # Guard
+      cond = t_node.get("cond")
+      if cond is not None:
+        parse_tree = parser.parse(cond, start="expr")
+        # print(parse_tree)
+        # print(parse_tree.pretty())
+        cond_expr = load_expression(parse_tree)
+        transition.setGuard(cond_expr)
+      source.addTransition(transition)
+
+    # Calculate stuff like list of ancestors, descendants, etc.
+    # Also get depth-first ordered lists of states and transitions (by source)
+    states: Dict[str, State] = {}
+    state_list: List[State] = []
+    transition_list: List[Transition] = []
+    root.init_tree(0, "", states, state_list, transition_list)
+
+    for t in transition_list:
+      t.optimize()
+
+    # Semantics - We use reflection to find the xml attribute names and values
+    semantics_node = sc_node.find("semantics")
+    semantics = SemanticConfiguration()
+    load_semantics(semantics_node, semantics)
+
+    # TODO: process datamodel node
+    datamodel_node = sc_node.find("datamodel")
+
+    statechart = Statechart(root=root, states=states, state_list=state_list, transition_list=transition_list, semantics=semantics)
+
+    model.classes[name] = statechart
+    if default:
+      model.default_class = name
+    return statechart
+
+  # Start of function:
+  src = sc_node.get("src")
+  if src is None:
+    _load(sc_node)
+  else:
+    external_sc_node = ET.parse(os.path.join(dir, src)).getroot()
+    statechart = _load(external_sc_node)
+
+    semantics_node = sc_node.find("override-semantics")
+    load_semantics(semantics_node, statechart.semantics)
+
+def load_semantics(semantics_node, semantics: SemanticConfiguration):
+  if semantics_node is not None:
+    # Use reflection to find the possible XML attributes and their values
+    for aspect in dataclasses.fields(SemanticConfiguration):
+      key = semantics_node.get(aspect.name)
+      if key is not None:
+        value = aspect.type[key.upper()]
+        setattr(semantics, aspect.name, value)
+
+def load_test(src_file) -> Test:
+  # We'll create a model with one statechart
+  model = Model()
+  test_node = ET.parse(src_file).getroot()
+  sc_node = test_node.find("statechart")
+  load_statechart(model, os.path.dirname(src_file), sc_node, name="??", default=True)
+
+  input_node = test_node.find("input")
+  output_node = test_node.find("output")
+
+  input = load_input(input_node)
+  output = load_output(output_node)
+
+  return Test(src_file, model, input, output)
+
+def load_input(input_node) -> TestInput:
+  input = []
+  if input_node is not None:
+    for event_node in input_node:
+      name = event_node.get("name")
+      port = event_node.get("port")
+      time = int(event_node.get("time"))
+      input.append(InputEvent(name, port, [], time))
+  return input
+
+def load_output(output_node) -> TestOutput:
+  output = []
+  if output_node is not None: 
+    for big_step_node in output_node:
+      big_step = []
+      for event_node in big_step_node:
+        name = event_node.get("name")
+        port = event_node.get("port")
+        parameters = [] # todo: read params
+        big_step.append(Event(id=0, name=name, port=port, parameters=parameters))
+      output.append(big_step)
+  return output

+ 0 - 1
test/new_test_files/take_many.test.xml

@@ -4,7 +4,6 @@
     <override-semantics
       big_step_maximality="take_many"/>
   </statechart>
-  <scheduler stabilize="false"/>
   <input/>
   <output>
     <big_step>

+ 3 - 1
test/new_test_files/take_one.test.xml

@@ -4,11 +4,13 @@
     <override-semantics
       big_step_maximality="take_one"/>
   </statechart>
-  <scheduler stabilize="false"/>
   <input/>
   <output>
     <big_step>
       <event name="in_b" port="out"/>
     </big_step>
+    <big_step>
+      <event name="in_c" port="out"/>
+    </big_step>
   </output>
 </test>

+ 25 - 0
test/test2.py

@@ -0,0 +1,25 @@
+import argparse
+from lib.os_tools import *
+from sccd.runtime.test import *
+from sccd.runtime.xml_loader2 import *
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(
+        description="Run SCCD tests.",
+        epilog="Set environment variable SCCDDEBUG=1 to display debug information about the inner workings of state machines.")
+    parser.add_argument('path', metavar='PATH', type=str, nargs='*', help="Tests to run. Can be a XML file or a directory. If a directory, it will be recursively scanned for XML files.")
+    parser.add_argument('--build-dir', metavar='BUILD_DIR', type=str, default='build', help="Directory for built tests. Defaults to 'build'")
+    args = parser.parse_args()
+
+    src_files = get_files(args.path, filter=lambda file: file.endswith(".test.xml"))
+
+    suite = unittest.TestSuite()
+    for src_file in src_files:
+        suite.addTest(load_test(src_file))
+
+    if len(src_files) == 0:
+        print("No input files specified.")
+        print()
+        parser.print_usage()
+    else:
+        unittest.TextTestRunner(verbosity=2).run(suite)