Ver código fonte

Fix a number of CFG JIT bugs

jonathanvdc 8 anos atrás
pai
commit
78848e484f

+ 13 - 6
kernel/modelverse_jit/cfg_optimization.py

@@ -4,6 +4,7 @@ from collections import defaultdict
 import modelverse_jit.cfg_ir as cfg_ir
 import modelverse_jit.cfg_dominators as cfg_dominators
 import modelverse_jit.cfg_ssa_construction as cfg_ssa_construction
+import modelverse_kernel.primitives as primitive_functions
 
 def is_empty_block(block):
     """Tests if the given block contains no parameters or definitions."""
@@ -53,11 +54,15 @@ def merge_blocks(entry_point):
     predecessor_map = cfg_ir.get_all_predecessor_blocks(entry_point)
     queue = list(predecessor_map.keys())
     def __do_merge(source, target):
-        for target_param, branch_arg in zip(target.parameters, source.flow.branch.arguments):
-            source.append_definition(target_param)
+        target_params = list(target.parameters)
+        for target_param, branch_arg in zip(target_params, source.flow.branch.arguments):
+            target.remove_parameter(target_param)
             target_param.redefine(branch_arg)
+            source.append_definition(target_param)
 
-        for target_def in target.definitions:
+        target_defs = list(target.definitions)
+        for target_def in target_defs:
+            target.remove_definition(target_def)
             source.append_definition(target_def)
 
         source.flow = target.flow
@@ -69,9 +74,9 @@ def merge_blocks(entry_point):
     while len(queue) > 0:
         block = queue.pop()
         preds = predecessor_map[block]
-        if len(preds) == 1:
+        if len(preds) == 1 and block != entry_point:
             single_pred = next(iter(preds))
-            if isinstance(single_pred.flow, cfg_ir.JumpFlow):
+            if single_pred != block and isinstance(single_pred.flow, cfg_ir.JumpFlow):
                 __do_merge(single_pred, block)
 
 def elide_local_checks(entry_point):
@@ -330,7 +335,8 @@ def expand_indirect_definitions(entry_point):
         block.flow = __expand_indirect_defs(block.flow)
 
 def optimize(entry_point, jit):
-    """Optimizes the control-flow graph defined by the given entry point."""
+    """Optimizes the control-flow graph defined by the given entry point.
+       A potentially altered entry point is returned."""
     optimize_graph_flow(entry_point)
     elide_local_checks(entry_point)
     optimize_graph_flow(entry_point)
@@ -344,3 +350,4 @@ def optimize(entry_point, jit):
     merge_blocks(entry_point)
     expand_indirect_definitions(entry_point)
     eliminate_unused_definitions(entry_point)
+    raise primitive_functions.PrimitiveFinished(entry_point)

+ 18 - 5
kernel/modelverse_jit/cfg_to_tree.py

@@ -785,10 +785,21 @@ class LoweringState(object):
 
     def lower_select(self, flow):
         """Lowers the given 'select' flow instruction to a tree."""
-        return tree_ir.SelectInstruction(
-            self.use_definition(flow.condition),
-            self.lower_branch(flow.if_branch),
-            self.lower_branch(flow.else_branch))
+        # Schedule all branch arguments, so their definitions don't end up in the
+        # 'if' statement.
+        statements = []
+        for branch in (flow.if_branch, flow.else_branch):
+            for arg in branch.arguments:
+                if not self.scheduler.has_scheduled(arg):
+                    statements.append(self.scheduler.schedule(arg, True))
+
+        statements.append(
+            tree_ir.SelectInstruction(
+                self.use_definition(flow.condition),
+                self.lower_branch(flow.if_branch),
+                self.lower_branch(flow.else_branch)))
+
+        return tree_ir.create_block(*statements)
 
     def lower_return(self, flow):
         """Lowers the given 'return' flow instruction to a tree."""
@@ -800,7 +811,9 @@ class LoweringState(object):
 
     def lower_unreachable(self, _):
         """Lowers the given 'unreachable' flow instruction to a tree."""
-        return tree_ir.EmptyInstruction()
+        return tree_ir.IgnoreInstruction(
+            tree_ir.CallInstruction(
+                tree_ir.LoadGlobalInstruction(jit_runtime.UNREACHABLE_FUNCTION_NAME), []))
 
     def lower_break(self, _):
         """Lowers the given 'break' flow instruction to a tree."""

+ 5 - 8
kernel/modelverse_jit/jit.py

@@ -107,7 +107,8 @@ class ModelverseJit(object):
             jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function,
             jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
             jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant,
-            jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global
+            jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global,
+            jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable
         }
         # jitted_entry_points maps body ids to values in jit_globals.
         self.jitted_entry_points = {}
@@ -648,18 +649,14 @@ def compile_function_body_fast(jit, function_name, body_id, _):
     body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
     bytecode_analyzer = bytecode_to_cfg.AnalysisState(jit, function_name, param_dict)
     bytecode_analyzer.analyze(body_bytecode)
-    yield [
+    entry_point, = yield [
         ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, jit)])]
     if jit.jit_code_log_function is not None:
         jit.jit_code_log_function(
             "CFG for function '%s' at '%d':\n%s" % (
                 function_name, body_id,
-                '\n'.join(
-                    map(
-                        str,
-                        cfg_ir.get_all_reachable_blocks(
-                            bytecode_analyzer.entry_point)))))
+                '\n'.join(map(str, cfg_ir.get_all_reachable_blocks(entry_point)))))
     raise primitive_functions.PrimitiveFinished(
         create_bare_function(
             function_name, parameter_list,
-            cfg_to_tree.lower_flow_graph(bytecode_analyzer.entry_point, jit)))
+            cfg_to_tree.lower_flow_graph(entry_point, jit)))

+ 7 - 0
kernel/modelverse_jit/runtime.py

@@ -25,6 +25,9 @@ JIT_THUNK_CONSTANT_FUNCTION_NAME = "__jit_thunk_constant"
 JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global"
 """The name of the jit_thunk_global function in the JIT's global context."""
 
+UNREACHABLE_FUNCTION_NAME = "__unreachable"
+"""The name of the unreachable function in the JIT's global context."""
+
 LOCALS_NODE_NAME = "jit_locals"
 """The name of the node that is connected to all JIT locals in a given function call."""
 
@@ -145,6 +148,10 @@ def interpret_function_body(body_id, named_arguments, **kwargs):
         # An instruction has completed. Forward it.
         yield result
 
+def unreachable():
+    """Marks unreachable code."""
+    raise ValueError('An unreachable statement was reached.')
+
 def get_input(**parameters):
     """Retrieves input."""
     mvk = parameters["mvk"]