Selaa lähdekoodia

Fixed errors reported by mypy type checker + parallelism (worker pool) in render script.

Joeri Exelmans 5 vuotta sitten
vanhempi
commit
3148ee5da3

+ 3 - 1
.gitignore

@@ -17,4 +17,6 @@ doc/_build
 .project
 .pydevproject
 src/MANIFEST
-__pycache__/
+__pycache__/
+.mypy_cache/
+*.smcat

+ 3 - 3
src/sccd/runtime/controller.py

@@ -1,6 +1,6 @@
 import queue
 import dataclasses
-from typing import Dict, List
+from typing import Dict, List, Optional
 from sccd.runtime.event_queue import EventQueue, EventQueueDeque, Timestamp
 from sccd.runtime.event import *
 from sccd.runtime.object_manager import ObjectManager
@@ -43,7 +43,7 @@ class Controller:
     # regardless of the platform ()
 
     # Get timestamp of next entry in event queue
-    def next_wakeup(self) -> Timestamp:
+    def next_wakeup(self) -> Optional[Timestamp]:
         return self.queue.earliest_timestamp()
 
     # Run until given timestamp.
@@ -51,7 +51,7 @@ class Controller:
     # Output generated while running is written to 'pipe' so it can be heard by another thread.
     def run_until(self, now: Timestamp, pipe: queue.Queue):
 
-        unstable = []
+        unstable: List[Instance] = []
 
         # Helper. Put big step output events in the event queue or add them to the right output listeners.
         def process_big_step_output(events: List[OutputEvent]):

+ 1 - 1
src/sccd/runtime/debug.py

@@ -1,7 +1,7 @@
 import os
 
 try:
-  DEBUG = os.environ['SCCDDEBUG']
+  DEBUG = bool(os.environ['SCCDDEBUG'])
 except KeyError:
   DEBUG = False
 def print_debug(msg):

+ 8 - 10
src/sccd/runtime/event_queue.py

@@ -1,7 +1,6 @@
-from sccd.runtime.infinity import INFINITY
 from heapq import heappush, heappop, heapify
 from abc import ABC
-from typing import List, Set, Tuple, Deque, Any, TypeVar, Generic, Generator
+from typing import List, Set, Tuple, Deque, Any, TypeVar, Generic, Generator, Optional
 from collections import deque
 
 Timestamp = int
@@ -20,14 +19,14 @@ class EventQueue(Generic[Item]):
     def is_empty(self) -> bool:
         return not [item for item in self.queue if not item[2] in self.removed]
     
-    def earliest_timestamp(self) -> Timestamp:
+    def earliest_timestamp(self) -> Optional[Timestamp]:
         while self.queue and (self.queue[0] in self.removed):
             item = heappop(self.queue)
-            self.removed.remove(item)
+            self.removed.remove(item[2])
         try:
             return self.queue[0][0]
         except IndexError:
-            return INFINITY
+            return None
     
     def add(self, timestamp: Timestamp, item: Item):
         self.counters[timestamp] = self.counters.setdefault(timestamp, 0) + 1
@@ -60,10 +59,9 @@ class EventQueue(Generic[Item]):
 
 # Alternative implementation: A heapq with unique entries for each timestamp, and a deque with items for each timestamp.
 class EventQueueDeque(Generic[Item]):
-    Entry = Tuple[Timestamp, Deque[Item]]
 
     def __init__(self):
-        self.queue: List[Entry] = []
+        self.queue: List[Tuple[Timestamp, Deque[Item]]] = []
         self.entries: Dict[Timestamp, Deque[Item]] = {}
 
         # performance optimization:
@@ -78,12 +76,12 @@ class EventQueueDeque(Generic[Item]):
                     return False
         return True
 
-    def earliest_timestamp(self) -> Timestamp:
+    def earliest_timestamp(self) -> Optional[Timestamp]:
         try:
             earliest, _ = self.queue[0]
             return earliest
         except IndexError:
-            return INFINITY
+            return None
 
     def add(self, timestamp: Timestamp, item: Item):
         try:
@@ -103,7 +101,7 @@ class EventQueueDeque(Generic[Item]):
             for i in range(len(self.queue)-1, -1, -1):
                 queue_entry = self.queue[i]
                 timestamp, old_deque = queue_entry
