Yentl Van Tendeloo пре 7 година
родитељ
комит
3e218117b4
2 измењених фајлова са 2 додато и 239 уклоњено
  1. 1 238
      kernel/modelverse_jit/jit.py
  2. 1 1
      wrappers/modelverse_SCCD.py

+ 1 - 238
kernel/modelverse_jit/jit.py

@@ -8,104 +8,9 @@ import modelverse_jit.runtime as jit_runtime
 # in this module.
 # in this module.
 JitCompilationFailedException = jit_runtime.JitCompilationFailedException
 JitCompilationFailedException = jit_runtime.JitCompilationFailedException
 
 
-def map_and_simplify_generator(function, instruction):
-    """Applies the given mapping function to every instruction in the tree
-       that has the given instruction as root, and simplifies it on-the-fly.
-
-       This is at least as powerful as first mapping and then simplifying, as
-       maps and simplifications are interspersed.
-
-       This function assumes that function creates a generator that returns by
-       raising a primitive_functions.PrimitiveFinished."""
-
-    # First handle the children by mapping on them and then simplifying them.
-    new_children = []
-    for inst in instruction.get_children():
-        new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])]
-        new_children.append(new_inst)
-
-    # Then apply the function to the top-level node.
-    transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])]
-    # Finally, simplify the transformed top-level node.
-    raise primitive_functions.PrimitiveFinished(transformed.simplify_node())
-
-def expand_constant_read(instruction):
-    """Tries to replace a read of a constant node by a literal."""
-    if isinstance(instruction, tree_ir.ReadValueInstruction) and \
-        isinstance(instruction.node_id, tree_ir.LiteralInstruction):
-        val, = yield [("RV", [instruction.node_id.literal])]
-        raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
-    else:
-        raise primitive_functions.PrimitiveFinished(instruction)
-
-def optimize_tree_ir(instruction):
-    """Optimizes an IR tree."""
-    return map_and_simplify_generator(expand_constant_read, instruction)
-
-def create_bare_function(function_name, parameter_list, function_body):
-    """Creates a function definition from the given function name, parameter list
-       and function body. No prolog is included."""
-    # Wrap the IR in a function definition, give it a unique name.
-    return tree_ir.DefineFunctionInstruction(
-        function_name,
-        parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
-        function_body)
-
-def create_function(
-        function_name, parameter_list, param_dict,
-        body_param_dict, function_body, source_map_name=None,
-        compatible_temporary_protects=False):
-    """Creates a function from the given function name, parameter list,
-       variable-to-parameter name map, variable-to-local name map and
-       function body. An optional source map can be included, too."""
-    # Write a prologue and prepend it to the generated function body.
-    prolog_statements = []
-    # If the source map is not None, then we should generate a "DEBUG_INFO"
-    # request.
-    if source_map_name is not None:
-        prolog_statements.append(
-            tree_ir.RegisterDebugInfoInstruction(
-                tree_ir.LiteralInstruction(function_name),
-                tree_ir.LoadGlobalInstruction(source_map_name),
-                tree_ir.LiteralInstruction(jit_runtime.BASELINE_JIT_ORIGIN_NAME)))
-
-    # Create a LOCALS_NODE_NAME node, and connect it to the user root.
-    prolog_statements.append(
-        tree_ir.create_new_local_node(
-            jit_runtime.LOCALS_NODE_NAME,
-            tree_ir.LoadIndexInstruction(
-                tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
-                tree_ir.LiteralInstruction('task_root')),
-            jit_runtime.LOCALS_EDGE_NAME))
-    for (key, val) in list(param_dict.items()):
-        arg_ptr = tree_ir.create_new_local_node(
-            body_param_dict[key],
-            tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
-        prolog_statements.append(arg_ptr)
-        prolog_statements.append(
-            tree_ir.CreateDictionaryEdgeInstruction(
-                tree_ir.LoadLocalInstruction(body_param_dict[key]),
-                tree_ir.LiteralInstruction('value'),
-                tree_ir.LoadLocalInstruction(val)))
-
-    constructed_body = tree_ir.create_block(
-        *(prolog_statements + [function_body]))
-
-    # Shield temporaries from the GC.
-    constructed_body = tree_ir.protect_temporaries_from_gc(
-        constructed_body,
-        tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME),
-        compatible_temporary_protects)
-
-    return create_bare_function(function_name, parameter_list, constructed_body)
-
-def print_value(val):
-    """A thin wrapper around 'print'."""
-    print(val)
-
 class ModelverseJit(object):
 class ModelverseJit(object):
     """A high-level interface to the modelverse JIT compiler."""
     """A high-level interface to the modelverse JIT compiler."""
