Browse Source

Define an adaptive JIT mode

jonathanvdc 8 years ago
parent
commit
0f17d07805

+ 77 - 0
kernel/modelverse_jit/bytecode_ir.py

@@ -6,6 +6,24 @@ class Instruction(object):
         self.next_instruction = None
         self.next_instruction = None
         self.debug_information = None
         self.debug_information = None
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        raise NotImplementedError()
+
+    def get_reachable(self):
+        """Gets the set of all instructions that are reachable from the given instruction, including
+           this instruction."""
+        results = set()
+        stack = [self]
+        while len(stack) > 0:
+            instr = stack.pop()
+            results.add(instr)
+            for other in instr.get_directly_reachable():
+                if other not in results:
+                    stack.append(other)
+
+        return results
+
 class VariableNode(object):
 class VariableNode(object):
     """Represents a variable node, which has an identifier and an optional name."""
     """Represents a variable node, which has an identifier and an optional name."""
     def __init__(self, node_id, name):
     def __init__(self, node_id, name):
@@ -26,6 +44,13 @@ class SelectInstruction(Instruction):
         self.if_clause = if_clause
         self.if_clause = if_clause
         self.else_clause = else_clause
         self.else_clause = else_clause
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        if self.else_clause is None:
+            return (self.condition, self.if_clause)
+        else:
+            return (self.condition, self.if_clause, self.else_clause)
+
     constructor_parameters = (
     constructor_parameters = (
         ('cond', Instruction),
         ('cond', Instruction),
         ('then', Instruction),
         ('then', Instruction),
@@ -42,6 +67,10 @@ class WhileInstruction(Instruction):
         self.condition = condition
         self.condition = condition
         self.body = body
         self.body = body
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.condition, self.body)
+
     constructor_parameters = (
     constructor_parameters = (
         ('cond', Instruction),
         ('cond', Instruction),
         ('body', Instruction))
         ('body', Instruction))
@@ -56,6 +85,10 @@ class BreakInstruction(Instruction):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.loop = loop
         self.loop = loop
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.loop,)
+
     constructor_parameters = (('while', WhileInstruction),)
     constructor_parameters = (('while', WhileInstruction),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -68,6 +101,10 @@ class ContinueInstruction(Instruction):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.loop = loop
         self.loop = loop
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.loop,)
+
     constructor_parameters = (('while', WhileInstruction),)
     constructor_parameters = (('while', WhileInstruction),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -81,6 +118,10 @@ class ReturnInstruction(Instruction):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.value = value
         self.value = value
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.value,)
+
     constructor_parameters = (('value', Instruction),)
     constructor_parameters = (('value', Instruction),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -95,6 +136,10 @@ class CallInstruction(Instruction):
         self.target = target
         self.target = target
         self.argument_list = argument_list
         self.argument_list = argument_list
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.target,) + tuple((value for _, value in self.argument_list))
+
     def __repr__(self):
     def __repr__(self):
         return '@%r: CallInstruction(@%r, [%s])' % (
         return '@%r: CallInstruction(@%r, [%s])' % (
             id(self), id(self.target),
             id(self), id(self.target),
@@ -108,6 +153,10 @@ class ConstantInstruction(Instruction):
         self.constant_id = constant_id
         self.constant_id = constant_id
         assert self.constant_id is not None
         assert self.constant_id is not None
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = (('node', int),)
     constructor_parameters = (('node', int),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -119,6 +168,10 @@ class InputInstruction(Instruction):
     def __init__(self):
     def __init__(self):
         Instruction.__init__(self)
         Instruction.__init__(self)
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = ()
     constructor_parameters = ()
 
 
     def __repr__(self):
     def __repr__(self):
@@ -131,6 +184,10 @@ class OutputInstruction(Instruction):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.value = value
         self.value = value
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.value,)
+
     constructor_parameters = (('value', Instruction),)
     constructor_parameters = (('value', Instruction),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -143,6 +200,10 @@ class DeclareInstruction(Instruction):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.variable = variable
         self.variable = variable
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = (('var', VariableNode),)
     constructor_parameters = (('var', VariableNode),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -155,6 +216,10 @@ class GlobalInstruction(Instruction):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.variable = variable
         self.variable = variable
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = (('var', VariableNode),)
     constructor_parameters = (('var', VariableNode),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -168,6 +233,10 @@ class ResolveInstruction(Instruction):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.variable = variable
         self.variable = variable
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = (('var', VariableNode),)
     constructor_parameters = (('var', VariableNode),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -181,6 +250,10 @@ class AccessInstruction(Instruction):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.pointer = pointer
         self.pointer = pointer
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.pointer,)
+
     constructor_parameters = (('var', Instruction),)
     constructor_parameters = (('var', Instruction),)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -195,6 +268,10 @@ class AssignInstruction(Instruction):
         self.pointer = pointer
         self.pointer = pointer
         self.value = value
         self.value = value
 
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.pointer, self.value)
+
     constructor_parameters = (
     constructor_parameters = (
         ('var', Instruction),
         ('var', Instruction),
         ('value', Instruction))
         ('value', Instruction))

+ 136 - 9
kernel/modelverse_jit/jit.py

@@ -117,6 +117,8 @@ class ModelverseJit(object):
             jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
             jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
             jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant_function,
             jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant_function,
             jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global,
             jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global,
+            jit_runtime.JIT_REJIT_FUNCTION_NAME : self.jit_rejit,
+            jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME : compile_function_body_fast,
             jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable
             jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable
         }
         }
         # jitted_entry_points maps body ids to values in jit_globals.
         # jitted_entry_points maps body ids to values in jit_globals.
@@ -454,10 +456,6 @@ class ModelverseJit(object):
            non-jittable, then a `JitCompilationFailedException` exception is thrown."""
            non-jittable, then a `JitCompilationFailedException` exception is thrown."""
         if body_id is None:
         if body_id is None:
             raise ValueError('body_id cannot be None')
             raise ValueError('body_id cannot be None')
-        elif body_id in self.jitted_entry_points:
-            # We have already compiled this function.
-            raise primitive_functions.PrimitiveFinished(
-                self.jit_globals[self.jitted_entry_points[body_id]])
         elif body_id in self.no_jit_entry_points:
         elif body_id in self.no_jit_entry_points:
             # We're not allowed to jit this function or have tried and failed before.
             # We're not allowed to jit this function or have tried and failed before.
             raise JitCompilationFailedException(
             raise JitCompilationFailedException(
@@ -471,9 +469,12 @@ class ModelverseJit(object):
                     '' if suggested_name is None else "'" + suggested_name + "'",
                     '' if suggested_name is None else "'" + suggested_name + "'",
                     body_id))
                     body_id))
 
 