-                new_deque = deque([])
+                new_deque: Deque[Item] = deque([])
                 for item in old_deque:
                     if item not in self.removed:
                         new_deque.append(item)

+ 185 - 156
src/sccd/runtime/object_manager.py

@@ -1,10 +1,37 @@
 import re
 import abc
 from typing import List, Tuple
-from sccd.runtime.event import Instance, Event, OutputEvent
+from sccd.runtime.event import Instance, Event, OutputEvent, InstancesTarget
 from sccd.runtime.event_queue import Timestamp
 from sccd.runtime.statecharts_core import StatechartInstance
 
+class RuntimeException(Exception):
+    """
+    Base class for runtime exceptions.
+    """
+    def __init__(self, message):
+        self.message = message
+    def __str__(self):
+        return repr(self.message)
+
+class AssociationException(RuntimeException):
+    """
+    Exception class thrown when an error occurs in a CRUD operation on associations.
+    """
+    pass
+
+class AssociationReferenceException(RuntimeException):
+    """
+    Exception class thrown when an error occurs when resolving an association reference.
+    """
+    pass
+
+class ParameterException(RuntimeException):
+    """
+    Exception class thrown when an error occurs when passing parameters.
+    """
+    pass
+
 # TODO: Clean this mess up. Look at all object management operations and see how they can be improved.
 class ObjectManager(Instance):
     _regex_pattern = re.compile("^([a-zA-Z_]\w*)(?:\[(\d+)\])?$")
@@ -68,174 +95,176 @@ class ObjectManager(Instance):
             raise ParameterException ("The broadcast event needs 2 parameters (source of event and event name).")
         return OutputEvent(parameters[1], InstancesTarget(self.instances))
 
-    def _handle_create(self, timestamp, parameters) -> List[OutputEvent]:
-        if len(parameters) < 2:
-            raise ParameterException ("The create event needs at least 2 parameters.")
+    # def _handle_create(self, timestamp, parameters) -> List[OutputEvent]:
+    #     if len(parameters) < 2:
+    #         raise ParameterException ("The create event needs at least 2 parameters.")
 
-        source = parameters[0]
-        association_name = parameters[1]
+    #     source = parameters[0]
+    #     association_name = parameters[1]
         
-        traversal_list = self._assoc_ref(association_name)
-        instances = self._get_instances(source, traversal_list)
+    #     traversal_list = self._assoc_ref(association_name)
+    #     instances = self._get_instances(source, traversal_list)
         
-        association = source.associations[association_name]
-        # association = self.instances_map[source].getAssociation(association_name)
-        if association.allowedToAdd():
-            ''' allow subclasses to be instantiated '''
-            class_name = association.to_class if len(parameters) == 2 else parameters[2]
-            instance = self._create(class_name)
-            # new_instance = self.model.classes[class_name](parameters[3:])
-            if not instance:
-                raise ParameterException("Creating instance: no such class: " + class_name)
-            output_events = instance.initialize(timestamp)
-            try:
-                index = association.addInstance(instance)
-            except AssociationException as exception:
-                raise RuntimeException("Error adding instance to association '" + association_name + "': " + str(exception))
-            p = instance.associations.get("parent")
-            if p:
-                p.addInstance(source)
-            return output_events.append(OutputEvent(Event("instance_created", None, [association_name+"["+str(index)+"]"]), InstancesTarget([source])))
-        else:
-            return OutputEvent(Event("instance_creation_error", None, [association_name]), InstancesTarget([source]))
-
-    def _handle_delete(self, timestamp, parameters) -> OutputEvent:
-        if len(parameters) < 2:
-            raise ParameterException ("The delete event needs at least 2 parameters.")
-        else:
-            source = parameters[0]
-            association_name = parameters[1]
+    #     association = source.associations[association_name]
+    #     # association = self.instances_map[source].getAssociation(association_name)
+    #     if association.allowedToAdd():
+    #         ''' allow subclasses to be instantiated '''
+    #         class_name = association.to_class if len(parameters) == 2 else parameters[2]
+    #         instance = self._create(class_name)
+    #         # new_instance = self.model.classes[class_name](parameters[3:])
+    #         if not instance:
+    #             raise ParameterException("Creating instance: no such class: " + class_name)
+    #         output_events = instance.initialize(timestamp)
+    #         try:
+    #             index = association.addInstance(instance)
+    #         except AssociationException as exception:
+    #             raise RuntimeException("Error adding instance to association '" + association_name + "': " + str(exception))
+    #         p = instance.associations.get("parent")
+    #         if p:
+    #             p.addInstance(source)
+    #         return output_events.append(OutputEvent(Event("instance_created", None, [association_name+"["+str(index)+"]"]), InstancesTarget([source])))
+    #     else:
+    #         return OutputEvent(Event("instance_creation_error", None, [association_name]), InstancesTarget([source]))
+
+    # def _handle_delete(self, timestamp, parameters) -> OutputEvent:
+    #     if len(parameters) < 2:
+    #         raise ParameterException ("The delete event needs at least 2 parameters.")
+    #     else:
+    #         source = parameters[0]
+    #         association_name = parameters[1]
             
