Browse Source

Add an optimization that replaces indirect calls by direct calls

jonathanvdc 8 years ago
parent
commit
cc101d3772

+ 11 - 12
kernel/modelverse_jit/cfg_ir.py

@@ -571,13 +571,18 @@ def create_jump(block, arguments=None):
     """Creates a jump to the given block with the given argument list."""
     """Creates a jump to the given block with the given argument list."""
     return JumpFlow(Branch(block, arguments))
     return JumpFlow(Branch(block, arguments))
 
 
+def get_def_value(def_or_value):
+    """Returns the given value, or the underlying value of the given definition, whichever is
+       appropriate."""
+    if isinstance(def_or_value, Definition):
+        return get_def_value(def_or_value.value)
+    else:
+        return def_or_value
+
 def apply_to_value(function, def_or_value):
 def apply_to_value(function, def_or_value):
     """Applies the given function to the specified value, or the underlying value of the
     """Applies the given function to the specified value, or the underlying value of the
        given definition."""
        given definition."""
-    if isinstance(def_or_value, Definition):
-        return apply_to_value(function, def_or_value.value)
-    else:
-        return function(def_or_value)
+    return function(get_def_value(def_or_value))
 
 
 def is_literal(value):
 def is_literal(value):
     """Tests if the given value is a literal."""
     """Tests if the given value is a literal."""
@@ -589,18 +594,12 @@ def is_literal_def(def_or_value):
 
 
 def is_value_def(def_or_value, class_or_type_or_tuple=Value):
 def is_value_def(def_or_value, class_or_type_or_tuple=Value):
     """Tests if the given definition or value is a value of the given type."""
     """Tests if the given definition or value is a value of the given type."""
-    if isinstance(def_or_value, Definition):
-        return is_value_def(def_or_value.value, class_or_type_or_tuple)
-    else:
-        return isinstance(def_or_value, class_or_type_or_tuple)
+    return isinstance(get_def_value(def_or_value), class_or_type_or_tuple)
 
 
 def get_def_variable(def_or_value):
 def get_def_variable(def_or_value):
     """Gets the 'variable' attribute of the given value, or the underlying value of the given
     """Gets the 'variable' attribute of the given value, or the underlying value of the given
        definition, whichever is appropriate."""
        definition, whichever is appropriate."""
-    if isinstance(def_or_value, Definition):
-        return get_def_variable(def_or_value.value)
-    else:
-        return def_or_value.variable
+    return get_def_value(def_or_value).variable
 
 
 def get_literal_value(value):
 def get_literal_value(value):
     """Gets the value of the given literal value."""
     """Gets the value of the given literal value."""

+ 40 - 1
kernel/modelverse_jit/cfg_optimization.py

@@ -175,11 +175,50 @@ def eliminate_unused_definitions(entry_point):
     for dead_def in dead_defs:
     for dead_def in dead_defs:
         dead_def.block.remove_definition(dead_def)
         dead_def.block.remove_definition(dead_def)
 
 
-def optimize(entry_point):
+def try_redefine_as_direct_call(definition, jit, called_globals):
+    """Tries to redefine the given indirect call definition as a direct call."""
+    call = definition.value
+    if not isinstance(call, cfg_ir.IndirectFunctionCall):
+        return
+
+    target = cfg_ir.get_def_value(call.target)
+    if isinstance(target, cfg_ir.LoadPointer):
+        loaded_ptr = cfg_ir.get_def_value(target.pointer)
+        if isinstance(loaded_ptr, cfg_ir.ResolveGlobal):
+            resolved_var_name = loaded_ptr.variable.name
+
+            # # Try to resolve the callee as an intrinsic.
+            # intrinsic = jit.get_intrinsic(resolved_var_name)
+            # if intrinsic is not None:
+            #     return redefine_as_intrinsic(definition, intrinsic, call.argument_list)
+
+            # Otherwise, build a thunk.
+            thunk_name = jit.jit_thunk_global(resolved_var_name)
+            definition.redefine(
+                cfg_ir.DirectFunctionCall(
+                    thunk_name, call.argument_list, cfg_ir.JIT_CALLING_CONVENTION))
+            called_globals.add(loaded_ptr)
+    elif isinstance(target, cfg_ir.Literal):
+        node_id = target.literal
+        thunk_name = jit.jit_thunk_constant(node_id)
+        definition.redefine(
+            cfg_ir.DirectFunctionCall(
+                thunk_name, call.argument_list, cfg_ir.JIT_CALLING_CONVENTION))
+
+def optimize_calls(entry_point, jit):
+    """Converts indirect calls to direct calls in the control-flow graph defined by the
+       given entry point."""
+    called_globals = set()
+    for block in get_all_blocks(entry_point):
+        for definition in block.definitions:
+            try_redefine_as_direct_call(definition, jit, called_globals)
+
+def optimize(entry_point, jit):
     """Optimizes the control-flow graph defined by the given entry point."""
     """Optimizes the control-flow graph defined by the given entry point."""
     optimize_graph_flow(entry_point)
     optimize_graph_flow(entry_point)
     elide_local_checks(entry_point)
     elide_local_checks(entry_point)
     optimize_graph_flow(entry_point)
     optimize_graph_flow(entry_point)
+    optimize_calls(entry_point, jit)
     eliminate_unused_definitions(entry_point)
     eliminate_unused_definitions(entry_point)
     optimize_graph_flow(entry_point)
     optimize_graph_flow(entry_point)
     merge_blocks(entry_point)
     merge_blocks(entry_point)

+ 1 - 1
kernel/modelverse_jit/jit.py

@@ -440,7 +440,7 @@ class ModelverseJit(object):
         if self.jit_code_log_function is not None:
         if self.jit_code_log_function is not None:
             bytecode_analyzer = bytecode_to_cfg.AnalysisState(param_dict)
             bytecode_analyzer = bytecode_to_cfg.AnalysisState(param_dict)
             bytecode_analyzer.analyze(body_bytecode)
             bytecode_analyzer.analyze(body_bytecode)
-            cfg_optimization.optimize(bytecode_analyzer.entry_point)
+            cfg_optimization.optimize(bytecode_analyzer.entry_point, self)
             self.jit_code_log_function(
             self.jit_code_log_function(
                 "CFG for function '%s' at '%d':\n%s" % (
                 "CFG for function '%s' at '%d':\n%s" % (
                     function_name, body_id,
                     function_name, body_id,