-    def __init__(self, max_instructions=None, compiled_function_lookup=None):
+    def __init__(self):
         self.todo_entry_points = set()
         self.todo_entry_points = set()
         self.no_jit_entry_points = set()
         self.no_jit_entry_points = set()
         self.jitted_parameters = {}
         self.jitted_parameters = {}
@@ -118,22 +23,11 @@ class ModelverseJit(object):
         self.global_functions = {}
         self.global_functions = {}
         # global_functions_inv maps body ids to global value names.
         # global_functions_inv maps body ids to global value names.
         self.global_functions_inv = {}
         self.global_functions_inv = {}
-        # bytecode_graphs maps body ids to their parsed bytecode graphs.
-        self.bytecode_graphs = {}
         # jitted_function_aliases maps body ids to known aliases.
         # jitted_function_aliases maps body ids to known aliases.
         self.jitted_function_aliases = defaultdict(set)
         self.jitted_function_aliases = defaultdict(set)
         self.jit_count = 0
         self.jit_count = 0
-        self.max_instructions = max_instructions
-        self.compiled_function_lookup = compiled_function_lookup
         self.compilation_dependencies = {}
         self.compilation_dependencies = {}
         self.jit_enabled = True
         self.jit_enabled = True
-        self.direct_calls_allowed = True
-        self.tracing_enabled = False
-        self.source_maps_enabled = True
-        self.input_function_enabled = False
-        self.nop_insertion_enabled = True
-        self.jit_success_log_function = None
-        self.jit_code_log_function = None
 
 
     def set_jit_enabled(self, is_enabled=True):
     def set_jit_enabled(self, is_enabled=True):
         """Enables or disables the JIT."""
         """Enables or disables the JIT."""
@@ -166,16 +60,6 @@ class ModelverseJit(object):
            the currently running code."""
            the currently running code."""
         self.nop_insertion_enabled = is_enabled
         self.nop_insertion_enabled = is_enabled
 
 
-    def set_jit_success_log(self, log_function=print_value):
-        """Configures this JIT instance with a function that prints output to a log.
-           Success and failure messages for specific functions are then sent to said log."""
-        self.jit_success_log_function = log_function
-
-    def set_jit_code_log(self, log_function=print_value):
-        """Configures this JIT instance with a function that prints output to a log.
-           Function definitions of jitted functions are then sent to said log."""
-        self.jit_code_log_function = log_function
-
     def set_function_body_compiler(self, compile_function_body):
     def set_function_body_compiler(self, compile_function_body):
         """Sets the function that the JIT uses to compile function bodies."""
         """Sets the function that the JIT uses to compile function bodies."""
         self.compile_function_body = compile_function_body
         self.compile_function_body = compile_function_body
@@ -348,16 +232,6 @@ class ModelverseJit(object):
 
 
         raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
         raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
 
 
-    def jit_parse_bytecode(self, body_id):
-        """Parses the given function body as a bytecode graph."""
-        if body_id in self.bytecode_graphs:
-            raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id])
-
-        parser = bytecode_parser.BytecodeParser()
-        result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])]
-        self.bytecode_graphs[body_id] = result
-        raise primitive_functions.PrimitiveFinished(result)
-
     def check_jittable(self, body_id, suggested_name=None):
     def check_jittable(self, body_id, suggested_name=None):
         """Checks if the function with the given body id is obviously non-jittable. If it's
         """Checks if the function with the given body id is obviously non-jittable. If it's
            non-jittable, then a `JitCompilationFailedException` exception is thrown."""
            non-jittable, then a `JitCompilationFailedException` exception is thrown."""
@@ -376,73 +250,6 @@ class ModelverseJit(object):
                     '' if suggested_name is None else "'" + suggested_name + "'",
                     '' if suggested_name is None else "'" + suggested_name + "'",
                     body_id))
                     body_id))
 
 
