浏览代码

can call custom functions from condition code while pattern matching / rewriting + added example to runner_translate.py

Joeri Exelmans 11 月之前
父节点
当前提交
9e74075066

+ 18 - 8
examples/semantics/translational/runner_translate.py

@@ -5,18 +5,15 @@ from concrete_syntax.plantuml.renderer import render_object_diagram, render_clas
 from concrete_syntax.plantuml.make_url import make_url
 from api.od import ODAPI
 
-from transformation.ramify import ramify
-from transformation.topify.topify import Topifier
-from transformation.merger import merge_models
 from transformation.ramify import ramify
 from transformation.rule import RuleMatcherRewriter
 
 from util import loader
+from util.module_to_dict import module_to_dict
 
-from examples.semantics.operational.simulator import Simulator, RandomDecisionMaker, InteractiveDecisionMaker
-from examples.semantics.operational.port import models
-from examples.semantics.operational.port.helpers import design_to_state, state_to_design, get_time
+from examples.semantics.operational.port import models, helpers
 from examples.semantics.operational.port.renderer import render_port_textual, render_port_graphviz
+from examples.semantics.translational.renderer import show_port_and_petri_net
 from examples.petrinet.renderer import render_petri_net
 
 import os
@@ -76,7 +73,14 @@ if __name__ == "__main__":
     print('ready!')
 
     port_m_rt = port_m_rt_initial
-    matcher_rewriter = RuleMatcherRewriter(state, merged_mm, ramified_merged_mm)
+    eval_context = {
+        # make all the functions defined in 'helpers' module available to 'condition'-code in LHS/NAC/RHS:
+        **module_to_dict(helpers),
+        # another example: in all 'condition'-code, there will be a global variable 'meaning_of_life', equal to 42:
+        'meaning_of_life': 42, # just to demonstrate - feel free to remove this
+    }
+    print('The following additional globals are available:', ', '.join(list(eval_context.keys())))
+    matcher_rewriter = RuleMatcherRewriter(state, merged_mm, ramified_merged_mm, eval_context=eval_context)
 
     ###################################
     # Because the matching of many different rules can be slow,
@@ -104,7 +108,7 @@ if __name__ == "__main__":
         try:
             with open(filename, "r") as file:
                 port_m_rt = parser.parse_od(state, file.read(), merged_mm)
-                print('loaded', filename)
+            print(f'skip rule (found {filename})')
         except FileNotFoundError:
             # Fire every rule until it cannot match any longer:
             while True:
@@ -123,6 +127,12 @@ if __name__ == "__main__":
                 print('wrote', filename)
                 render_petri_net(ODAPI(state, port_m_rt, merged_mm))
 
+                # Uncomment to show also the port model:
+                # show_port_and_petri_net(state, port_m_rt, merged_mm)
+
+                # Uncomment to pause after each rendering:
+                # input()
+
     ###################################
     # Once you have generated a Petri Net, you can execute the petri net:
     #

+ 31 - 11
transformation/matcher.py

@@ -168,7 +168,14 @@ def _cannot_call_matched(_):
 
 # This function returns a Generator of matches.
 # The idea is that the user can iterate over the match set, lazily generating it: if only interested in the first match, the entire match set doesn't have to be generated.
-def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}):
+def match_od(state,
+    host_m, # the host graph, in which to search for matches
+    host_mm, # meta-model of the host graph
+    pattern_m, # the pattern to look for
+    pattern_mm, # the meta-model of the pattern (typically the RAMified version of host_mm)
+    pivot={}, # optional: a partial match (restricts possible matches, and speeds up the match process)
+    eval_context={}, # optional: additional variables, functions, ... to be available while evaluating condition-code in the pattern. Will be available as global variables in the condition-code.
+):
     bottom = Bottom(state)
 
     # compute subtype relations and such:
@@ -177,6 +184,21 @@ def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}):
     pattern_odapi = ODAPI(state, pattern_m, pattern_mm)
     pattern_mm_odapi = ODAPI(state, pattern_mm, cdapi.mm)
 
+    # 'globals'-dict used when eval'ing conditions
+    bound_api = bind_api_readonly(odapi)
+    builtin = {
+        **bound_api,
+        'matched': _cannot_call_matched,
+        'odapi': odapi,
+    }
+    for key in eval_context:
+        if key in builtin:
+            print(f"WARNING: custom global '{key}' overrides pre-defined API function. Consider renaming it.")
+    eval_globals = {
+        **builtin,
+        **eval_context,
+    }
+
     # Function object for pattern matching. Decides whether to match host and guest vertices, where guest is a RAMified instance (e.g., the attributes are all strings with Python expressions), and the host is an instance (=object diagram) of the original model (=class diagram)
     class RAMCompare:
         def __init__(self, bottom, host_od):
