|
@@ -0,0 +1,88 @@
|
|
|
+"""Defines optimizations that replace Modelverse data structures by Python data structures."""
|
|
|
+
|
|
|
+import modelverse_jit.cfg_ir as cfg_ir
|
|
|
+
|
|
|
+SET_DEF_REWRITE_RULES = {
|
|
|
+ 'dict_keys':
|
|
|
+ lambda def_def:
|
|
|
+ def_def.redefine(
|
|
|
+ cfg_ir.DirectFunctionCall(
|
|
|
+ cfg_ir.REVERSE_LIST_MACRO_NAME,
|
|
|
+ [('seq',
|
|
|
+ def_def.insert_before(
|
|
|
+ cfg_ir.DirectFunctionCall(
|
|
|
+ cfg_ir.READ_DICT_KEYS_MACRO_NAME,
|
|
|
+ cfg_ir.get_def_value(def_def).argument_list,
|
|
|
+ cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION,
|
|
|
+ has_side_effects=False)))],
|
|
|
+ cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION,
|
|
|
+ has_side_effects=False))
|
|
|
+}
|
|
|
+
|
|
|
+def __redefine_as_list_len(use_def, def_def):
|
|
|
+ use_def.redefine(
|
|
|
+ cfg_ir.CreateNode(
|
|
|
+ use_def.insert_before(
|
|
|
+ cfg_ir.DirectFunctionCall(
|
|
|
+ 'len', [('seq', def_def)],
|
|
|
+ cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION,
|
|
|
+ has_side_effects=False))))
|
|
|
+
|
|
|
+SET_USE_REWRITE_RULES = {
|
|
|
+ ('set_pop', 'a'):
|
|
|
+ lambda use_def, def_def:
|
|
|
+ use_def.redefine(
|
|
|
+ cfg_ir.DirectFunctionCall(
|
|
|
+ 'pop', [('self', def_def)],
|
|
|
+ cfg_ir.SELF_POSITIONAL_CALLING_CONVENTION)),
|
|
|
+ ('list_len', 'a'): __redefine_as_list_len,
|
|
|
+ ('read_nr_out', 'a'): __redefine_as_list_len
|
|
|
+}
|
|
|
+
|
|
|
+def get_call_def_rewriter(definition, def_rewrite_rules):
|
|
|
+ """Gets an appropriate rewrite rule from the given dictionary of call rewrite rules."""
|
|
|
+ if cfg_ir.is_call(definition, calling_convention=cfg_ir.JIT_CALLING_CONVENTION):
|
|
|
+ call = cfg_ir.get_def_value(definition)
|
|
|
+ if call.target_name in def_rewrite_rules:
|
|
|
+ return def_rewrite_rules[call.target_name]
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+def get_call_use_rewriter(use_definition, def_definition, use_rewrite_rules):
|
|
|
+ """Gets an appropriate rewrite rule from the given dictionary of call rewrite rules."""
|
|
|
+ if cfg_ir.is_call(use_definition, calling_convention=cfg_ir.JIT_CALLING_CONVENTION):
|
|
|
+ call = cfg_ir.get_def_value(use_definition)
|
|
|
+ for arg_name, arg_def in call.argument_list:
|
|
|
+ if arg_def == def_definition:
|
|
|
+ key = (call.target_name, arg_name)
|
|
|
+ if key in use_rewrite_rules:
|
|
|
+ return use_rewrite_rules[key]
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+def apply_rewrite_rules(
|
|
|
+ entry_point,
|
|
|
+ get_def_rewriter,
|
|
|
+ get_use_rewriter):
|
|
|
+ """Applies the given definition and use rewrite rules to all definitions and uses where a
|
|
|
+ rewrite rule can be found for both the definitions and the uses."""
|
|
|
+ # pylint: disable=I0011,W0108
|
|
|
+ cfg_ir.match_and_rewrite(
|
|
|
+ entry_point,
|
|
|
+ lambda def_def:
|
|
|
+ get_def_rewriter(def_def) is not None,
|
|
|
+ lambda use_def, def_def:
|
|
|
+ get_use_rewriter(use_def, def_def) is not None,
|
|
|
+ lambda def_def:
|
|
|
+ get_def_rewriter(def_def)(def_def),
|
|
|
+ lambda use_def, def_def:
|
|
|
+ get_use_rewriter(use_def, def_def)(use_def, def_def))
|
|
|
+
|
|
|
+def optimize_data_structures(entry_point):
|
|
|
+ """Optimizes data structures in the graph defined by the given entry point."""
|
|
|
+ apply_rewrite_rules(
|
|
|
+ entry_point,
|
|
|
+ lambda def_def:
|
|
|
+ get_call_def_rewriter(def_def, SET_DEF_REWRITE_RULES),
|
|
|
+ lambda use_def, def_def:
|
|
|
+ get_call_use_rewriter(use_def, def_def, SET_USE_REWRITE_RULES))
|