-    def jit_recompile(self, task_root, body_id, function_name):
+    def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
         """Replaces the function with the given name by compiling the bytecode at the given
         """Replaces the function with the given name by compiling the bytecode at the given
            body id."""
            body id."""
+        if compile_function_body is None:
+            compile_function_body = self.compile_function_body
+
         self.check_jittable(body_id, function_name)
         self.check_jittable(body_id, function_name)
 
 
         # Generate a name for the function we're about to analyze, and pretend that
         # Generate a name for the function we're about to analyze, and pretend that
@@ -510,7 +511,7 @@ class ModelverseJit(object):
                 "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
                 "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
 
 
         constructed_function, = yield [
         constructed_function, = yield [
-            ("CALL_ARGS", [self.compile_function_body, (self, function_name, body_id, task_root)])]
+            ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])]
 
 
         yield [("END_TRY", [])]
         yield [("END_TRY", [])]
         del self.compilation_dependencies[body_id]
         del self.compilation_dependencies[body_id]
@@ -532,6 +533,10 @@ class ModelverseJit(object):
         else:
         else:
             return None
             return None
 
 
+    def get_can_rejit_name(self, function_name):
+        """Gets the name of the given jitted function's can-rejit flag."""
+        return function_name + "_can_rejit"
+
     def jit_define_function(self, function_name, function_def):
     def jit_define_function(self, function_name, function_def):
         """Converts the given tree-IR function definition to Python code, defines it,
         """Converts the given tree-IR function definition to Python code, defines it,
            and extracts the resulting function."""
            and extracts the resulting function."""
@@ -558,11 +563,35 @@ class ModelverseJit(object):
 
 
     def jit_compile(self, task_root, body_id, suggested_name=None):
     def jit_compile(self, task_root, body_id, suggested_name=None):
         """Tries to jit the function defined by the given entry point id and parameter list."""
         """Tries to jit the function defined by the given entry point id and parameter list."""
