12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049 |
- import math
- import keyword
- import time
- from collections import defaultdict
- import modelverse_kernel.primitives as primitive_functions
- import modelverse_jit.bytecode_parser as bytecode_parser
- import modelverse_jit.bytecode_to_tree as bytecode_to_tree
- import modelverse_jit.bytecode_to_cfg as bytecode_to_cfg
- import modelverse_jit.bytecode_ir as bytecode_ir
- import modelverse_jit.bytecode_interpreter as bytecode_interpreter
- import modelverse_jit.cfg_optimization as cfg_optimization
- import modelverse_jit.cfg_to_tree as cfg_to_tree
- import modelverse_jit.cfg_ir as cfg_ir
- import modelverse_jit.tree_ir as tree_ir
- import modelverse_jit.runtime as jit_runtime
- # Import JitCompilationFailedException because it used to be defined
- # in this module.
- JitCompilationFailedException = jit_runtime.JitCompilationFailedException
- def map_and_simplify_generator(function, instruction):
- """Applies the given mapping function to every instruction in the tree
- that has the given instruction as root, and simplifies it on-the-fly.
- This is at least as powerful as first mapping and then simplifying, as
- maps and simplifications are interspersed.
- This function assumes that function creates a generator that returns by
- raising a primitive_functions.PrimitiveFinished."""
- # First handle the children by mapping on them and then simplifying them.
- new_children = []
- for inst in instruction.get_children():
- new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])]
- new_children.append(new_inst)
- # Then apply the function to the top-level node.
- transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])]
- # Finally, simplify the transformed top-level node.
- raise primitive_functions.PrimitiveFinished(transformed.simplify_node())
- def expand_constant_read(instruction):
- """Tries to replace a read of a constant node by a literal."""
- if isinstance(instruction, tree_ir.ReadValueInstruction) and \
- isinstance(instruction.node_id, tree_ir.LiteralInstruction):
- val, = yield [("RV", [instruction.node_id.literal])]
- raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
- else:
- raise primitive_functions.PrimitiveFinished(instruction)
- def optimize_tree_ir(instruction):
- """Optimizes an IR tree."""
- return map_and_simplify_generator(expand_constant_read, instruction)
- def create_bare_function(function_name, parameter_list, function_body):
- """Creates a function definition from the given function name, parameter list
- and function body. No prolog is included."""
- # Wrap the IR in a function definition, give it a unique name.
- return tree_ir.DefineFunctionInstruction(
- function_name,
- parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
- function_body)
- def create_function(
- function_name, parameter_list, param_dict,
- body_param_dict, function_body, source_map_name=None,
- compatible_temporary_protects=False):
- """Creates a function from the given function name, parameter list,
- variable-to-parameter name map, variable-to-local name map and
- function body. An optional source map can be included, too."""
- # Write a prologue and prepend it to the generated function body.
- prolog_statements = []
- # If the source map is not None, then we should generate a "DEBUG_INFO"
- # request.
- if source_map_name is not None:
- prolog_statements.append(
- tree_ir.RegisterDebugInfoInstruction(
- tree_ir.LiteralInstruction(function_name),
- tree_ir.LoadGlobalInstruction(source_map_name),
- tree_ir.LiteralInstruction(jit_runtime.BASELINE_JIT_ORIGIN_NAME)))
- # Create a LOCALS_NODE_NAME node, and connect it to the user root.
- prolog_statements.append(
- tree_ir.create_new_local_node(
- jit_runtime.LOCALS_NODE_NAME,
- tree_ir.LoadIndexInstruction(
- tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
- tree_ir.LiteralInstruction('task_root')),
- jit_runtime.LOCALS_EDGE_NAME))
- for (key, val) in param_dict.items():
- arg_ptr = tree_ir.create_new_local_node(
- body_param_dict[key],
- tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
- prolog_statements.append(arg_ptr)
- prolog_statements.append(
- tree_ir.CreateDictionaryEdgeInstruction(
- tree_ir.LoadLocalInstruction(body_param_dict[key]),
- tree_ir.LiteralInstruction('value'),
- tree_ir.LoadLocalInstruction(val)))
- constructed_body = tree_ir.create_block(
- *(prolog_statements + [function_body]))
- # Shield temporaries from the GC.
- constructed_body = tree_ir.protect_temporaries_from_gc(
- constructed_body,
- tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME),
- compatible_temporary_protects)
- return create_bare_function(function_name, parameter_list, constructed_body)
- def print_value(val):
- """A thin wrapper around 'print'."""
- print(val)
- class ModelverseJit(object):
- """A high-level interface to the modelverse JIT compiler."""
- def __init__(self, max_instructions=None, compiled_function_lookup=None):
- self.todo_entry_points = set()
- self.no_jit_entry_points = set()
- self.jitted_parameters = {}
- self.jit_globals = {
- 'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
- jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function,
- jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
- jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant_function,
- jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global,
- jit_runtime.JIT_REJIT_FUNCTION_NAME : self.jit_rejit,
- jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME : compile_function_body_fast,
- jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable
- }
- # jitted_entry_points maps body ids to values in jit_globals.
- self.jitted_entry_points = {}
- # global_functions maps global value names to body ids.
- self.global_functions = {}
- # global_functions_inv maps body ids to global value names.
- self.global_functions_inv = {}
- # bytecode_graphs maps body ids to their parsed bytecode graphs.
- self.bytecode_graphs = {}
- # jitted_function_aliases maps body ids to known aliases.
- self.jitted_function_aliases = defaultdict(set)
- self.jit_count = 0
- self.max_instructions = max_instructions
- self.compiled_function_lookup = compiled_function_lookup
- # jit_intrinsics is a function name -> intrinsic map.
- self.jit_intrinsics = {}
- # cfg_jit_intrinsics is a function name -> intrinsic map.
- self.cfg_jit_intrinsics = {}
- self.compilation_dependencies = {}
- self.jit_enabled = True
- self.direct_calls_allowed = True
- self.tracing_enabled = False
- self.source_maps_enabled = True
- self.input_function_enabled = False
- self.nop_insertion_enabled = True
- self.thunks_enabled = True
- self.jit_success_log_function = None
- self.jit_code_log_function = None
- self.jit_timing_log = None
- self.compile_function_body = compile_function_body_baseline
- def set_jit_enabled(self, is_enabled=True):
- """Enables or disables the JIT."""
- self.jit_enabled = is_enabled
- def allow_direct_calls(self, is_allowed=True):
- """Allows or disallows direct calls from jitted to jitted code."""
- self.direct_calls_allowed = is_allowed
- def use_input_function(self, is_enabled=True):
- """Configures the JIT to compile 'input' instructions as function calls."""
- self.input_function_enabled = is_enabled
- def enable_tracing(self, is_enabled=True):
- """Enables or disables tracing for jitted code."""
- self.tracing_enabled = is_enabled
- def enable_source_maps(self, is_enabled=True):
- """Enables or disables the creation of source maps for jitted code. Source maps
- convert lines in the generated code to debug information.
- Source maps are enabled by default."""
- self.source_maps_enabled = is_enabled
- def enable_nop_insertion(self, is_enabled=True):
- """Enables or disables nop insertion for jitted code. If enabled, the JIT will
- insert nops at loop back-edges. Inserting nops sacrifices performance to
- keep the jitted code from blocking the thread of execution and consuming
- all resources; nops give the Modelverse server an opportunity to interrupt
- the currently running code."""
- self.nop_insertion_enabled = is_enabled
- def enable_thunks(self, is_enabled=True):
- """Enables or disables thunks for jitted code. Thunks delay the compilation of
- functions until they are actually used. Thunks generally reduce start-up
- time.
- Thunks are enabled by default."""
- self.thunks_enabled = is_enabled
- def set_jit_success_log(self, log_function=print_value):
- """Configures this JIT instance with a function that prints output to a log.
- Success and failure messages for specific functions are then sent to said log."""
- self.jit_success_log_function = log_function
- def set_jit_code_log(self, log_function=print_value):
- """Configures this JIT instance with a function that prints output to a log.
- Function definitions of jitted functions are then sent to said log."""
- self.jit_code_log_function = log_function
- def set_function_body_compiler(self, compile_function_body):
- """Sets the function that the JIT uses to compile function bodies."""
- self.compile_function_body = compile_function_body
- def set_jit_timing_log(self, log_function=print_value):
- """Configures this JIT instance with a function that prints output to a log.
- The time it takes to compile functions is then sent to this log."""
- self.jit_timing_log = log_function
- def mark_entry_point(self, body_id):
- """Marks the node with the given identifier as a function entry point."""
- if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
- self.todo_entry_points.add(body_id)
- def is_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point."""
- return body_id in self.todo_entry_points or \
- body_id in self.no_jit_entry_points or \
- body_id in self.jitted_entry_points
- def is_jittable_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point that
- has not been marked as non-jittable. This only returns `True` if the JIT
- is enabled and the function entry point has been marked jittable, or if
- the function has already been compiled."""
- return ((self.jit_enabled and body_id in self.todo_entry_points) or
- self.has_compiled(body_id))
- def has_compiled(self, body_id):
- """Tests if the function belonging to the given body node has been compiled yet."""
- return body_id in self.jitted_entry_points
- def get_compiled_name(self, body_id):
- """Gets the name of the compiled version of the given body node in the JIT
- global state."""
- if body_id in self.jitted_entry_points:
- return self.jitted_entry_points[body_id]
- else:
- return None
- def mark_no_jit(self, body_id):
- """Informs the JIT that the node with the given identifier is a function entry
- point that must never be jitted."""
- self.no_jit_entry_points.add(body_id)
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def generate_name(self, infix, suggested_name=None):
- """Generates a new name or picks the suggested name if it is still
- available."""
- if suggested_name is not None \
- and suggested_name not in self.jit_globals \
- and not keyword.iskeyword(suggested_name):
- self.jit_count += 1
- return suggested_name
- else:
- function_name = 'jit_%s%d' % (infix, self.jit_count)
- self.jit_count += 1
- return function_name
- def generate_function_name(self, body_id, suggested_name=None):
- """Generates a new function name or picks the suggested name if it is still
- available."""
- if suggested_name is None:
- suggested_name = self.get_global_name(body_id)
- return self.generate_name('func', suggested_name)
- def register_global(self, body_id, global_name):
- """Associates the given body id with the given global name."""
- self.global_functions[global_name] = body_id
- self.global_functions_inv[body_id] = global_name
- def get_global_name(self, body_id):
- """Gets the name of the global function with the given body id.
- Returns None if no known global exists with the given id."""
- if body_id in self.global_functions_inv:
- return self.global_functions_inv[body_id]
- else:
- return None
- def get_global_body_id(self, global_name):
- """Gets the body id of the global function with the given name.
- Returns None if no known global exists with the given name."""
- if global_name in self.global_functions:
- return self.global_functions[global_name]
- else:
- return None
- def register_compiled(self, body_id, compiled_function, function_name=None):
- """Registers a compiled entry point with the JIT."""
- # Get the function's name.
- actual_function_name = self.generate_function_name(body_id, function_name)
- # Map the body id to the given parameter list.
- self.jitted_entry_points[body_id] = actual_function_name
- self.jit_globals[actual_function_name] = compiled_function
- if function_name is not None:
- self.register_global(body_id, function_name)
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def import_value(self, value, suggested_name=None):
- """Imports the given value into the JIT's global scope, with the given suggested name.
- The actual name of the value (within the JIT's global scope) is returned."""
- actual_name = self.generate_name('import', suggested_name)
- self.jit_globals[actual_name] = value
- return actual_name
- def __lookup_compiled_body_impl(self, body_id):
- """Looks up a compiled function by body id. Returns a matching function,
- or None if no function was found."""
- if body_id is not None and body_id in self.jitted_entry_points:
- return self.jit_globals[self.jitted_entry_points[body_id]]
- else:
- return None
- def __lookup_external_body_impl(self, global_name, body_id):
- """Looks up an external function by global name. Returns a matching function,
- or None if no function was found."""
- if global_name is not None and self.compiled_function_lookup is not None:
- result = self.compiled_function_lookup(global_name)
- if result is not None and body_id is not None:
- self.register_compiled(body_id, result, global_name)
- return result
- else:
- return None
- def lookup_compiled_body(self, body_id):
- """Looks up a compiled function by body id. Returns a matching function,
- or None if no function was found."""
- result = self.__lookup_compiled_body_impl(body_id)
- if result is not None:
- return result
- else:
- global_name = self.get_global_name(body_id)
- return self.__lookup_external_body_impl(global_name, body_id)
- def lookup_compiled_function(self, global_name):
- """Looks up a compiled function by global name. Returns a matching function,
- or None if no function was found."""
- body_id = self.get_global_body_id(global_name)
- result = self.__lookup_compiled_body_impl(body_id)
- if result is not None:
- return result
- else:
- return self.__lookup_external_body_impl(global_name, body_id)
- def get_intrinsic(self, name):
- """Tries to find an intrinsic version of the function with the
- given name."""
- if name in self.jit_intrinsics:
- return self.jit_intrinsics[name]
- else:
- return None
- def get_cfg_intrinsic(self, name):
- """Tries to find an intrinsic version of the function with the
- given name that is specialized for CFGs."""
- if name in self.cfg_jit_intrinsics:
- return self.cfg_jit_intrinsics[name]
- else:
- return None
- def register_intrinsic(self, name, intrinsic_function, cfg_intrinsic_function=None):
- """Registers the given intrisic with the JIT. This will make the JIT replace calls to
- the function with the given entry point by an application of the specified function."""
- self.jit_intrinsics[name] = intrinsic_function
- if cfg_intrinsic_function is not None:
- self.register_cfg_intrinsic(name, cfg_intrinsic_function)
- def register_cfg_intrinsic(self, name, cfg_intrinsic_function):
- """Registers the given intrisic with the JIT. This will make the JIT replace calls to
- the function with the given entry point by an application of the specified function."""
- self.cfg_jit_intrinsics[name] = cfg_intrinsic_function
- def register_binary_intrinsic(self, name, operator):
- """Registers an intrinsic with the JIT that represents the given binary operation."""
- self.register_intrinsic(
- name,
- lambda a, b:
- tree_ir.CreateNodeWithValueInstruction(
- tree_ir.BinaryInstruction(
- tree_ir.ReadValueInstruction(a),
- operator,
- tree_ir.ReadValueInstruction(b))),
- lambda original_def, a, b:
- original_def.redefine(
- cfg_ir.CreateNode(
- original_def.insert_before(
- cfg_ir.Binary(
- original_def.insert_before(cfg_ir.Read(a)),
- operator,
- original_def.insert_before(cfg_ir.Read(b)))))))
- def register_unary_intrinsic(self, name, operator):
- """Registers an intrinsic with the JIT that represents the given unary operation."""
- self.register_intrinsic(
- name,
- lambda a:
- tree_ir.CreateNodeWithValueInstruction(
- tree_ir.UnaryInstruction(
- operator,
- tree_ir.ReadValueInstruction(a))),
- lambda original_def, a:
- original_def.redefine(
- cfg_ir.CreateNode(
- original_def.insert_before(
- cfg_ir.Unary(
- operator,
- original_def.insert_before(cfg_ir.Read(a)))))))
- def register_cast_intrinsic(self, name, target_type):
- """Registers an intrinsic with the JIT that represents a unary conversion operator."""
- self.register_intrinsic(
- name,
- lambda a:
- tree_ir.CreateNodeWithValueInstruction(
- tree_ir.CallInstruction(
- tree_ir.LoadGlobalInstruction(target_type.__name__),
- [tree_ir.ReadValueInstruction(a)])),
- lambda original_def, a:
- original_def.redefine(
- cfg_ir.CreateNode(
- original_def.insert_before(
- cfg_ir.create_pure_simple_call(
- target_type.__name__,
- original_def.insert_before(cfg_ir.Read(a)))))))
- def jit_signature(self, body_id):
- """Acquires the signature for the given body id node, which consists of the
- parameter variables, parameter name and a flag that tells if the given function
- is mutable."""
- if body_id not in self.jitted_parameters:
- signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
- signature_id = signature_id[0]
- param_set_id, is_mutable = yield [
- ("RD", [signature_id, "params"]),
- ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
- if param_set_id is None:
- self.jitted_parameters[body_id] = ([], [], is_mutable)
- else:
- param_name_ids, = yield [("RDK", [param_set_id])]
- param_names = yield [("RV", [n]) for n in param_name_ids]
- param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
- self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
- raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
- def jit_parse_bytecode(self, body_id):
- """Parses the given function body as a bytecode graph."""
- if body_id in self.bytecode_graphs:
- raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id])
- parser = bytecode_parser.BytecodeParser()
- result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])]
- self.bytecode_graphs[body_id] = result
- raise primitive_functions.PrimitiveFinished(result)
- def check_jittable(self, body_id, suggested_name=None):
- """Checks if the function with the given body id is obviously non-jittable. If it's
- non-jittable, then a `JitCompilationFailedException` exception is thrown."""
- if body_id is None:
- raise ValueError('body_id cannot be None')
- elif body_id in self.no_jit_entry_points:
- # We're not allowed to jit this function or have tried and failed before.
- raise JitCompilationFailedException(
- 'Cannot jit function %s at %d because it is marked non-jittable.' % (
- '' if suggested_name is None else "'" + suggested_name + "'",
- body_id))
- elif not self.jit_enabled:
- # We're not allowed to jit anything.
- raise JitCompilationFailedException(
- 'Cannot jit function %s at %d because the JIT has been disabled.' % (
- '' if suggested_name is None else "'" + suggested_name + "'",
- body_id))
- def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
- """Replaces the function with the given name by compiling the bytecode at the given
- body id."""
- if self.jit_timing_log is not None:
- start_time = time.time()
- if compile_function_body is None:
- compile_function_body = self.compile_function_body
- self.check_jittable(body_id, function_name)
- # Generate a name for the function we're about to analyze, and pretend that
- # it already exists. (we need to do this for recursive functions)
- self.jitted_entry_points[body_id] = function_name
- self.jit_globals[function_name] = None
- (_, _, is_mutable), = yield [
- ("CALL_ARGS", [self.jit_signature, (body_id,)])]
- dependencies = set([body_id])
- self.compilation_dependencies[body_id] = dependencies
- def handle_jit_exception(exception):
- # If analysis fails, then a JitCompilationFailedException will be thrown.
- del self.compilation_dependencies[body_id]
- for dep in dependencies:
- self.mark_no_jit(dep)
- if dep in self.jitted_entry_points:
- del self.jitted_entry_points[dep]
- failure_message = "%s (function '%s' at %d)" % (
- exception.message, function_name, body_id)
- if self.jit_success_log_function is not None:
- self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
- raise JitCompilationFailedException(failure_message)
- # Try to analyze the function's body.
- yield [("TRY", [])]
- yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
- if is_mutable:
- # We can't just JIT mutable functions. That'd be dangerous.
- raise JitCompilationFailedException(
- "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
- compiled_function, = yield [
- ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])]
- yield [("END_TRY", [])]
- del self.compilation_dependencies[body_id]
- if self.jit_success_log_function is not None:
- assert self.jitted_entry_points[body_id] == function_name
- self.jit_success_log_function(
- "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
- if self.jit_timing_log is not None:
- end_time = time.time()
- compile_time = end_time - start_time
- self.jit_timing_log('Compile time for %s:%f' % (function_name, compile_time))
- raise primitive_functions.PrimitiveFinished(compiled_function)
- def get_source_map_name(self, function_name):
- """Gets the name of the given jitted function's source map. None is returned if source maps
- are disabled."""
- if self.source_maps_enabled:
- return function_name + "_source_map"
- else:
- return None
- def get_can_rejit_name(self, function_name):
- """Gets the name of the given jitted function's can-rejit flag."""
- return function_name + "_can_rejit"
- def jit_define_function(self, function_name, function_def):
- """Converts the given tree-IR function definition to Python code, defines it,
- and extracts the resulting function."""
- # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
- # pylint: disable=I0011,W0122
- if self.jit_code_log_function is not None:
- self.jit_code_log_function(function_def)
- # Convert the function definition to Python code, and compile it.
- code_generator = tree_ir.PythonGenerator()
- function_def.generate_python_def(code_generator)
- source_map_name = self.get_source_map_name(function_name)
- if source_map_name is not None:
- self.jit_globals[source_map_name] = code_generator.source_map_builder.source_map
- exec(str(code_generator), self.jit_globals)
- # Extract the compiled function from the JIT global state.
- return self.jit_globals[function_name]
- def jit_delete_function(self, function_name):
- """Deletes the function with the given function name."""
- del self.jit_globals[function_name]
- def jit_compile(self, task_root, body_id, suggested_name=None):
- """Tries to jit the function defined by the given entry point id and parameter list."""
- if body_id is None:
- raise ValueError('body_id cannot be None')
- elif body_id in self.jitted_entry_points:
- raise primitive_functions.PrimitiveFinished(
- self.jit_globals[self.jitted_entry_points[body_id]])
- compiled_func = self.lookup_compiled_body(body_id)
- if compiled_func is not None:
- raise primitive_functions.PrimitiveFinished(compiled_func)
- # Generate a name for the function we're about to analyze, and 're-compile'
- # it for the first time.
- function_name = self.generate_function_name(body_id, suggested_name)
- yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
- def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None):
- """Re-compiles the given function. If compilation fails, then the can-rejit
- flag is set to false."""
- old_jitted_func = self.jitted_entry_points[body_id]
- def __handle_jit_failed(_):
- self.jit_globals[self.get_can_rejit_name(function_name)] = False
- self.jitted_entry_points[body_id] = old_jitted_func
- self.no_jit_entry_points.remove(body_id)
- raise primitive_functions.PrimitiveFinished(None)
- yield [("TRY", [])]
- yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
- jitted_function, = yield [
- ("CALL_ARGS",
- [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
- yield [("END_TRY", [])]
- # Update all aliases.
- for function_alias in self.jitted_function_aliases[body_id]:
- self.jit_globals[function_alias] = jitted_function
- def jit_thunk(self, get_function_body, global_name=None):
- """Creates a thunk from the given IR tree that computes the function's body id.
- This thunk is a function that will invoke the function whose body id is retrieved.
- The thunk's name in the JIT's global context is returned."""
- # The general idea is to first create a function that looks a bit like this:
- #
- # def jit_get_function_body(**kwargs):
- # raise primitive_functions.PrimitiveFinished(<get_function_body>)
- #
- get_function_body_name = self.generate_name('get_function_body')
- get_function_body_func_def = create_function(
- get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))
- get_function_body_func = self.jit_define_function(
- get_function_body_name, get_function_body_func_def)
- # Next, we want to create a thunk that invokes said function, and then replaces itself.
- thunk_name = self.generate_name('thunk', global_name)
- def __jit_thunk(**kwargs):
- # Compute the body id, and delete the function that computes the body id; we won't
- # be needing it anymore after this call.
- body_id, = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
- self.jit_delete_function(get_function_body_name)
- # Try to associate the global name with the body id, if that's at all possible.
- if global_name is not None:
- self.register_global(body_id, global_name)
- compiled_function = self.lookup_compiled_body(body_id)
- if compiled_function is not None:
- # Replace this thunk by the compiled function.
- self.jit_globals[thunk_name] = compiled_function
- self.jitted_function_aliases[body_id].add(thunk_name)
- else:
- def __handle_jit_exception(_):
- # Replace this thunk by a different thunk: one that calls the interpreter
- # directly, without checking if the function is jittable.
- (_, parameter_names, _), = yield [
- ("CALL_ARGS", [self.jit_signature, (body_id,)])]
- def __interpreter_thunk(**new_kwargs):
- named_arg_dict = {name : new_kwargs[name] for name in parameter_names}
- return jit_runtime.interpret_function_body(
- body_id, named_arg_dict, **new_kwargs)
- self.jit_globals[thunk_name] = __interpreter_thunk
- yield [("TRY", [])]
- yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])]
- compiled_function, = yield [
- ("CALL_ARGS",
- [self.jit_recompile, (kwargs['task_root'], body_id, thunk_name)])]
- yield [("END_TRY", [])]
- # Call the compiled function.
- yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])]
- self.jit_globals[thunk_name] = __jit_thunk
- return thunk_name
- def jit_thunk_constant_body(self, body_id):
- """Creates a thunk from the given body id.
- This thunk is a function that will invoke the function whose body id is given.
- The thunk's name in the JIT's global context is returned."""
- self.lookup_compiled_body(body_id)
- compiled_name = self.get_compiled_name(body_id)
- if compiled_name is not None:
- # We might have compiled the function with the given body id already. In that case,
- # we need not bother with constructing the thunk; we can return the compiled function
- # right away.
- return compiled_name
- else:
- # Looks like we'll just have to build that thunk after all.
- return self.jit_thunk(tree_ir.LiteralInstruction(body_id))
- def jit_thunk_constant_function(self, body_id):
- """Creates a thunk from the given function id.
- This thunk is a function that will invoke the function whose function id is given.
- The thunk's name in the JIT's global context is returned."""
- return self.jit_thunk(
- tree_ir.ReadDictionaryValueInstruction(
- tree_ir.LiteralInstruction(body_id),
- tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)))
- def jit_thunk_global(self, global_name):
- """Creates a thunk from given global name.
- This thunk is a function that will invoke the function whose body id is given.
- The thunk's name in the JIT's global context is returned."""
- # We might have compiled the function with the given name already. In that case,
- # we need not bother with constructing the thunk; we can return the compiled function
- # right away.
- body_id = self.get_global_body_id(global_name)
- if body_id is not None:
- self.lookup_compiled_body(body_id)
- compiled_name = self.get_compiled_name(body_id)
- if compiled_name is not None:
- return compiled_name
- # Looks like we'll just have to build that thunk after all.
- # We want to look up the global function like so
- #
- # _globals, = yield [("RD", [kwargs['task_root'], "globals"])]
- # global_var, = yield [("RD", [_globals, global_name])]
- # function_id, = yield [("RD", [global_var, "value"])]
- # body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])]
- #
- return self.jit_thunk(
- tree_ir.ReadDictionaryValueInstruction(
- tree_ir.ReadDictionaryValueInstruction(
- tree_ir.ReadDictionaryValueInstruction(
- tree_ir.ReadDictionaryValueInstruction(
- tree_ir.LoadIndexInstruction(
- tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
- tree_ir.LiteralInstruction('task_root')),
- tree_ir.LiteralInstruction('globals')),
- tree_ir.LiteralInstruction(global_name)),
- tree_ir.LiteralInstruction('value')),
- tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
- global_name)
- def compile_function_body_interpret(jit, function_name, body_id, task_root, header=None):
- """Create a function that invokes the interpreter on the given function."""
- (parameter_ids, parameter_list, _), = yield [
- ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
- param_dict = dict(zip(parameter_ids, parameter_list))
- body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
- def __interpret_function(**kwargs):
- if header is not None:
- (done, result), = yield [("CALL_KWARGS", [header, kwargs])]
- if done:
- raise primitive_functions.PrimitiveFinished(result)
- local_args = {}
- inner_kwargs = dict(kwargs)
- for param_id, name in param_dict.items():
- local_args[param_id] = inner_kwargs[name]
- del inner_kwargs[name]
- yield [("TAIL_CALL_ARGS",
- [bytecode_interpreter.interpret_bytecode_function,
- (function_name, body_bytecode, local_args, inner_kwargs)])]
- jit.jit_globals[function_name] = __interpret_function
- raise primitive_functions.PrimitiveFinished(__interpret_function)
- def compile_function_body_baseline(
- jit, function_name, body_id, task_root,
- header=None, compatible_temporary_protects=False):
- """Have the baseline JIT compile the function with the given name and body id."""
- (parameter_ids, parameter_list, _), = yield [
- ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
- param_dict = dict(zip(parameter_ids, parameter_list))
- body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
- body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
- state = bytecode_to_tree.AnalysisState(
- jit, body_id, task_root, body_param_dict,
- jit.max_instructions)
- constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
- if header is not None:
- constructed_body = tree_ir.create_block(header, constructed_body)
- # Optimize the function's body.
- constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
- # Wrap the tree IR in a function definition.
- constructed_function = create_function(
- function_name, parameter_list, param_dict,
- body_param_dict, constructed_body, jit.get_source_map_name(function_name),
- compatible_temporary_protects)
- # Convert the function definition to Python code, and compile it.
- raise primitive_functions.PrimitiveFinished(
- jit.jit_define_function(function_name, constructed_function))
- def compile_function_body_fast(jit, function_name, body_id, _):
- """Have the fast JIT compile the function with the given name and body id."""
- (parameter_ids, parameter_list, _), = yield [
- ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
- param_dict = dict(zip(parameter_ids, parameter_list))
- body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
- bytecode_analyzer = bytecode_to_cfg.AnalysisState(jit, function_name, param_dict)
- bytecode_analyzer.analyze(body_bytecode)
- entry_point, = yield [
- ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, jit)])]
- if jit.jit_code_log_function is not None:
- jit.jit_code_log_function(
- "CFG for function '%s' at '%d':\n%s" % (
- function_name, body_id,
- '\n'.join(map(str, cfg_ir.get_all_reachable_blocks(entry_point)))))
- # Lower the CFG to tree IR.
- constructed_body = cfg_to_tree.lower_flow_graph(entry_point, jit)
- # Optimize the tree that was generated.
- constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
- constructed_function = create_bare_function(function_name, parameter_list, constructed_body)
- # Convert the function definition to Python code, and compile it.
- raise primitive_functions.PrimitiveFinished(
- jit.jit_define_function(function_name, constructed_function))
- def favor_large_functions(body_bytecode):
- """Computes the initial temperature of a function based on the size of
- its body bytecode. Larger functions are favored and the temperature
- is incremented by one on every call."""
- # The rationale for this heuristic is that it does some damage control:
- # we can afford to decide (wrongly) not to fast-jit a small function,
- # because we can just fast-jit that function later on. Since the function
- # is so small, it will (hopefully) not be able to deal us a heavy blow in
- # terms of performance.
- #
- # If we decide not to fast-jit a large function however, we might end up
- # in a situation where said function runs for a long time before we
- # realize that we really should have jitted it. And that's exactly what
- # this heuristic tries to avoid.
- return len(body_bytecode.get_reachable()), 1
- def favor_small_functions(body_bytecode):
- """Computes the initial temperature of a function based on the size of
- its body bytecode. Smaller functions are favored and the temperature
- is incremented by one on every call."""
- # The rationale for this heuristic is that small functions are easy to
- # fast-jit, because they probably won't trigger the non-linear complexity
- # of fast-jit's algorithms. So it might be cheaper to fast-jit small
- # functions and get a performance boost from that than to fast-jit large
- # functions.
- return ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD - len(body_bytecode.get_reachable()), 1
- ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER = 4
- ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD = 100
- """The threshold temperature at which the adaptive JIT will use the baseline JIT."""
- ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD = 250
- """The threshold temperature at which the adaptive JIT will use the fast JIT."""
- def favor_loops(body_bytecode):
- """Computes the initial temperature of a function. Code within a loop makes
- the function hotter; code outside loops makes the function colder. The
- temperature is incremented by one on every call."""
- reachable_instructions = body_bytecode.get_reachable()
- # First set the temperature to the negative number of instructions.
- temperature = ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD - len(reachable_instructions)
- for instruction in reachable_instructions:
- if isinstance(instruction, bytecode_ir.WhileInstruction):
- # Then increase the temperature by the number of instructions reachable
- # from loop bodies. Note that the algorithm will count nested loops twice.
- # This is actually by design.
- loop_body_instructions = instruction.body.get_reachable(
- lambda x: not isinstance(
- x, (bytecode_ir.BreakInstruction, bytecode_ir.ContinueInstruction)))
- temperature += ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER * len(loop_body_instructions)
- return temperature, 1
- def favor_small_loops(body_bytecode):
- """Computes the initial temperature of a function. Code within a loop makes
- the function hotter; code outside loops makes the function colder. The
- temperature is incremented by one on every call."""
- reachable_instructions = body_bytecode.get_reachable()
- # First set the temperature to the negative number of instructions.
- temperature = ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD - 50 - len(reachable_instructions)
- for instruction in reachable_instructions:
- if isinstance(instruction, bytecode_ir.WhileInstruction):
- # Then increase the temperature by the number of instructions reachable
- # from loop bodies. Note that the algorithm will count nested loops twice.
- # This is actually by design.
- loop_body_instructions = instruction.body.get_reachable(
- lambda x: not isinstance(
- x, (bytecode_ir.BreakInstruction, bytecode_ir.ContinueInstruction)))
- temperature += (
- (ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER ** 2) *
- int(math.sqrt(len(loop_body_instructions))))
- return temperature, max(int(math.log(len(reachable_instructions), 2)), 1)
- class AdaptiveJitState(object):
- """Shared state for adaptive JIT compilation."""
- def __init__(
- self, temperature_counter_name,
- temperature_increment, can_rejit_name):
- self.temperature_counter_name = temperature_counter_name
- self.temperature_increment = temperature_increment
- self.can_rejit_name = can_rejit_name
- def compile_interpreter(
- self, jit, function_name, body_id, task_root):
- """Compiles the given function as a function that controls the temperature counter
- and calls the interpreter."""
- def __increment_temperature(**kwargs):
- if jit.jit_globals[self.can_rejit_name]:
- temperature_counter_val = jit.jit_globals[self.temperature_counter_name]
- temperature_counter_val += self.temperature_increment
- jit.jit_globals[self.temperature_counter_name] = temperature_counter_val
- if temperature_counter_val >= ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD:
- if temperature_counter_val >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
- yield [
- ("CALL_ARGS",
- [jit.jit_rejit,
- (task_root, body_id, function_name, compile_function_body_fast)])]
- else:
- yield [
- ("CALL_ARGS",
- [jit.jit_rejit,
- (task_root, body_id, function_name, self.compile_baseline)])]
- result, = yield [("CALL_KWARGS", [jit.jit_globals[function_name], kwargs])]
- raise primitive_functions.PrimitiveFinished((True, result))
- raise primitive_functions.PrimitiveFinished((False, None))
- yield [
- ("TAIL_CALL_ARGS",
- [compile_function_body_interpret,
- (jit, function_name, body_id, task_root, __increment_temperature)])]
- def compile_baseline(
- self, jit, function_name, body_id, task_root):
- """Compiles the given function with the baseline JIT, and inserts logic that controls
- the temperature counter."""
- (_, parameter_list, _), = yield [
- ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
- # This tree represents the following logic:
- #
- # if can_rejit:
- # global temperature_counter
- # temperature_counter = temperature_counter + temperature_increment
- # if temperature_counter >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
- # yield [("CALL_KWARGS", [jit_runtime.JIT_REJIT_FUNCTION_NAME, {...}])]
- # yield [("TAIL_CALL_KWARGS", [function_name, {...}])]
- header = tree_ir.SelectInstruction(
- tree_ir.LoadGlobalInstruction(self.can_rejit_name),
- tree_ir.create_block(
- tree_ir.DeclareGlobalInstruction(self.temperature_counter_name),
- tree_ir.IgnoreInstruction(
- tree_ir.StoreGlobalInstruction(
- self.temperature_counter_name,
- tree_ir.BinaryInstruction(
- tree_ir.LoadGlobalInstruction(self.temperature_counter_name),
- '+',
- tree_ir.LiteralInstruction(self.temperature_increment)))),
- tree_ir.SelectInstruction(
- tree_ir.BinaryInstruction(
- tree_ir.LoadGlobalInstruction(self.temperature_counter_name),
- '>=',
- tree_ir.LiteralInstruction(ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD)),
- tree_ir.create_block(
- tree_ir.RunGeneratorFunctionInstruction(
- tree_ir.LoadGlobalInstruction(jit_runtime.JIT_REJIT_FUNCTION_NAME),
- tree_ir.DictionaryLiteralInstruction([
- (tree_ir.LiteralInstruction('task_root'),
- bytecode_to_tree.load_task_root()),
- (tree_ir.LiteralInstruction('body_id'),
- tree_ir.LiteralInstruction(body_id)),
- (tree_ir.LiteralInstruction('function_name'),
- tree_ir.LiteralInstruction(function_name)),
- (tree_ir.LiteralInstruction('compile_function_body'),
- tree_ir.LoadGlobalInstruction(
- jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME))]),
- result_type=tree_ir.NO_RESULT_TYPE),
- bytecode_to_tree.create_return(
- tree_ir.create_jit_call(
- tree_ir.LoadGlobalInstruction(function_name),
- [(name, tree_ir.LoadLocalInstruction(name))
- for name in parameter_list],
- tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME)))),
- tree_ir.EmptyInstruction())),
- tree_ir.EmptyInstruction())
- # Compile with the baseline JIT, and insert the header.
- yield [
- ("TAIL_CALL_ARGS",
- [compile_function_body_baseline,
- (jit, function_name, body_id, task_root, header, True)])]
- def compile_function_body_adaptive(
- jit, function_name, body_id, task_root,
- temperature_heuristic=favor_loops):
- """Compile the function with the given name and body id. An execution engine is picked
- automatically, and the function may be compiled again at a later time."""
- # The general idea behind this compilation technique is to first use the baseline JIT
- # to compile a function, and then switch to the fast JIT when we determine that doing
- # so would be a good idea. We maintain a 'temperature' counter, which has an initial value
- # and gets incremented every time the function is executed.
- body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
- initial_temperature, temperature_increment = temperature_heuristic(body_bytecode)
- if jit.jit_success_log_function is not None:
- jit.jit_success_log_function(
- "Initial temperature for '%s': %d" % (function_name, initial_temperature))
- if initial_temperature >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
- # Initial temperature exceeds the fast-jit threshold.
- # Compile this thing with fast-jit right away.
- if jit.jit_success_log_function is not None:
- jit.jit_success_log_function(
- "Compiling '%s' with fast-jit." % function_name)
- yield [
- ("TAIL_CALL_ARGS",
- [compile_function_body_fast, (jit, function_name, body_id, task_root)])]
- temperature_counter_name = jit.import_value(
- initial_temperature, function_name + "_temperature_counter")
- can_rejit_name = jit.get_can_rejit_name(function_name)
- jit.jit_globals[can_rejit_name] = True
- state = AdaptiveJitState(temperature_counter_name, temperature_increment, can_rejit_name)
- if initial_temperature >= ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD:
- # Initial temperature exceeds the baseline JIT threshold.
- # Compile this thing with baseline JIT right away.
- if jit.jit_success_log_function is not None:
- jit.jit_success_log_function(
- "Compiling '%s' with baseline-jit." % function_name)
- yield [
- ("TAIL_CALL_ARGS",
- [state.compile_baseline, (jit, function_name, body_id, task_root)])]
- else:
- # Looks like we'll use the interpreter initially.
- if jit.jit_success_log_function is not None:
- jit.jit_success_log_function(
- "Compiling '%s' with bytecode-interpreter." % function_name)
- yield [
- ("TAIL_CALL_ARGS",
- [state.compile_interpreter, (jit, function_name, body_id, task_root)])]
|