-            traversal_list = self._assoc_ref(association_name)
-            instances = self._get_instances(source, traversal_list)
-            # association = self.instances_map[source].getAssociation(traversal_list[0][0])
-            association = source.associations[traversal_list[0][0]]
+    #         traversal_list = self._assoc_ref(association_name)
+    #         instances = self._get_instances(source, traversal_list)
+    #         # association = self.instances_map[source].getAssociation(traversal_list[0][0])
+    #         association = source.associations[traversal_list[0][0]]
             
-            for i in instances:
-                try:
-                    for assoc_name in i["instance"].associations:
-                        if assoc_name != 'parent':
-                            traversal_list = self._assoc_ref(assoc_name)
-                            instances = self._get_instances(i["instance"], traversal_list)
-                            if len(instances) > 0:
-                                raise RuntimeException("Error removing instance from association %s, still %i children left connected with association %s" % (association_name, len(instances), assoc_name))
-                    # del i["instance"].controller.input_ports[i["instance"].narrow_cast_port]
-                    association.removeInstance(i["instance"])
-                    self.instances.remove(i["instance"])
-                except AssociationException as exception:
-                    raise RuntimeException("Error removing instance from association '" + association_name + "': " + str(exception))
-                i["instance"].user_defined_destructor()
-                i["instance"].stop()
+    #         for i in instances:
+    #             try:
+    #                 for assoc_name in i["instance"].associations:
+    #                     if assoc_name != 'parent':
+    #                         traversal_list = self._assoc_ref(assoc_name)
+    #                         instances = self._get_instances(i["instance"], traversal_list)
+    #                         if len(instances) > 0:
+    #                             raise RuntimeException("Error removing instance from association %s, still %i children left connected with association %s" % (association_name, len(instances), assoc_name))
+    #                 # del i["instance"].controller.input_ports[i["instance"].narrow_cast_port]
+    #                 association.removeInstance(i["instance"])
+    #                 self.instances.remove(i["instance"])
+    #             except AssociationException as exception:
+    #                 raise RuntimeException("Error removing instance from association '" + association_name + "': " + str(exception))
+    #             i["instance"].user_defined_destructor()
+    #             i["instance"].stop()
             
-            return OutputEvent(Event("instance_deleted", parameters = [parameters[1]]), InstancesTarget([source]))
+    #         return OutputEvent(Event("instance_deleted", parameters = [parameters[1]]), InstancesTarget([source]))
                 
-    def _handle_associate(self, timestamp, parameters) -> OutputEvent:
-        if len(parameters) != 3:
-            raise ParameterException ("The associate_instance event needs 3 parameters.")
-        else:
-            source = parameters[0]
-            to_copy_list = self._get_instances(source, self._assoc_ref(parameters[1]))
-            if len(to_copy_list) != 1:
-                raise AssociationReferenceException ("Invalid source association reference.")
-            wrapped_to_copy_instance = to_copy_list[0]["instance"]
-            dest_list = self._assoc_ref(parameters[2])
-            if len(dest_list) == 0:
-                raise AssociationReferenceException ("Invalid destination association reference.")
-            last = dest_list.pop()
-            if last[1] != -1:
-                raise AssociationReferenceException ("Last association name in association reference should not be accompanied by an index.")
+    # def _handle_associate(self, timestamp, parameters) -> OutputEvent:
+    #     if len(parameters) != 3:
+    #         raise ParameterException ("The associate_instance event needs 3 parameters.")
+    #     else:
+    #         source = parameters[0]
+    #         to_copy_list = self._get_instances(source, self._assoc_ref(parameters[1]))
+    #         if len(to_copy_list) != 1:
+    #             raise AssociationReferenceException ("Invalid source association reference.")
+    #         wrapped_to_copy_instance = to_copy_list[0]["instance"]
+    #         dest_list = self._assoc_ref(parameters[2])
+    #         if len(dest_list) == 0:
+    #             raise AssociationReferenceException ("Invalid destination association reference.")
+    #         last = dest_list.pop()
+    #         if last[1] != -1:
+    #             raise AssociationReferenceException ("Last association name in association reference should not be accompanied by an index.")
                 