-        # Generate a name for the function we're about to analyze, and pretend that
-        # it already exists. (we need to do this for recursive functions)
+        if body_id is None:
+            raise ValueError('body_id cannot be None')
+        elif body_id in self.jitted_entry_points:
+            # We have already compiled this function.
+            raise primitive_functions.PrimitiveFinished(
+                self.jit_globals[self.jitted_entry_points[body_id]])
+
+        # Generate a name for the function we're about to analyze, and 're-compile'
+        # it for the first time.
         function_name = self.generate_function_name(body_id, suggested_name)
         function_name = self.generate_function_name(body_id, suggested_name)
         yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
         yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
 
 
+    def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None):
+        """Re-compiles the given function. If compilation fails, then the can-rejit
+           flag is set to false."""
+        old_jitted_func = self.jitted_entry_points[body_id]
+        def __handle_jit_failed(_):
+            self.jit_globals[self.get_can_rejit_name(function_name)] = False
+            self.jitted_entry_points[body_id] = old_jitted_func
+            self.no_jit_entry_points.remove(body_id)
+            raise primitive_functions.PrimitiveFinished(None)
+
+        yield [("TRY", [])]
+        yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
+        yield [
+            ("CALL_ARGS",
+             [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
+        yield [("END_TRY", [])]
+
     def jit_thunk(self, get_function_body, global_name=None):
     def jit_thunk(self, get_function_body, global_name=None):
         """Creates a thunk from the given IR tree that computes the function's body id.
         """Creates a thunk from the given IR tree that computes the function's body id.
            This thunk is a function that will invoke the function whose body id is retrieved.
            This thunk is a function that will invoke the function whose body id is retrieved.
@@ -680,7 +709,7 @@ class ModelverseJit(object):
                 tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
                 tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
             global_name)
             global_name)
 
 
-def compile_function_body_baseline(jit, function_name, body_id, task_root):
+def compile_function_body_baseline(jit, function_name, body_id, task_root, header=None):
     """Have the baseline JIT compile the function with the given name and body id."""
     """Have the baseline JIT compile the function with the given name and body id."""
     (parameter_ids, parameter_list, _), = yield [
     (parameter_ids, parameter_list, _), = yield [
         ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
         ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
@@ -692,6 +721,9 @@ def compile_function_body_baseline(jit, function_name, body_id, task_root):
         jit.max_instructions)
         jit.max_instructions)
     constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
     constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
 
 
+    if header is not None:
+        constructed_body = tree_ir.create_block(header, constructed_body)
+
     # Optimize the function's body.
     # Optimize the function's body.
     constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
     constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
 
 
@@ -726,3 +758,98 @@ def compile_function_body_fast(jit, function_name, body_id, _):
         create_bare_function(
         create_bare_function(
             function_name, parameter_list,
             function_name, parameter_list,
             constructed_body))
             constructed_body))