@@ -234,10 +256,7 @@ def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}):
                     #   - incompatible slots may be matched (it is only when their AttributeLinks are matched, that we know the types will be compatible)
                     with Timer(f'EVAL condition {g_vtx.name}'):
                         ok = exec_then_eval(python_code,
-                            _globals={
-                                **bind_api_readonly(odapi),
-                                'matched': _cannot_call_matched,
-                            },
+                            _globals=eval_globals,
                             _locals={'this': h_vtx.node_id})
                     self.conditions_to_check.pop(g_vtx.name, None)
                     return ok
@@ -324,13 +343,14 @@ def match_od(state, host_m, host_mm, pattern_m, pattern_mm, pivot={}):
 
 
     def check_conditions(name_mapping):
+        eval_globals = {
+            **bound_api,
+            # this time, the real 'matched'-function can be used:
+            'matched': lambda name: bottom.read_outgoing_elements(host_m, name_mapping[name])[0],
+            **eval_context,
+        }
         def check(python_code: str, loc):
-            return exec_then_eval(python_code,
-                _globals={
-                    **bind_api_readonly(odapi),
-                    'matched': lambda name: bottom.read_outgoing_elements(host_m, name_mapping[name])[0],
-                },
-                _locals=loc)
+            return exec_then_eval(python_code, _globals=eval_globals, _locals=loc)
 
         # Attribute conditions
         for pattern_name, host_name in name_mapping.items():

+ 27 - 22
transformation/rewriter.py

@@ -18,7 +18,15 @@ class TryAgainNextRound(Exception):
     pass
 
 # Rewrite is performed in-place (modifying `host_m`)
-def rewrite(state, lhs_m: UUID, rhs_m: UUID, pattern_mm: UUID, lhs_match: dict, host_m: UUID, host_mm: UUID):
+def rewrite(state,
+    lhs_m: UUID, # LHS-pattern
+    rhs_m: UUID, # RHS-pattern
+    pattern_mm: UUID, # meta-model of both patterns (typically the RAMified host_mm)
+    lhs_match: dict, # a match, morphism, from lhs_m to host_m (mapping pattern name -> host name), typically found by the 'match_od'-function.
+    host_m: UUID, # host model
+    host_mm: UUID, # host meta-model
+    eval_context={}, # optional: additional variables/functions to be available while executing condition-code. These will be seen as global variables.
+):
     bottom = Bottom(state)
 
     # Need to come up with a new, unique name when creating new element in host-model:
@@ -74,6 +82,19 @@ def rewrite(state, lhs_m: UUID, rhs_m: UUID, pattern_mm: UUID, lhs_match: dict,
     # to be grown
     rhs_match = { name : lhs_match[name] for name in common }
 
+    builtin = {
+        **bind_api(host_odapi),
+        'matched': matched_callback,
+        'odapi': host_odapi,
+    }
+    for key in eval_context:
+        if key in builtin:
+            print(f"WARNING: custom global '{key}' overrides pre-defined API function. Consider renaming it.")
+    eval_globals = {
+        **builtin,
+        **eval_context,
+    }
+
     # 1. Perform creations - in the right order!
     remaining_to_create = list(to_create)
     while len(remaining_to_create) > 0:
@@ -86,11 +107,7 @@ def rewrite(state, lhs_m: UUID, rhs_m: UUID, pattern_mm: UUID, lhs_match: dict,
                 name_expr = rhs_odapi.get_slot_value(rhs_obj, "name")
             except:
                 name_expr = f'"{rhs_name}"' # <- if the 'name' slot doesnt exist, use the pattern element name
-            suggested_name = exec_then_eval(name_expr,
-                _globals={
-                    **bind_api(host_odapi),
-                    'matched': matched_callback,
-                })
+            suggested_name = exec_then_eval(name_expr, _globals=eval_globals)
             rhs_type = rhs_odapi.get_type(rhs_obj)
             host_type = ramify.get_original_type(bottom, rhs_type)
             # for debugging:
@@ -157,10 +174,7 @@ def rewrite(state, lhs_m: UUID, rhs_m: UUID, pattern_mm: UUID, lhs_match: dict,
                     host_attr_name = host_mm_odapi.get_slot_value(host_attr_link, "name")
                     val_name = f"{host_src_name}.{host_attr_name}"
                     python_expr = ActionCode(UUID(bottom.read_value(rhs_obj)), bottom.state).read()
-                    result = exec_then_eval(python_expr, _globals={
-                        **bind_api(host_odapi),
-                        'matched': matched_callback,
-                    })
+                    result = exec_then_eval(python_expr, _globals=eval_globals)
                     host_odapi.create_primitive_value(val_name, result, is_code=False)
                     rhs_match[rhs_name] = val_name
                 else:
@@ -192,10 +206,7 @@ def rewrite(state, lhs_m: UUID, rhs_m: UUID, pattern_mm: UUID, lhs_match: dict,
             rhs_obj = rhs_odapi.get(common_name)
             python_expr = ActionCode(UUID(bottom.read_value(rhs_obj)), bottom.state).read()
             result = exec_then_eval(python_expr,
-                _globals={
-                    **bind_api(host_odapi),
-                    'matched': matched_callback,
-                },
+                _globals=eval_globals,
                 _locals={'this': host_obj}) # 'this' can be used to read the previous value of the slot
             host_odapi.overwrite_primitive_value(host_obj_name, result, is_code=False)
         else:
@@ -235,18 +246,12 @@ def rewrite(state, lhs_m: UUID, rhs_m: UUID, pattern_mm: UUID, lhs_match: dict,
             # rhs_obj is an object or link (because association is subtype of class)
             python_code = rhs_odapi.get_slot_value_default(rhs_obj, "condition", default="")
             simply_exec(python_code,
-                _globals={
-                    **bind_api(host_odapi),
-                    'matched': matched_callback,
-                },
+                _globals=eval_globals,
                 _locals={'this': host_obj})
 
     # 5. Execute global actions
     for cond_name, cond in rhs_odapi.get_all_instances("GlobalCondition"):
         python_code = rhs_odapi.get_slot_value(cond, "condition")
-        simply_exec(python_code, _globals={
-            **bind_api(host_odapi),
-            'matched': matched_callback,
-        })
+        simply_exec(python_code, _globals=eval_globals)
 
     return rhs_match

+ 11 - 4
transformation/rule.py

@@ -26,10 +26,11 @@ class _NAC_MATCHED(Exception):
 
 # Helper for executing NAC/LHS/RHS-type rules
 class RuleMatcherRewriter:
-    def __init__(self, state, mm: UUID, mm_ramified: UUID):
+    def __init__(self, state, mm: UUID, mm_ramified: UUID, eval_context={}):
         self.state = state
         self.mm = mm
         self.mm_ramified = mm_ramified
+        self.eval_context = eval_context
 
     # Generates matches.
     # Every match is a dictionary with entries LHS_element_name -> model_element_name
@@ -38,7 +39,9 @@ class RuleMatcherRewriter:
             host_m=m,
             host_mm=self.mm,
             pattern_m=lhs,
-            pattern_mm=self.mm_ramified)
+            pattern_mm=self.mm_ramified,
+            eval_context=self.eval_context,
+        )
 
         try:
             # First we iterate over LHS-matches:
@@ -64,7 +67,9 @@ class RuleMatcherRewriter:
                                     host_mm=self.mm,
                                     pattern_m=nac,
                                     pattern_mm=self.mm_ramified,
-                                    pivot=lhs_match) # try to "grow" LHS-match with NAC-match
+                                    pivot=lhs_match, # try to "grow" LHS-match with NAC-match
+                                    eval_context=self.eval_context,
+                                )
 
                                 try:
                                     # for nac_match in nac_matcher:
@@ -117,7 +122,9 @@ class RuleMatcherRewriter:
                 pattern_mm=self.mm_ramified,
                 lhs_match=lhs_match,
                 host_m=cloned_m,
-                host_mm=self.mm)
+                host_mm=self.mm,
+                eval_context=self.eval_context,
+            )
         except Exception as e:
             # Make exceptions raised in eval'ed code easier to trace:
             e.add_note(f"while executing RHS of '{rule_name}'")

+ 8 - 0
util/module_to_dict.py

@@ -0,0 +1,8 @@
+# Based on: https://stackoverflow.com/a/46263657
+def module_to_dict(module):
+    context = {}
+    for name in dir(module):
+        # this will filter out 'private' functions, as well as __builtins__, __name__, __package__, etc.:
+        if not name.startswith('_'):
+            context[name] = getattr(module, name)
+    return context