-    def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
-        """Replaces the function with the given name by compiling the bytecode at the given
-           body id."""
-        if compile_function_body is None:
-            compile_function_body = self.compile_function_body
-
-        self.check_jittable(body_id, function_name)
-
-        # Generate a name for the function we're about to analyze, and pretend that
-        # it already exists. (we need to do this for recursive functions)
-        self.jitted_entry_points[body_id] = function_name
-        self.jit_globals[function_name] = None
-
-        (_, _, is_mutable), = yield [
-            ("CALL_ARGS", [self.jit_signature, (body_id,)])]
-
-        dependencies = set([body_id])
-        self.compilation_dependencies[body_id] = dependencies
-
-        def handle_jit_exception(exception):
-            # If analysis fails, then a JitCompilationFailedException will be thrown.
-            print("EXCEPTION with mutable")
-            del self.compilation_dependencies[body_id]
-            for dep in dependencies:
-                self.mark_no_jit(dep)
-                if dep in self.jitted_entry_points:
-                    del self.jitted_entry_points[dep]
-
-            failure_message = "%s (function '%s' at %d)" % (
-                str(exception), function_name, body_id)
-            if self.jit_success_log_function is not None:
-                self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
-            raise JitCompilationFailedException(failure_message)
-
-        # Try to analyze the function's body.
-        yield [("TRY", [])]
-        yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
-        if is_mutable:
-            # We can't just JIT mutable functions. That'd be dangerous.
-            raise JitCompilationFailedException(
-                "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
-
-        compiled_function, = yield [
-            ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])]
-
-        yield [("END_TRY", [])]
-        del self.compilation_dependencies[body_id]
-
-        if self.jit_success_log_function is not None:
-            assert self.jitted_entry_points[body_id] == function_name
-            self.jit_success_log_function(
-                "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
-
-        raise primitive_functions.PrimitiveFinished(compiled_function)
-
-    def get_source_map_name(self, function_name):
-        """Gets the name of the given jitted function's source map. None is returned if source maps
-           are disabled."""
-        if self.source_maps_enabled:
-            return function_name + "_source_map"
-        else:
-            return None
-
-    def get_can_rejit_name(self, function_name):
-        """Gets the name of the given jitted function's can-rejit flag."""
-        return function_name + "_can_rejit"
-
     def jit_define_function(self, function_name, function_def):
     def jit_define_function(self, function_name, function_def):
         """Converts the given tree-IR function definition to Python code, defines it,
         """Converts the given tree-IR function definition to Python code, defines it,
            and extracts the resulting function."""
            and extracts the resulting function."""
@@ -465,47 +272,3 @@ class ModelverseJit(object):
     def jit_delete_function(self, function_name):
     def jit_delete_function(self, function_name):
         """Deletes the function with the given function name."""
         """Deletes the function with the given function name."""
         del self.jit_globals[function_name]
         del self.jit_globals[function_name]
-
-    def jit_compile(self, task_root, body_id, suggested_name=None):
-        """Tries to jit the function defined by the given entry point id and parameter list."""
-        if body_id is None:
-            raise ValueError('body_id cannot be None: ' + str(suggested_name))
-        elif body_id in self.jitted_entry_points:
-            raise primitive_functions.PrimitiveFinished(
-                self.jit_globals[self.jitted_entry_points[body_id]])
-
-        compiled_func = self.lookup_compiled_body(body_id)
-        if compiled_func is not None:
-            raise primitive_functions.PrimitiveFinished(compiled_func)
-
-        # Generate a name for the function we're about to analyze, and 're-compile'
-        # it for the first time.
-        function_name = self.generate_function_name(body_id, suggested_name)
-        yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
-
-    def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None):
-        """Re-compiles the given function. If compilation fails, then the can-rejit
-           flag is set to false."""
-        old_jitted_func = self.jitted_entry_points[body_id]
-        def __handle_jit_failed(_):
-            self.jit_globals[self.get_can_rejit_name(function_name)] = False
-            self.jitted_entry_points[body_id] = old_jitted_func
-            self.no_jit_entry_points.remove(body_id)
-            raise primitive_functions.PrimitiveFinished(None)
-
-        yield [("TRY", [])]
-        yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
-        jitted_function, = yield [
-            ("CALL_ARGS",
-             [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
-        yield [("END_TRY", [])]
-
-        # Update all aliases.
-        for function_alias in self.jitted_function_aliases[body_id]:
-            self.jit_globals[function_alias] = jitted_function
-
-    def new_compile(self, body_id):
-        print("Compiling body ID " + str(body_id))
-        raise JitCompilationFailedException("Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
-
-        #raise primitive_functions.PrimitiveFinished("pass")

+ 1 - 1
wrappers/modelverse_SCCD.py

@@ -1,7 +1,7 @@
 """
 """
 Generated by Statechart compiler by Glenn De Jonghe, Joeri Exelmans, Simon Van Mierlo, and Yentl Van Tendeloo (for the inspiration)
 Generated by Statechart compiler by Glenn De Jonghe, Joeri Exelmans, Simon Van Mierlo, and Yentl Van Tendeloo (for the inspiration)
 
 
-Date:   Tue Apr 24 10:00:08 2018
+Date:   Tue Apr 24 10:04:54 2018
 
 
 Model author: Yentl Van Tendeloo
 Model author: Yentl Van Tendeloo
 Model name:   MvK Server
 Model name:   MvK Server