+
+def favor_large_functions(body_bytecode):
+    """Computes the initial temperature of a function based on the size of
+       its body bytecode. Larger functions are favored and the temperature
+       is incremented by one on every call."""
+    return (
+        len(body_bytecode.get_reachable()),
+        lambda old_value:
+        tree_ir.BinaryInstruction(
+            old_value,
+            '+',
+            tree_ir.LiteralInstruction(1)))
+
+ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD = 200
+"""The threshold temperature at which fast-jit will be used."""
+
+def compile_function_body_adaptive(
+        jit, function_name, body_id, task_root,
+        temperature_heuristic=favor_large_functions):
+    """Compile the function with the given name and body id. An execution engine is picked
+       automatically, and the function may be compiled again at a later time."""
+    # The general idea behind this compilation technique is to first use the baseline JIT
+    # to compile a function, and then switch to the fast JIT when we determine that doing
+    # so would be a good idea. We maintain a 'temperature' counter, which has an initial value
+    # and gets incremented every time the function is executed.
+
+    body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
+    initial_temperature, increment_temperature = temperature_heuristic(body_bytecode)
+
+    if initial_temperature >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
+        # Initial temperature exceeds the fast-jit threshold.
+        # Compile this thing with fast-jit right away.
+        yield [
+            ("TAIL_CALL_ARGS",
+             [compile_function_body_fast, (jit, function_name, body_id, task_root)])]
+
+    (_, parameter_list, _), = yield [
+        ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
+
+    temperature_counter_name = jit.import_value(
+        initial_temperature, function_name + "_temperature_counter")
+
+    can_rejit_name = jit.get_can_rejit_name(function_name)
+    jit.jit_globals[can_rejit_name] = True
+
+    # This tree represents the following logic:
+    #
+    # if can_rejit:
+    #     global temperature_counter
+    #     temperature_counter = increment_temperature(temperature_counter)
+    #     if temperature_counter >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
+    #         yield [("CALL_KWARGS", [jit_runtime.JIT_REJIT_FUNCTION_NAME, {...}])]
+    #         yield [("TAIL_CALL_KWARGS", [function_name, {...}])]
+
+    header = tree_ir.SelectInstruction(
+        tree_ir.LoadGlobalInstruction(can_rejit_name),
+        tree_ir.create_block(
+            tree_ir.DeclareGlobalInstruction(temperature_counter_name),
+            tree_ir.IgnoreInstruction(
+                tree_ir.StoreGlobalInstruction(
+                    temperature_counter_name,
+                    increment_temperature(
+                        tree_ir.LoadGlobalInstruction(temperature_counter_name)))),
+            tree_ir.SelectInstruction(
+                tree_ir.BinaryInstruction(
+                    tree_ir.LoadGlobalInstruction(temperature_counter_name),
+                    '>=',
+                    tree_ir.LiteralInstruction(ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD)),
+                tree_ir.create_block(
+                    tree_ir.RunGeneratorFunctionInstruction(
+                        tree_ir.LoadGlobalInstruction(jit_runtime.JIT_REJIT_FUNCTION_NAME),
+                        tree_ir.DictionaryLiteralInstruction([
+                            (tree_ir.LiteralInstruction('task_root'),
+                             bytecode_to_tree.load_task_root()),
+                            (tree_ir.LiteralInstruction('body_id'),
+                             tree_ir.LiteralInstruction(body_id)),
+                            (tree_ir.LiteralInstruction('function_name'),
+                             tree_ir.LiteralInstruction(function_name)),
+                            (tree_ir.LiteralInstruction('compile_function_body'),
+                             tree_ir.LoadGlobalInstruction(
+                                 jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME))]),
+                        result_type=tree_ir.NO_RESULT_TYPE),
+                    tree_ir.create_jit_call(
+                        tree_ir.LoadGlobalInstruction(function_name),
+                        [(name, tree_ir.LoadLocalInstruction(name)) for name in parameter_list],
+                        tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
+                        tree_ir.RunTailGeneratorFunctionInstruction)),
+                tree_ir.EmptyInstruction())),
+        tree_ir.EmptyInstruction())
+
+    # Compile with the baseline JIT, and insert the header.
+    yield [
+        ("TAIL_CALL_ARGS",
+         [compile_function_body_baseline,
+          (jit, function_name, body_id, task_root, header)])]

+ 6 - 0
kernel/modelverse_jit/runtime.py

@@ -25,6 +25,12 @@ JIT_THUNK_CONSTANT_FUNCTION_NAME = "__jit_thunk_constant_function"
 JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global"
 JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global"
 """The name of the jit_thunk_global function in the JIT's global context."""
 """The name of the jit_thunk_global function in the JIT's global context."""
 
 
+JIT_REJIT_FUNCTION_NAME = "__jit_rejit"
+"""The name of the rejit function in the JIT's global context."""
+
+JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME = "__jit_compile_function_body_fast"
+"""The name of the compile_function_body_fast function in the JIT's global context."""
+
 UNREACHABLE_FUNCTION_NAME = "__unreachable"
 UNREACHABLE_FUNCTION_NAME = "__unreachable"
 """The name of the unreachable function in the JIT's global context."""
 """The name of the unreachable function in the JIT's global context."""
 
 

+ 64 - 10
kernel/modelverse_jit/tree_ir.py

@@ -823,12 +823,18 @@ class DictionaryLiteralInstruction(Instruction):
 
 
     def get_children(self):
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
         """Gets this instruction's sequence of child instructions."""
-        return [val for _, val in self.key_value_pairs]
+        results = []
+        for key, val in self.key_value_pairs:
+            results.append(key)
+            results.append(val)
+        return results
 
 
     def create(self, new_children):
     def create(self, new_children):
         """Creates a new instruction of this type from the given sequence of child instructions."""
         """Creates a new instruction of this type from the given sequence of child instructions."""
-        keys = [k for k, _ in self.key_value_pairs]
-        return DictionaryLiteralInstruction(zip(keys, new_children))
+        new_kv_pairs = []
+        for i in xrange(len(self.key_value_pairs)):
+            new_kv_pairs.append((new_children[2 * i], new_children[2 * i + 1]))
+        return DictionaryLiteralInstruction(new_kv_pairs)
 
 
     def has_definition_impl(self):
     def has_definition_impl(self):
         """Tells if this instruction requires a definition."""
         """Tells if this instruction requires a definition."""
@@ -981,14 +987,11 @@ class RunGeneratorFunctionInstruction(StateInstruction):
 
 
 class RunTailGeneratorFunctionInstruction(StateInstruction):
 class RunTailGeneratorFunctionInstruction(StateInstruction):
     """An instruction that runs a generator function."""
     """An instruction that runs a generator function."""
-    def __init__(self, function, argument_dict):
+    def __init__(self, function, argument_dict, result_type=NO_RESULT_TYPE):
         StateInstruction.__init__(self)
         StateInstruction.__init__(self)
         self.function = function
         self.function = function
         self.argument_dict = argument_dict
         self.argument_dict = argument_dict
-
-    def get_result_type_impl(self):
-        """Gets the type of value produced by this instruction."""
-        return NO_RESULT_TYPE
+        self.result_type_cache = result_type
 
 
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
@@ -1234,6 +1237,55 @@ class LoadGlobalInstruction(VariableInstruction):
     def __repr__(self):
     def __repr__(self):
         return 'LoadGlobalInstruction(%r)' % self.name
         return 'LoadGlobalInstruction(%r)' % self.name
 
 
+class StoreGlobalInstruction(VariableInstruction):
+    """An instruction that assigns a value to a global variable."""
+    def __init__(self, name, value):
+        VariableInstruction.__init__(self, name)
+        self.value = value
+
+    def get_children(self):
+        """Gets this instruction's sequence of child instructions."""
+        return [self.value]
+
+    def create(self, new_children):
+        """Creates a new instruction of this type from the given sequence of child instructions."""
+        val, = new_children
+        return StoreGlobalInstruction(self.name, val)
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+           The statement is appended immediately to the code generator."""
+        code_generator.append_move_definition(self, self.value)
+
+    def __repr__(self):
+        return 'StoreGlobalInstruction(%r, %r)' % (self.name, self.value)
+
+class DeclareGlobalInstruction(VariableInstruction):
+    """An instruction that declares a name as a global variable."""
+    def get_children(self):
+        """Gets this instruction's sequence of child instructions."""
+        return []
+
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NO_RESULT_TYPE
+
+    def has_result_temporary(self):
+        """Tells if this instruction stores its result in a temporary."""
+        return False
+
+    def create(self, new_children):
+        """Creates a new instruction of this type from the given sequence of child instructions."""
+        return self
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+           The statement is appended immediately to the code generator."""
+        code_generator.append_line('global %s' % self.name)
+
+    def __repr__(self):
+        return 'DeclareGlobalInstruction(%r)' % self.name
+
 class LoadIndexInstruction(Instruction):
 class LoadIndexInstruction(Instruction):
     """An instruction that produces a value by indexing a specified expression with
     """An instruction that produces a value by indexing a specified expression with
        a given key."""
        a given key."""
@@ -1602,7 +1654,9 @@ def create_block(*statements):
             statements[0],
             statements[0],
             create_block(*statements[1:]))
             create_block(*statements[1:]))
 
 
-def create_jit_call(target, named_arguments, kwargs):
+def create_jit_call(
+        target, named_arguments, kwargs,
+        create_run_generator=RunGeneratorFunctionInstruction):
     """Creates a call that abides by the JIT's calling convention."""
     """Creates a call that abides by the JIT's calling convention."""
     # A JIT call looks like this:
     # A JIT call looks like this:
     #
     #
@@ -1628,7 +1682,7 @@ def create_jit_call(target, named_arguments, kwargs):
             [kwargs]))
             [kwargs]))
     return CompoundInstruction(
     return CompoundInstruction(
         create_block(*results),
         create_block(*results),
-        RunGeneratorFunctionInstruction(
+        create_run_generator(
             target, arg_dict.create_load(), NODE_RESULT_TYPE))
             target, arg_dict.create_load(), NODE_RESULT_TYPE))
 
 
 def evaluate_and_load(value):
 def evaluate_and_load(value):