-            added_links = []
-            for i in self._get_instances(source, dest_list):
-                association = i["instance"].associations[last[0]]
-                if association.allowedToAdd():
-                    index = association.addInstance(wrapped_to_copy_instance)
-                    added_links.append(i["path"] + ("" if i["path"] == "" else "/") + last[0] + "[" + str(index) + "]")
+    #         added_links = []
+    #         for i in self._get_instances(source, dest_list):
+    #             association = i["instance"].associations[last[0]]
+    #             if association.allowedToAdd():
+    #                 index = association.addInstance(wrapped_to_copy_instance)
+    #                 added_links.append(i["path"] + ("" if i["path"] == "" else "/") + last[0] + "[" + str(index) + "]")
             
-            return OutputEvent(Event("instance_associated", parameters = [added_links]), InstancesTarget([source]))
+    #         return OutputEvent(Event("instance_associated", parameters = [added_links]), InstancesTarget([source]))
                 
-    def _handle_disassociate(self, timestamp, parameters) -> OutputEvent:
-        if len(parameters) < 2:
-            raise ParameterException ("The disassociate_instance event needs at least 2 parameters.")
-        else:
-            source = parameters[0]
-            association_name = parameters[1]
-            if not isinstance(association_name, list):
-                association_name = [association_name]
-            deleted_links = []
+    # def _handle_disassociate(self, timestamp, parameters) -> OutputEvent:
+    #     if len(parameters) < 2:
+    #         raise ParameterException ("The disassociate_instance event needs at least 2 parameters.")
+    #     else:
+    #         source = parameters[0]
+    #         association_name = parameters[1]
+    #         if not isinstance(association_name, list):
+    #             association_name = [association_name]
+    #         deleted_links = []
             
-            for a_n in association_name:
-                traversal_list = self._assoc_ref(a_n)
-                instances = self._get_instances(source, traversal_list)
+    #         for a_n in association_name:
+    #             traversal_list = self._assoc_ref(a_n)
+    #             instances = self._get_instances(source, traversal_list)
                 
-                for i in instances:
-                    try:
-                        index = i['ref'].associations[i['assoc_name']].removeInstance(i["instance"])
-                        deleted_links.append(a_n +  "[" + str(index) + "]")
-                    except AssociationException as exception:
-                        raise RuntimeException("Error disassociating '" + a_n + "': " + str(exception))
+    #             for i in instances:
+    #                 try:
+    #                     index = i['ref'].associations[i['assoc_name']].removeInstance(i["instance"])
+    #                     deleted_links.append(a_n +  "[" + str(index) + "]")
+    #                 except AssociationException as exception:
+    #                     raise RuntimeException("Error disassociating '" + a_n + "': " + str(exception))
             
-            return OutputEvent(Event("instance_disassociated", parameters = [deleted_links]), InstancesTarget([source]))
+    #         return OutputEvent(Event("instance_disassociated", parameters = [deleted_links]), InstancesTarget([source]))
         
-    def _handle_narrowcast(self, timestamp, parameters) -> OutputEvent:
-        if len(parameters) != 3:
-            raise ParameterException ("The narrow_cast event needs 3 parameters.")
-        source, targets, cast_event = parameters
+    # def _handle_narrowcast(self, timestamp, parameters) -> OutputEvent:
+    #     if len(parameters) != 3:
+    #         raise ParameterException ("The narrow_cast event needs 3 parameters.")
+    #     source, targets, cast_event = parameters
         
-        if not isinstance(targets, list):
-            targets = [targets]
-
-        all_instances = []
-        for target in targets:
-            traversal_list = self._assoc_ref(target)
-            instances = self._get_instances(source, traversal_list)
-            all_instances.extend(instances)
-        return OutputEvent(cast_event, instances)
+    #     if not isinstance(targets, list):
+    #         targets = [targets]
+
+    #     all_instances = []
+    #     for target in targets:
+    #         traversal_list = self._assoc_ref(target)
+    #         instances = self._get_instances(source, traversal_list)
+    #         all_instances.extend(instances)
+    #     return OutputEvent(cast_event, instances)
         
