1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063 |
- import modelverse_kernel.primitives as primitive_functions
- import modelverse_jit.tree_ir as tree_ir
- import modelverse_jit.runtime as jit_runtime
- import keyword
- # Import JitCompilationFailedException because it used to be defined
- # in this module.
- JitCompilationFailedException = jit_runtime.JitCompilationFailedException
- KWARGS_PARAMETER_NAME = "kwargs"
- """The name of the kwargs parameter in jitted functions."""
- CALL_FUNCTION_NAME = "__call_function"
- """The name of the '__call_function' function, in the jitted function scope."""
- GET_INPUT_FUNCTION_NAME = "__get_input"
- """The name of the '__get_input' function, in the jitted function scope."""
- LOCALS_NODE_NAME = "jit_locals"
- """The name of the node that is connected to all JIT locals in a given function call."""
- LOCALS_EDGE_NAME = "jit_locals_edge"
- """The name of the edge that connects the LOCALS_NODE_NAME node to a user root."""
- def get_parameter_names(compiled_function):
- """Gets the given compiled function's parameter names."""
- if hasattr(compiled_function, '__code__'):
- return compiled_function.__code__.co_varnames[
- :compiled_function.__code__.co_argcount]
- elif hasattr(compiled_function, '__init__'):
- return get_parameter_names(compiled_function.__init__)[1:]
- else:
- raise ValueError("'compiled_function' must be a function or a type.")
- def apply_intrinsic(intrinsic_function, named_args):
- """Applies the given intrinsic to the given sequence of named arguments."""
- param_names = get_parameter_names(intrinsic_function)
- if tuple(param_names) == tuple([n for n, _ in named_args]):
- # Perfect match. Yay!
- return intrinsic_function(**dict(named_args))
- else:
- # We'll have to store the arguments into locals to preserve
- # the order of evaluation.
- stored_args = [(name, tree_ir.StoreLocalInstruction(None, arg)) for name, arg in named_args]
- arg_value_dict = dict([(name, arg.create_load()) for name, arg in stored_args])
- store_instructions = [instruction for _, instruction in stored_args]
- return tree_ir.CompoundInstruction(
- tree_ir.create_block(*store_instructions),
- intrinsic_function(**arg_value_dict))
- def map_and_simplify_generator(function, instruction):
- """Applies the given mapping function to every instruction in the tree
- that has the given instruction as root, and simplifies it on-the-fly.
- This is at least as powerful as first mapping and then simplifying, as
- maps and simplifications are interspersed.
- This function assumes that function creates a generator that returns by
- raising a primitive_functions.PrimitiveFinished."""
- # First handle the children by mapping on them and then simplifying them.
- new_children = []
- for inst in instruction.get_children():
- new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])]
- new_children.append(new_inst)
- # Then apply the function to the top-level node.
- transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])]
- # Finally, simplify the transformed top-level node.
- raise primitive_functions.PrimitiveFinished(transformed.simplify_node())
- def expand_constant_read(instruction):
- """Tries to replace a read of a constant node by a literal."""
- if isinstance(instruction, tree_ir.ReadValueInstruction) and \
- isinstance(instruction.node_id, tree_ir.LiteralInstruction):
- val, = yield [("RV", [instruction.node_id.literal])]
- raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
- else:
- raise primitive_functions.PrimitiveFinished(instruction)
- def optimize_tree_ir(instruction):
- """Optimizes an IR tree."""
- return map_and_simplify_generator(expand_constant_read, instruction)
- def print_value(val):
- """A thin wrapper around 'print'."""
- print(val)
- class ModelverseJit(object):
- """A high-level interface to the modelverse JIT compiler."""
- def __init__(self, max_instructions=None, compiled_function_lookup=None):
- self.todo_entry_points = set()
- self.no_jit_entry_points = set()
- self.jitted_entry_points = {}
- self.jitted_parameters = {}
- self.jit_globals = {
- 'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
- CALL_FUNCTION_NAME : jit_runtime.call_function,
- GET_INPUT_FUNCTION_NAME : jit_runtime.get_input
- }
- self.jit_count = 0
- self.max_instructions = max_instructions
- self.compiled_function_lookup = compiled_function_lookup
- # jit_intrinsics is a function name -> intrinsic map.
- self.jit_intrinsics = {}
- self.compilation_dependencies = {}
- self.jit_enabled = True
- self.direct_calls_allowed = True
- self.tracing_enabled = False
- self.input_function_enabled = False
- self.nop_insertion_enabled = True
- self.jit_success_log_function = None
- self.jit_code_log_function = None
- def set_jit_enabled(self, is_enabled=True):
- """Enables or disables the JIT."""
- self.jit_enabled = is_enabled
- def allow_direct_calls(self, is_allowed=True):
- """Allows or disallows direct calls from jitted to jitted code."""
- self.direct_calls_allowed = is_allowed
- def use_input_function(self, is_enabled=True):
- """Configures the JIT to compile 'input' instructions as function calls."""
- self.input_function_enabled = is_enabled
- def enable_tracing(self, is_enabled=True):
- """Enables or disables tracing for jitted code."""
- self.tracing_enabled = is_enabled
- def enable_nop_insertion(self, is_enabled=True):
- """Enables or disables nop insertion for jitted code. The JIT will insert nops at loop
- back-edges. Inserting nops sacrifices performance to keep the jitted code from
- blocking the thread of execution by consuming all resources; nops give the
- Modelverse server an opportunity to interrupt the currently running code."""
- self.nop_insertion_enabled = is_enabled
- def set_jit_success_log(self, log_function=print_value):
- """Configures this JIT instance with a function that prints output to a log.
- Success and failure messages for specific functions are then sent to said log."""
- self.jit_success_log_function = log_function
- def set_jit_code_log(self, log_function=print_value):
- """Configures this JIT instance with a function that prints output to a log.
- Function definitions of jitted functions are then sent to said log."""
- self.jit_code_log_function = log_function
- def mark_entry_point(self, body_id):
- """Marks the node with the given identifier as a function entry point."""
- if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
- self.todo_entry_points.add(body_id)
- def is_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point."""
- return body_id in self.todo_entry_points or \
- body_id in self.no_jit_entry_points or \
- body_id in self.jitted_entry_points
- def is_jittable_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point that
- has not been marked as non-jittable. This only returns `True` if the JIT
- is enabled and the function entry point has been marked jittable, or if
- the function has already been compiled."""
- return ((self.jit_enabled and body_id in self.todo_entry_points) or
- self.has_compiled(body_id))
- def has_compiled(self, body_id):
- """Tests if the function belonging to the given body node has been compiled yet."""
- return body_id in self.jitted_entry_points
- def get_compiled_name(self, body_id):
- """Gets the name of the compiled version of the given body node in the JIT
- global state."""
- return self.jitted_entry_points[body_id]
- def mark_no_jit(self, body_id):
- """Informs the JIT that the node with the given identifier is a function entry
- point that must never be jitted."""
- self.no_jit_entry_points.add(body_id)
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def generate_name(self, infix, suggested_name=None):
- """Generates a new name or picks the suggested name if it is still
- available."""
- if suggested_name is not None \
- and suggested_name not in self.jit_globals \
- and not keyword.iskeyword(suggested_name):
- self.jit_count += 1
- return suggested_name
- else:
- function_name = 'jit_%s%d' % (infix, self.jit_count)
- self.jit_count += 1
- return function_name
- def generate_function_name(self, suggested_name=None):
- """Generates a new function name or picks the suggested name if it is still
- available."""
- return self.generate_name('func', suggested_name)
- def register_compiled(self, body_id, compiled_function, function_name=None):
- """Registers a compiled entry point with the JIT."""
- # Get the function's name.
- function_name = self.generate_function_name(function_name)
- # Map the body id to the given parameter list.
- self.jitted_entry_points[body_id] = function_name
- self.jit_globals[function_name] = compiled_function
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def import_value(self, value, suggested_name=None):
- """Imports the given value into the JIT's global scope, with the given suggested name.
- The actual name of the value (within the JIT's global scope) is returned."""
- actual_name = self.generate_name('import', suggested_name)
- self.jit_globals[actual_name] = value
- return actual_name
- def lookup_compiled_function(self, name):
- """Looks up a compiled function by name. Returns a matching function,
- or None if no function was found."""
- if name is None:
- return None
- elif name in self.jit_globals:
- return self.jit_globals[name]
- elif self.compiled_function_lookup is not None:
- return self.compiled_function_lookup(name)
- else:
- return None
- def get_intrinsic(self, name):
- """Tries to find an intrinsic version of the function with the
- given name."""
- if name in self.jit_intrinsics:
- return self.jit_intrinsics[name]
- else:
- return None
- def register_intrinsic(self, name, intrinsic_function):
- """Registers the given intrisic with the JIT. This will make the JIT replace calls to
- the function with the given entry point by an application of the specified function."""
- self.jit_intrinsics[name] = intrinsic_function
- def register_binary_intrinsic(self, name, operator):
- """Registers an intrinsic with the JIT that represents the given binary operation."""
- self.register_intrinsic(name, lambda a, b: tree_ir.CreateNodeWithValueInstruction(
- tree_ir.BinaryInstruction(
- tree_ir.ReadValueInstruction(a),
- operator,
- tree_ir.ReadValueInstruction(b))))
- def register_unary_intrinsic(self, name, operator):
- """Registers an intrinsic with the JIT that represents the given unary operation."""
- self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
- tree_ir.UnaryInstruction(
- operator,
- tree_ir.ReadValueInstruction(a))))
- def register_cast_intrinsic(self, name, target_type):
- """Registers an intrinsic with the JIT that represents a unary conversion operator."""
- self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
- tree_ir.CallInstruction(
- tree_ir.LoadGlobalInstruction(target_type.__name__),
- [tree_ir.ReadValueInstruction(a)])))
- def jit_signature(self, body_id):
- """Acquires the signature for the given body id node, which consists of the
- parameter variables, parameter name and a flag that tells if the given function
- is mutable."""
- if body_id not in self.jitted_parameters:
- signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
- signature_id = signature_id[0]
- param_set_id, is_mutable = yield [
- ("RD", [signature_id, "params"]),
- ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
- if param_set_id is None:
- self.jitted_parameters[body_id] = ([], [], is_mutable)
- else:
- param_name_ids, = yield [("RDK", [param_set_id])]
- param_names = yield [("RV", [n]) for n in param_name_ids]
- param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
- self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
- raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
- def jit_compile(self, user_root, body_id, suggested_name=None):
- """Tries to jit the function defined by the given entry point id and parameter list."""
- # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
- # pylint: disable=I0011,W0122
- if body_id is None:
- raise ValueError('body_id cannot be None')
- elif body_id in self.jitted_entry_points:
- # We have already compiled this function.
- raise primitive_functions.PrimitiveFinished(
- self.jit_globals[self.jitted_entry_points[body_id]])
- elif body_id in self.no_jit_entry_points:
- # We're not allowed to jit this function or have tried and failed before.
- raise JitCompilationFailedException(
- 'Cannot jit function %s at %d because it is marked non-jittable.' % (
- '' if suggested_name is None else "'" + suggested_name + "'",
- body_id))
- elif not self.jit_enabled:
- # We're not allowed to jit anything.
- raise JitCompilationFailedException(
- 'Cannot jit function %s at %d because the JIT has been disabled.' % (
- '' if suggested_name is None else "'" + suggested_name + "'",
- body_id))
- # Generate a name for the function we're about to analyze, and pretend that
- # it already exists. (we need to do this for recursive functions)
- function_name = self.generate_function_name(suggested_name)
- self.jitted_entry_points[body_id] = function_name
- self.jit_globals[function_name] = None
- (parameter_ids, parameter_list, is_mutable), = yield [
- ("CALL_ARGS", [self.jit_signature, (body_id,)])]
- param_dict = dict(zip(parameter_ids, parameter_list))
- body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
- dependencies = set([body_id])
- self.compilation_dependencies[body_id] = dependencies
- def handle_jit_exception(exception):
- # If analysis fails, then a JitCompilationFailedException will be thrown.
- del self.compilation_dependencies[body_id]
- for dep in dependencies:
- self.mark_no_jit(dep)
- if dep in self.jitted_entry_points:
- del self.jitted_entry_points[dep]
- failure_message = "%s (function '%s' at %d)" % (
- exception.message, function_name, body_id)
- if self.jit_success_log_function is not None:
- self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
- raise JitCompilationFailedException(failure_message)
- # Try to analyze the function's body.
- yield [("TRY", [])]
- yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
- if is_mutable:
- # We can't just JIT mutable functions. That'd be dangerous.
- raise JitCompilationFailedException(
- "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
- state = AnalysisState(
- self, body_id, user_root, body_param_dict,
- self.max_instructions)
- constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_id,)])]
- yield [("END_TRY", [])]
- del self.compilation_dependencies[body_id]
- # Write a prologue and prepend it to the generated function body.
- prologue_statements = []
- # Create a LOCALS_NODE_NAME node, and connect it to the user root.
- prologue_statements.append(
- tree_ir.create_new_local_node(
- LOCALS_NODE_NAME,
- tree_ir.LoadIndexInstruction(
- tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
- tree_ir.LiteralInstruction('user_root')),
- LOCALS_EDGE_NAME))
- for (key, val) in param_dict.items():
- arg_ptr = tree_ir.create_new_local_node(
- body_param_dict[key],
- tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))
- prologue_statements.append(arg_ptr)
- prologue_statements.append(
- tree_ir.CreateDictionaryEdgeInstruction(
- tree_ir.LoadLocalInstruction(body_param_dict[key]),
- tree_ir.LiteralInstruction('value'),
- tree_ir.LoadLocalInstruction(val)))
- constructed_body = tree_ir.create_block(
- *(prologue_statements + [constructed_body]))
- # Optimize the function's body.
- constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
- # Shield temporaries from the GC.
- constructed_body = tree_ir.protect_temporaries_from_gc(
- constructed_body, tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))
- # Wrap the IR in a function definition, give it a unique name.
- constructed_function = tree_ir.DefineFunctionInstruction(
- function_name,
- parameter_list + ['**' + KWARGS_PARAMETER_NAME],
- constructed_body)
- # Convert the function definition to Python code, and compile it.
- exec(str(constructed_function), self.jit_globals)
- # Extract the compiled function from the JIT global state.
- compiled_function = self.jit_globals[function_name]
- if self.jit_success_log_function is not None:
- self.jit_success_log_function(
- "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
- if self.jit_code_log_function is not None:
- self.jit_code_log_function(constructed_function)
- raise primitive_functions.PrimitiveFinished(compiled_function)
- class AnalysisState(object):
- """The state of a bytecode analysis call graph."""
- def __init__(self, jit, body_id, user_root, local_mapping, max_instructions=None):
- self.analyzed_instructions = set()
- self.function_vars = set()
- self.local_vars = set()
- self.body_id = body_id
- self.max_instructions = max_instructions
- self.user_root = user_root
- self.jit = jit
- self.local_mapping = local_mapping
- self.function_name = jit.jitted_entry_points[body_id]
- self.enclosing_loop_instruction = None
- def get_local_name(self, local_id):
- """Gets the name for a local with the given id."""
- if local_id not in self.local_mapping:
- self.local_mapping[local_id] = 'local%d' % local_id
- return self.local_mapping[local_id]
- def register_local_var(self, local_id):
- """Registers the given variable node id as a local."""
- if local_id in self.function_vars:
- raise JitCompilationFailedException(
- "Local is used as target of function call.")
- self.local_vars.add(local_id)
- def register_function_var(self, local_id):
- """Registers the given variable node id as a function."""
- if local_id in self.local_vars:
- raise JitCompilationFailedException(
- "Local is used as target of function call.")
- self.function_vars.add(local_id)
- def retrieve_user_root(self):
- """Creates an instruction that stores the user_root variable
- in a local."""
- return tree_ir.StoreLocalInstruction(
- 'user_root',
- tree_ir.LoadIndexInstruction(
- tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
- tree_ir.LiteralInstruction('user_root')))
- def load_kernel(self):
- """Creates an instruction that loads the Modelverse kernel."""
- return tree_ir.LoadIndexInstruction(
- tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
- tree_ir.LiteralInstruction('mvk'))
- def analyze(self, instruction_id):
- """Tries to build an intermediate representation from the instruction with the
- given id."""
- # Check the analyzed_instructions set for instruction_id to avoid
- # infinite loops.
- if instruction_id in self.analyzed_instructions:
- raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
- elif (self.max_instructions is not None and
- len(self.analyzed_instructions) > self.max_instructions):
- raise JitCompilationFailedException('Maximum number of instructions exceeded.')
- self.analyzed_instructions.add(instruction_id)
- instruction_val, = yield [("RV", [instruction_id])]
- instruction_val = instruction_val["value"]
- if instruction_val in self.instruction_analyzers:
- # If tracing is enabled, then this would be an appropriate time to
- # retrieve the debug information.
- if self.jit.tracing_enabled:
- debug_info, = yield [("RD", [instruction_id, "__debug"])]
- if debug_info is not None:
- debug_info, = yield [("RV", [debug_info])]
- # Analyze the instruction itself.
- outer_result, = yield [
- ("CALL_ARGS", [self.instruction_analyzers[instruction_val], (self, instruction_id)])]
- if self.jit.tracing_enabled:
- outer_result = tree_ir.with_debug_info_trace(outer_result, debug_info, self.function_name)
- # Check if the instruction has a 'next' instruction.
- next_instr, = yield [("RD", [instruction_id, "next"])]
- if next_instr is None:
- raise primitive_functions.PrimitiveFinished(outer_result)
- else:
- next_result, = yield [("CALL_ARGS", [self.analyze, (next_instr,)])]
- raise primitive_functions.PrimitiveFinished(
- tree_ir.CompoundInstruction(
- outer_result,
- next_result))
- else:
- raise JitCompilationFailedException(
- "Unknown instruction type: '%s'" % (instruction_val))
- def analyze_all(self, instruction_ids):
- """Tries to compile a list of IR trees from the given list of instruction ids."""
- results = []
- for inst in instruction_ids:
- analyzed_inst, = yield [("CALL_ARGS", [self.analyze, (inst,)])]
- results.append(analyzed_inst)
- raise primitive_functions.PrimitiveFinished(results)
- def analyze_return(self, instruction_id):
- """Tries to analyze the given 'return' instruction."""
- retval_id, = yield [("RD", [instruction_id, 'value'])]
- def create_return(return_value):
- return tree_ir.ReturnInstruction(
- tree_ir.CompoundInstruction(
- return_value,
- tree_ir.DeleteEdgeInstruction(
- tree_ir.LoadLocalInstruction(LOCALS_EDGE_NAME))))
- if retval_id is None:
- raise primitive_functions.PrimitiveFinished(
- create_return(
- tree_ir.EmptyInstruction()))
- else:
- retval, = yield [("CALL_ARGS", [self.analyze, (retval_id,)])]
- raise primitive_functions.PrimitiveFinished(
- create_return(retval))
- def analyze_if(self, instruction_id):
- """Tries to analyze the given 'if' instruction."""
- cond, true, false = yield [
- ("RD", [instruction_id, "cond"]),
- ("RD", [instruction_id, "then"]),
- ("RD", [instruction_id, "else"])]
- analysis_results, = yield [("CALL_ARGS", [self.analyze_all, (
- [cond, true]
- if false is None
- else [cond, true, false],)])]
- if false is None:
- cond_r, true_r = analysis_results
- false_r = tree_ir.EmptyInstruction()
- else:
- cond_r, true_r, false_r = analysis_results
- raise primitive_functions.PrimitiveFinished(
- tree_ir.SelectInstruction(
- tree_ir.ReadValueInstruction(cond_r),
- true_r,
- false_r))
- def analyze_while(self, instruction_id):
- """Tries to analyze the given 'while' instruction."""
- cond, body = yield [
- ("RD", [instruction_id, "cond"]),
- ("RD", [instruction_id, "body"])]
- # Analyze the condition.
- cond_r, = yield [("CALL_ARGS", [self.analyze, (cond,)])]
- # Store the old enclosing loop on the stack, and make this loop the
- # new enclosing loop.
- old_loop_instruction = self.enclosing_loop_instruction
- self.enclosing_loop_instruction = instruction_id
- body_r, = yield [("CALL_ARGS", [self.analyze, (body,)])]
- # Restore hte old enclosing loop.
- self.enclosing_loop_instruction = old_loop_instruction
- if self.jit.nop_insertion_enabled:
- create_loop_body = lambda check, body: tree_ir.create_block(
- check,
- body_r,
- tree_ir.NopInstruction())
- else:
- create_loop_body = tree_ir.CompoundInstruction
- raise primitive_functions.PrimitiveFinished(
- tree_ir.LoopInstruction(
- create_loop_body(
- tree_ir.SelectInstruction(
- tree_ir.ReadValueInstruction(cond_r),
- tree_ir.EmptyInstruction(),
- tree_ir.BreakInstruction()),
- body_r)))
- def analyze_constant(self, instruction_id):
- """Tries to analyze the given 'constant' (literal) instruction."""
- node_id, = yield [("RD", [instruction_id, "node"])]
- raise primitive_functions.PrimitiveFinished(
- tree_ir.LiteralInstruction(node_id))
- def analyze_output(self, instruction_id):
- """Tries to analyze the given 'output' instruction."""
- # The plan is to basically generate this tree:
- #
- # value = <some tree>
- # last_output, last_output_link, new_last_output = \
- # yield [("RD", [user_root, "last_output"]),
- # ("RDE", [user_root, "last_output"]),
- # ("CN", []),
- # ]
- # _, _, _, _ = \
- # yield [("CD", [last_output, "value", value]),
- # ("CD", [last_output, "next", new_last_output]),
- # ("CD", [user_root, "last_output", new_last_output]),
- # ("DE", [last_output_link])
- # ]
- # yield None
- value_id, = yield [("RD", [instruction_id, "value"])]
- value_val, = yield [("CALL_ARGS", [self.analyze, (value_id,)])]
- value_local = tree_ir.StoreLocalInstruction('value', value_val)
- store_user_root = self.retrieve_user_root()
- last_output = tree_ir.StoreLocalInstruction(
- 'last_output',
- tree_ir.ReadDictionaryValueInstruction(
- store_user_root.create_load(),
- tree_ir.LiteralInstruction('last_output')))
- last_output_link = tree_ir.StoreLocalInstruction(
- 'last_output_link',
- tree_ir.ReadDictionaryEdgeInstruction(
- store_user_root.create_load(),
- tree_ir.LiteralInstruction('last_output')))
- new_last_output = tree_ir.StoreLocalInstruction(
- 'new_last_output',
- tree_ir.CreateNodeInstruction())
- result = tree_ir.create_block(
- value_local,
- store_user_root,
- last_output,
- last_output_link,
- new_last_output,
- tree_ir.CreateDictionaryEdgeInstruction(
- last_output.create_load(),
- tree_ir.LiteralInstruction('value'),
- value_local.create_load()),
- tree_ir.CreateDictionaryEdgeInstruction(
- last_output.create_load(),
- tree_ir.LiteralInstruction('next'),
- new_last_output.create_load()),
- tree_ir.CreateDictionaryEdgeInstruction(
- store_user_root.create_load(),
- tree_ir.LiteralInstruction('last_output'),
- new_last_output.create_load()),
- tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
- tree_ir.NopInstruction())
- raise primitive_functions.PrimitiveFinished(result)
- def analyze_input(self, _):
- """Tries to analyze the given 'input' instruction."""
- # Possible alternative to the explicit syntax tree:
- if self.jit.input_function_enabled:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.create_jit_call(
- tree_ir.LoadGlobalInstruction(GET_INPUT_FUNCTION_NAME),
- [],
- tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
- # The plan is to generate this tree:
- #
- # value = None
- # while True:
- # _input = yield [("RD", [user_root, "input"])]
- # value = yield [("RD", [_input, "value"])]
- #
- # if value is None:
- # kwargs['mvk'].success = False # to avoid blocking
- # yield None # nop/interrupt
- # else:
- # break
- #
- # _next = yield [("RD", [_input, "next"])]
- # yield [("CD", [user_root, "input", _next])]
- # yield [("CE", [jit_locals, value])]
- # yield [("DN", [_input])]
- user_root = self.retrieve_user_root()
- _input = tree_ir.StoreLocalInstruction(
- None,
- tree_ir.ReadDictionaryValueInstruction(
- user_root.create_load(),
- tree_ir.LiteralInstruction('input')))
- value = tree_ir.StoreLocalInstruction(
- None,
- tree_ir.ReadDictionaryValueInstruction(
- _input.create_load(),
- tree_ir.LiteralInstruction('value')))
- raise primitive_functions.PrimitiveFinished(
- tree_ir.CompoundInstruction(
- tree_ir.create_block(
- user_root,
- value.create_store(tree_ir.LiteralInstruction(None)),
- tree_ir.LoopInstruction(
- tree_ir.create_block(
- _input,
- value,
- tree_ir.SelectInstruction(
- tree_ir.BinaryInstruction(
- value.create_load(),
- 'is',
- tree_ir.LiteralInstruction(None)),
- tree_ir.create_block(
- tree_ir.StoreMemberInstruction(
- self.load_kernel(),
- 'success',
- tree_ir.LiteralInstruction(False)),
- tree_ir.NopInstruction()),
- tree_ir.BreakInstruction()))),
- tree_ir.CreateDictionaryEdgeInstruction(
- user_root.create_load(),
- tree_ir.LiteralInstruction('input'),
- tree_ir.ReadDictionaryValueInstruction(
- _input.create_load(),
- tree_ir.LiteralInstruction('next'))),
- tree_ir.CreateEdgeInstruction(
- tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME),
- value.create_load()),
- tree_ir.DeleteNodeInstruction(_input.create_load())),
- value.create_load()))
- def analyze_resolve(self, instruction_id):
- """Tries to analyze the given 'resolve' instruction."""
- var_id, = yield [("RD", [instruction_id, "var"])]
- var_name, = yield [("RV", [var_id])]
- # To resolve a variable, we'll do something along the
- # lines of:
- #
- # if 'local_var' in locals():
- # tmp = local_var
- # else:
- # _globals, = yield [("RD", [user_root, "globals"])]
- # global_var, = yield [("RD", [_globals, var_name])]
- #
- # if global_var is None:
- # raise Exception("Not found as global: %s" % (var_name))
- #
- # tmp = global_var
- name = self.get_local_name(var_id)
- if var_name is None:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.LoadLocalInstruction(name))
- user_root = self.retrieve_user_root()
- global_var = tree_ir.StoreLocalInstruction(
- 'global_var',
- tree_ir.ReadDictionaryValueInstruction(
- tree_ir.ReadDictionaryValueInstruction(
- user_root.create_load(),
- tree_ir.LiteralInstruction('globals')),
- tree_ir.LiteralInstruction(var_name)))
- err_block = tree_ir.SelectInstruction(
- tree_ir.BinaryInstruction(
- global_var.create_load(),
- 'is',
- tree_ir.LiteralInstruction(None)),
- tree_ir.RaiseInstruction(
- tree_ir.CallInstruction(
- tree_ir.LoadGlobalInstruction('Exception'),
- [tree_ir.LiteralInstruction(
- "Not found as global: %s" % var_name)
- ])),
- tree_ir.EmptyInstruction())
- raise primitive_functions.PrimitiveFinished(
- tree_ir.SelectInstruction(
- tree_ir.LocalExistsInstruction(name),
- tree_ir.LoadLocalInstruction(name),
- tree_ir.CompoundInstruction(
- tree_ir.create_block(
- user_root,
- global_var,
- err_block),
- global_var.create_load())))
- def analyze_declare(self, instruction_id):
- """Tries to analyze the given 'declare' function."""
- var_id, = yield [("RD", [instruction_id, "var"])]
- self.register_local_var(var_id)
- name = self.get_local_name(var_id)
- # The following logic declares a local:
- #
- # if 'local_name' not in locals():
- # local_name, = yield [("CN", [])]
- # yield [("CE", [LOCALS_NODE_NAME, local_name])]
- raise primitive_functions.PrimitiveFinished(
- tree_ir.SelectInstruction(
- tree_ir.LocalExistsInstruction(name),
- tree_ir.EmptyInstruction(),
- tree_ir.create_new_local_node(
- name,
- tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))))
- def analyze_global(self, instruction_id):
- """Tries to analyze the given 'global' (declaration) instruction."""
- var_id, = yield [("RD", [instruction_id, "var"])]
- var_name, = yield [("RV", [var_id])]
- # To resolve a variable, we'll do something along the
- # lines of:
- #
- # _globals, = yield [("RD", [user_root, "globals"])]
- # global_var = yield [("RD", [_globals, var_name])]
- #
- # if global_var is None:
- # global_var, = yield [("CN", [])]
- # yield [("CD", [_globals, var_name, global_var])]
- #
- # tmp = global_var
- user_root = self.retrieve_user_root()
- _globals = tree_ir.StoreLocalInstruction(
- '_globals',
- tree_ir.ReadDictionaryValueInstruction(
- user_root.create_load(),
- tree_ir.LiteralInstruction('globals')))
- global_var = tree_ir.StoreLocalInstruction(
- 'global_var',
- tree_ir.ReadDictionaryValueInstruction(
- _globals.create_load(),
- tree_ir.LiteralInstruction(var_name)))
- raise primitive_functions.PrimitiveFinished(
- tree_ir.CompoundInstruction(
- tree_ir.create_block(
- user_root,
- _globals,
- global_var,
- tree_ir.SelectInstruction(
- tree_ir.BinaryInstruction(
- global_var.create_load(),
- 'is',
- tree_ir.LiteralInstruction(None)),
- tree_ir.create_block(
- global_var.create_store(
- tree_ir.CreateNodeInstruction()),
- tree_ir.CreateDictionaryEdgeInstruction(
- _globals.create_load(),
- tree_ir.LiteralInstruction(var_name),
- global_var.create_load())),
- tree_ir.EmptyInstruction())),
- global_var.create_load()))
- def analyze_assign(self, instruction_id):
- """Tries to analyze the given 'assign' instruction."""
- var_id, value_id = yield [("RD", [instruction_id, "var"]),
- ("RD", [instruction_id, "value"])]
- (var_r, value_r), = yield [("CALL_ARGS", [self.analyze_all, ([var_id, value_id],)])]
- # Assignments work like this:
- #
- # value_link = yield [("RDE", [variable, "value"])]
- # _, _ = yield [("CD", [variable, "value", value]),
- # ("DE", [value_link])]
- variable = tree_ir.StoreLocalInstruction(None, var_r)
- value = tree_ir.StoreLocalInstruction(None, value_r)
- value_link = tree_ir.StoreLocalInstruction(
- 'value_link',
- tree_ir.ReadDictionaryEdgeInstruction(
- variable.create_load(),
- tree_ir.LiteralInstruction('value')))
- raise primitive_functions.PrimitiveFinished(
- tree_ir.create_block(
- variable,
- value,
- value_link,
- tree_ir.CreateDictionaryEdgeInstruction(
- variable.create_load(),
- tree_ir.LiteralInstruction('value'),
- value.create_load()),
- tree_ir.DeleteEdgeInstruction(
- value_link.create_load())))
- def analyze_access(self, instruction_id):
- """Tries to analyze the given 'access' instruction."""
- var_id, = yield [("RD", [instruction_id, "var"])]
- var_r, = yield [("CALL_ARGS", [self.analyze, (var_id,)])]
- # Accessing a variable is pretty easy. It really just boils
- # down to reading the value corresponding to the 'value' key
- # of the variable.
- #
- # value, = yield [("RD", [returnvalue, "value"])]
- raise primitive_functions.PrimitiveFinished(
- tree_ir.ReadDictionaryValueInstruction(
- var_r,
- tree_ir.LiteralInstruction('value')))
- def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
- """Tries to analyze a direct 'call' instruction."""
- self.register_function_var(callee_id)
- body_id, = yield [("RD", [callee_id, jit_runtime.FUNCTION_BODY_KEY])]
- # Make this function dependent on the callee.
- if body_id in self.jit.compilation_dependencies:
- self.jit.compilation_dependencies[body_id].add(self.body_id)
- # Figure out if the function might be an intrinsic.
- intrinsic = self.jit.get_intrinsic(callee_name)
- if intrinsic is None:
- compiled_func = self.jit.lookup_compiled_function(callee_name)
- if compiled_func is None:
- # Compile the callee.
- yield [("CALL_ARGS", [self.jit.jit_compile, (self.user_root, body_id, callee_name)])]
- else:
- self.jit.register_compiled(body_id, compiled_func, callee_name)
- # Get the callee's name.
- compiled_func_name = self.jit.get_compiled_name(body_id)
- # This handles the corner case where a constant node is called, like
- # 'call(constant(9), ...)'. In this case, `callee_name` is `None`
- # because 'constant(9)' doesn't give us a name. However, we can look up
- # the name of the function at a specific node. If that turns out to be
- # an intrinsic, then we still want to pick the intrinsic over a call.
- intrinsic = self.jit.get_intrinsic(compiled_func_name)
- # Analyze the argument dictionary.
- named_args, = yield [("CALL_ARGS", [self.analyze_arguments, (first_parameter_id,)])]
- if intrinsic is not None:
- raise primitive_functions.PrimitiveFinished(
- apply_intrinsic(intrinsic, named_args))
- else:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.create_jit_call(
- tree_ir.LoadGlobalInstruction(compiled_func_name),
- named_args,
- tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
- def analyze_arguments(self, first_argument_id):
- """Analyzes the parameter-to-argument mapping started by the specified first argument
- node."""
- next_param = first_argument_id
- named_args = []
- while next_param is not None:
- param_name_id, = yield [("RD", [next_param, "name"])]
- param_name, = yield [("RV", [param_name_id])]
- param_val_id, = yield [("RD", [next_param, "value"])]
- param_val, = yield [("CALL_ARGS", [self.analyze, (param_val_id,)])]
- named_args.append((param_name, param_val))
- next_param, = yield [("RD", [next_param, "next_param"])]
- raise primitive_functions.PrimitiveFinished(named_args)
- def analyze_indirect_call(self, func_id, first_arg_id):
- """Analyzes a call to an unknown function."""
- # First off, let's analyze the callee and the argument list.
- func_val, = yield [("CALL_ARGS", [self.analyze, (func_id,)])]
- named_args, = yield [("CALL_ARGS", [self.analyze_arguments, (first_arg_id,)])]
- # Call the __call_function function to run the interpreter, like so:
- #
- # __call_function(function_id, { first_param_name : first_param_val, ... }, **kwargs)
- #
- dict_literal = tree_ir.DictionaryLiteralInstruction(
- [(tree_ir.LiteralInstruction(key), val) for key, val in named_args])
- raise primitive_functions.PrimitiveFinished(
- tree_ir.create_jit_call(
- tree_ir.LoadGlobalInstruction(CALL_FUNCTION_NAME),
- [('function_id', func_val), ('named_arguments', dict_literal)],
- tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
- def try_analyze_direct_call(self, func_id, first_param_id):
- """Tries to analyze the given 'call' instruction as a direct call."""
- if not self.jit.direct_calls_allowed:
- raise JitCompilationFailedException('Direct calls are not allowed by the JIT.')
- # Figure out what the 'func' instruction's type is.
- func_instruction_op, = yield [("RV", [func_id])]
- if func_instruction_op['value'] == 'access':
- # 'access(resolve(var))' instructions are translated to direct calls.
- access_value_id, = yield [("RD", [func_id, "var"])]
- access_value_op, = yield [("RV", [access_value_id])]
- if access_value_op['value'] == 'resolve':
- resolved_var_id, = yield [("RD", [access_value_id, "var"])]
- resolved_var_name, = yield [("RV", [resolved_var_id])]
- # Try to look up the name as a global.
- _globals, = yield [("RD", [self.user_root, "globals"])]
- global_var, = yield [("RD", [_globals, resolved_var_name])]
- global_val, = yield [("RD", [global_var, "value"])]
- if global_val is not None:
- result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
- global_val, resolved_var_name, first_param_id)])]
- raise primitive_functions.PrimitiveFinished(result)
- elif func_instruction_op['value'] == 'constant':
- # 'const(func_id)' instructions are also translated to direct calls.
- function_val_id, = yield [("RD", [func_id, "node"])]
- result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
- function_val_id, None, first_param_id)])]
- raise primitive_functions.PrimitiveFinished(result)
- raise JitCompilationFailedException(
- "Cannot JIT function calls that target an unknown value as direct calls.")
- def analyze_call(self, instruction_id):
- """Tries to analyze the given 'call' instruction."""
- func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
- ("RD", [instruction_id, "params"])]
- def handle_exception(exception):
- # Looks like we'll have to compile it as an indirect call.
- gen = self.analyze_indirect_call(func_id, first_param_id)
- result, = yield [("CALL", [gen])]
- raise primitive_functions.PrimitiveFinished(result)
- # Try to analyze the call as a direct call.
- yield [("TRY", [])]
- yield [("CATCH", [JitCompilationFailedException, handle_exception])]
- result, = yield [("CALL_ARGS", [self.try_analyze_direct_call, (func_id, first_param_id)])]
- yield [("END_TRY", [])]
- raise primitive_functions.PrimitiveFinished(result)
- def analyze_break(self, instruction_id):
- """Tries to analyze the given 'break' instruction."""
- target_instruction_id, = yield [("RD", [instruction_id, "while"])]
- if target_instruction_id == self.enclosing_loop_instruction:
- raise primitive_functions.PrimitiveFinished(tree_ir.BreakInstruction())
- else:
- raise JitCompilationFailedException(
- "Multilevel 'break' is not supported by the baseline JIT.")
- def analyze_continue(self, instruction_id):
- """Tries to analyze the given 'continue' instruction."""
- target_instruction_id, = yield [("RD", [instruction_id, "while"])]
- if target_instruction_id == self.enclosing_loop_instruction:
- raise primitive_functions.PrimitiveFinished(tree_ir.ContinueInstruction())
- else:
- raise JitCompilationFailedException(
- "Multilevel 'continue' is not supported by the baseline JIT.")
- instruction_analyzers = {
- 'if' : analyze_if,
- 'while' : analyze_while,
- 'return' : analyze_return,
- 'constant' : analyze_constant,
- 'resolve' : analyze_resolve,
- 'declare' : analyze_declare,
- 'global' : analyze_global,
- 'assign' : analyze_assign,
- 'access' : analyze_access,
- 'output' : analyze_output,
- 'input' : analyze_input,
- 'call' : analyze_call,
- 'break' : analyze_break,
- 'continue' : analyze_continue
- }
|