Prechádzať zdrojové kódy

Re-write the block-merging optimization

jonathanvdc 8 rokov pred
rodič
commit
2a077896c6
1 zmenil súbory, kde vykonal 18 pridanie a 9 odobranie
  1. 18 9
      kernel/modelverse_jit/cfg_optimization.py

+ 18 - 9
kernel/modelverse_jit/cfg_optimization.py

@@ -52,10 +52,12 @@ def merge_blocks(entry_point):
     """Merges blocks which have exactly one predecessor with said predecessor, if the
        predecessor has a jump flow instruction."""
     predecessor_map = cfg_ir.get_all_predecessor_blocks(entry_point)
-    queue = list(predecessor_map.keys())
+    queue = set(predecessor_map.keys())
+    queue.add(entry_point)
     def __do_merge(source, target):
         target_params = list(target.parameters)
-        for target_param, branch_arg in zip(target_params, source.flow.branch.arguments):
+        branch_args = list(source.flow.branch.arguments)
+        for target_param, branch_arg in zip(target_params, branch_args):
             target.remove_parameter(target_param)
             target_param.redefine(branch_arg)
             source.append_definition(target_param)
@@ -73,11 +75,18 @@ def merge_blocks(entry_point):
 
     while len(queue) > 0:
         block = queue.pop()
-        preds = predecessor_map[block]
-        if len(preds) == 1 and block != entry_point:
-            single_pred = next(iter(preds))
-            if single_pred != block and isinstance(single_pred.flow, cfg_ir.JumpFlow):
-                __do_merge(single_pred, block)
+        if isinstance(block.flow, cfg_ir.JumpFlow):
+            next_block = block.flow.branch.block
+            preds = predecessor_map[next_block]
+            if (len(preds) == 1
+                    and next(iter(preds)) == block
+                    and block != next_block
+                    and next_block != entry_point):
+                __do_merge(block, next_block)
+                del predecessor_map[next_block]
+                queue.add(block)
+                if next_block in queue:
+                    queue.remove(next_block)
 
 def elide_local_checks(entry_point):
     """Tries to elide redundant checks on local variables."""
@@ -341,13 +350,13 @@ def optimize(entry_point, jit):
     elide_local_checks(entry_point)
     optimize_graph_flow(entry_point)
     eliminate_trivial_phis(entry_point)
-    # entry_point = cfg_ssa_construction.construct_ssa_form(entry_point)
+    entry_point = cfg_ssa_construction.construct_ssa_form(entry_point)
     optimize_calls(entry_point, jit)
     yield [("CALL_ARGS", [inline_constants, (entry_point,)])]
     simplify_values(entry_point)
     eliminate_unused_definitions(entry_point)
     optimize_graph_flow(entry_point)
-    merge_blocks(entry_point)
     expand_indirect_definitions(entry_point)
     eliminate_unused_definitions(entry_point)
+    merge_blocks(entry_point)
     raise primitive_functions.PrimitiveFinished(entry_point)