-    def _get_instances(self, source, traversal_list):
-        print("_get_instances(source=",source,"traversal_list=",traversal_list)
-        currents = [{
-            "instance": source,
-            "ref": None,
-            "assoc_name": None,
-            "assoc_index": None,
-            "path": ""
-        }]
-        # currents = [source]
-        for (name, index) in traversal_list:
-            nexts = []
-            for current in currents:
-                association = current["instance"].associations[name]
-                if (index >= 0 ):
-                    try:
-                        nexts.append({
-                            "instance": association.instances[index],
-                            "ref": current["instance"],
-                            "assoc_name": name,
-                            "assoc_index": index,
-                            "path": current["path"] + ("" if current["path"] == "" else "/") + name + "[" + str(index) + "]"
-                        })
-                    except KeyError:
-                        # Entry was removed, so ignore this request
-                        pass
-                elif (index == -1):
-                    for i in association.instances:
-                        nexts.append({
-                            "instance": association.instances[i],
-                            "ref": current["instance"],
-                            "assoc_name": name,
-                            "assoc_index": index,
-                            "path": current["path"] + ("" if current["path"] == "" else "/") + name + "[" + str(index) + "]"
-                        })
-                    #nexts.extend( association.instances.values() )
-                else:
-                    raise AssociationReferenceException("Incorrect index in association reference.")
-            currents = nexts
-        return currents
-
-    _handlers = {"narrow_cast": _handle_narrowcast,
-                         "broad_cast": _handle_broadcast,
-                         "create_instance": _handle_create,
-                         "associate_instance": _handle_associate,
-                         "disassociate_instance": _handle_disassociate,
-                         "delete_instance": _handle_delete}
+    # def _get_instances(self, source, traversal_list):
+    #     print("_get_instances(source=",source,"traversal_list=",traversal_list)
+    #     currents = [{
+    #         "instance": source,
+    #         "ref": None,
+    #         "assoc_name": None,
+    #         "assoc_index": None,
+    #         "path": ""
+    #     }]
+    #     # currents = [source]
+    #     for (name, index) in traversal_list:
+    #         nexts = []
+    #         for current in currents:
+    #             association = current["instance"].associations[name]
+    #             if (index >= 0 ):
+    #                 try:
+    #                     nexts.append({
+    #                         "instance": association.instances[index],
+    #                         "ref": current["instance"],
+    #                         "assoc_name": name,
+    #                         "assoc_index": index,
+    #                         "path": current["path"] + ("" if current["path"] == "" else "/") + name + "[" + str(index) + "]"
+    #                     })
+    #                 except KeyError:
+    #                     # Entry was removed, so ignore this request
+    #                     pass
+    #             elif (index == -1):
+    #                 for i in association.instances:
+    #                     nexts.append({
+    #                         "instance": association.instances[i],
+    #                         "ref": current["instance"],
+    #                         "assoc_name": name,
+    #                         "assoc_index": index,
+    #                         "path": current["path"] + ("" if current["path"] == "" else "/") + name + "[" + str(index) + "]"
+    #                     })
+    #                 #nexts.extend( association.instances.values() )
+    #             else:
+    #                 raise AssociationReferenceException("Incorrect index in association reference.")
+    #         currents = nexts
+    #     return currents
+
+    _handlers = {
+        # "narrow_cast": _handle_narrowcast,
+        "broad_cast": _handle_broadcast,
+        # "create_instance": _handle_create,
+        # "associate_instance": _handle_associate,
+        # "disassociate_instance": _handle_disassociate,
+        # "delete_instance": _handle_delete
+    }

+ 10 - 32
src/sccd/runtime/statecharts_core.py

@@ -5,6 +5,7 @@ The classes and functions needed to run (compiled) SCCD models.
 import os
 import termcolor
 from typing import List, Tuple
+from enum import Enum
 from sccd.runtime.infinity import INFINITY
 from sccd.runtime.event_queue import Timestamp
 from sccd.runtime.event import Event, OutputEvent, Instance, InstancesTarget
@@ -13,32 +14,6 @@ from collections import Counter
 
 ELSE_GUARD = "ELSE_GUARD"
 
-class RuntimeException(Exception):
-    """
-    Base class for runtime exceptions.
-    """
-    def __init__(self, message):
-        self.message = message
-    def __str__(self):
-        return repr(self.message)
-
-class AssociationException(RuntimeException):
-    """
-    Exception class thrown when an error occurs in a CRUD operation on associations.
-    """
-    pass
-
-class AssociationReferenceException(RuntimeException):
-    """
-    Exception class thrown when an error occurs when resolving an association reference.
-    """
-    pass
-
-class ParameterException(RuntimeException):
-    """
-    Exception class thrown when an error occurs when passing parameters.
-    """
-    pass
 
 class Association(object):
     """
@@ -537,7 +512,6 @@ class SmallStepState(object):
     def __init__(self):
         self.current_events = [] # set of enabled events during small step
         self.next_events = [] # events to become 'current' in the next small step
-        self.candidates = [] # document-ordered(!) list of transitions that can potentially be executed concurrently, or preempt each other, depending on concurrency semantics. If no concurrency is used and there are multiple candidates, the first one is chosen. Source states of candidates are *always* orthogonal to each other.
         self.has_stepped = True
 
     def reset(self):
@@ -547,15 +521,19 @@ class SmallStepState(object):
     def next(self):
         self.current_events = self.next_events # raised events from previous small step
         self.next_events = []
-        self.candidates = []
         self.has_stepped = False
 
     def addNextEvent(self, event):
         self.next_events.append(event)
 
-    def addCandidate(self, t, p):
-        self.candidates.append((t, p))
+class Maximality(Enum):
+    TAKE_ONE = 0
+    TAKE_MANY = 2
 
-    def hasCandidates(self):
-        return len(self.candidates) > 0
+class Round:
+    def __init__(self, maximality: Maximality):
+        self.current_events: List[Event] = []
+        self.next_events: List[Event] = []
+        self.has_stepped: bool = True
+        self.maximality: Maximality
 

+ 5 - 5
test/README.md

@@ -4,18 +4,18 @@ The Python program `test.py` replaces the old `run_tests.py`. It takes test inpu
 
 For example, to run the "semantics" tests:
 ```
-python3 test.py semantics
+python3 test.py test_files/semantics
 ```
 This will create a 'build' directory with compiled statechart models. It is always safe to remove this directory, it merely serves as a 'cache' for build artifacts.
 
 ## render.py
 
-The Python program `render.py` renders SVG graphs for test files. Rendered SVG files are checked in to this repository. If you wish to re-render them, you need the NPM (NodeJS) package 'state-machine-cat'. Install NodeJS and NPM, and then install the NPM package 'state-machine-cat':
+The Python program `render.py` renders SVG graphs for test files. Rendered SVG files are already checked in to this repository. If you wish to re-render them, you need the NPM (NodeJS) package [state-machine-cat](https://github.com/sverweij/state-machine-cat/). Install NodeJS and NPM, and then install the NPM package 'state-machine-cat':
 ```
 npm i -g state-machine-cat
 ```
-You can now render all the tests in the 'semantics' dir:
+Now, e.g. render the "semantics" tests:
 ```
-python3 render.py semantics
+python3 render.py test_files/semantics
 ```
-By default, the SVG files are stored next to the test XML files.
+By default, the SVG files are stored next to the test XML files.

+ 0 - 0
test/lib/__init__.py


+ 1 - 1
test/lib/builder.py

@@ -23,7 +23,7 @@ class Builder:
 
     # Get src_file and target_file modification times
     src_file_mtime = os.path.getmtime(src_file)
-    target_file_mtime = 0
+    target_file_mtime = 0.0
     try:
         target_file_mtime = os.path.getmtime(target_file)
     except FileNotFoundError:

+ 2 - 2
test/lib/os_tools.py

@@ -1,9 +1,9 @@
 import os
-from typing import List, Callable
+from typing import List, Callable, Set
 
 # For a given list of files and or directories, get all the 
 def get_files(paths: List[str], filter: Callable[[str], bool]) -> List[str]:
-  already_have = set()
+  already_have: Set[str] = set()
   src_files = []
 
   def add_file(path):

+ 39 - 18
test/render.py

@@ -1,6 +1,7 @@
 import argparse
 import sys
 import subprocess
+import multiprocessing
 from lib.os_tools import *
 from lib.builder import *
 from sccd.compiler.utils import FormattedWriter
@@ -10,31 +11,43 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser(
         description="Render statecharts as SVG images.")
     parser.add_argument('path', metavar='PATH', type=str, nargs='*', help="Models to render. Can be a XML file or a directory. If a directory, it will be recursively scanned for XML files.")
-    parser.add_argument('--build-dir', metavar='DIR', type=str, default='build', help="Directory for built tests. Defaults to 'build'")
-    parser.add_argument('--render-dir', metavar='DIR', type=str, default='.', help="Directory for SVG rendered output. Defaults to '.' (putting the SVG files with the XML source files)")
+    parser.add_argument('--build-dir', metavar='DIR', type=str, default='build', help="As a first step, input XML files first must be compiled to python files. Directory to store these files. Defaults to 'build'")
+    parser.add_argument('--output-dir', metavar='DIR', type=str, default='', help="Directory for SVG rendered output. Defaults to '.' (putting the SVG files with the XML source files)")
+    parser.add_argument('--keep-smcat', action='store_true', help="Whether to NOT delete intermediary SMCAT files after producing SVG output. Default = off (delete files)")
+    parser.add_argument('--no-svg', action='store_true', help="Don't produce SVG output. This option only makes sense in combination with the --keep-smcat option. Default = off")
+    parser.add_argument('--pool-size', metavar='INT', type=int, default=multiprocessing.cpu_count()+1, help="Number of worker processes. Default = CPU count + 1.")
     args = parser.parse_args()
 
-    try:
-      subprocess.run(["state-machine-cat", "-h"], capture_output=True)
-    except:
-        print("Failed to run 'state-machine-cat'. Make sure this application is installed on your system.")
-        exit()
 
-    builder = Builder(args.build_dir)
-    render_builder = Builder(args.render_dir)
+    py_builder = Builder(args.build_dir)
+    svg_builder = Builder(args.output_dir)
     srcs = get_files(args.path, filter=xml_filter)
 
-    for src in srcs:
-      module = builder.build_and_load(src)
+    if len(srcs):
+      if not args.no_svg:
+        try:
+          subprocess.run(["state-machine-cat", "-h"], capture_output=True)
+        except:
+            print("Failed to run 'state-machine-cat'. Make sure this application is installed on your system.")
+            exit()
+    else:
+      print("No input files specified.")      
+      print()
+      parser.print_usage()
+
+
+    def process(src):
+      module = py_builder.build_and_load(src)
       model = module.Model()
 
+      # Produce an output file for each class in the src file
       for class_name, _class in model.classes.items():
-        target = render_builder.target_file(src, '_'+class_name+'.smcat')
-        svg_target = render_builder.target_file(src, '_'+class_name+'.svg')
+        smcat_target = svg_builder.target_file(src, '_'+class_name+'.smcat')
+        svg_target = svg_builder.target_file(src, '_'+class_name+'.svg')
         
-        make_dirs(target)
+        make_dirs(smcat_target)
 
-        f = open(target, 'w')
+        f = open(smcat_target, 'w')
         w = FormattedWriter(f)
         sc = _class().statechart
 
@@ -108,6 +121,14 @@ if __name__ == '__main__':
             ctr += 1
 
         f.close()
-        subprocess.run(["state-machine-cat", target, "-o", svg_target])
-        os.remove(target)
-        print("Rendered "+svg_target)
+        if args.keep_smcat:
+          print("Wrote "+smcat_target)
+        if not args.no_svg:
+          subprocess.run(["state-machine-cat", smcat_target, "-o", svg_target])
+          print("Wrote "+svg_target)
+        if not args.keep_smcat:
+          os.remove(smcat_target)
+
+    with multiprocessing.Pool(processes=args.pool_size) as pool:
+      print("Created a pool of %d processes."%args.pool_size)
+      pool.map(process, srcs)

+ 3 - 3
test/test.py

@@ -104,8 +104,8 @@ if __name__ == '__main__':
         suite.addTest(PyTestCase(src_file, builder))
 
     if len(src_files) == 0:
-        print("Note: no test files specified.")
+        print("No input files specified.")
         print()
         parser.print_usage()
-
-    unittest.TextTestRunner(verbosity=2).run(suite)
+    else:
+        unittest.TextTestRunner(verbosity=2).run(suite)