|
@@ -1,5 +1,6 @@
|
|
|
import math
|
|
|
import keyword
|
|
|
+import time
|
|
|
from collections import defaultdict
|
|
|
import modelverse_kernel.primitives as primitive_functions
|
|
|
import modelverse_jit.bytecode_parser as bytecode_parser
|
|
@@ -155,6 +156,7 @@ class ModelverseJit(object):
|
|
|
self.thunks_enabled = True
|
|
|
self.jit_success_log_function = None
|
|
|
self.jit_code_log_function = None
|
|
|
+ self.jit_timing_log = None
|
|
|
self.compile_function_body = compile_function_body_baseline
|
|
|
|
|
|
def set_jit_enabled(self, is_enabled=True):
|
|
@@ -210,6 +212,11 @@ class ModelverseJit(object):
|
|
|
"""Sets the function that the JIT uses to compile function bodies."""
|
|
|
self.compile_function_body = compile_function_body
|
|
|
|
|
|
+ def set_jit_timing_log(self, log_function=print_value):
|
|
|
+ """Configures this JIT instance with a function that prints output to a log.
|
|
|
+ The time it takes to compile functions is then sent to this log."""
|
|
|
+ self.jit_timing_log = log_function
|
|
|
+
|
|
|
def mark_entry_point(self, body_id):
|
|
|
"""Marks the node with the given identifier as a function entry point."""
|
|
|
if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
|
|
@@ -481,6 +488,9 @@ class ModelverseJit(object):
|
|
|
def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
|
|
|
"""Replaces the function with the given name by compiling the bytecode at the given
|
|
|
body id."""
|
|
|
+ if self.jit_timing_log is not None:
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
if compile_function_body is None:
|
|
|
compile_function_body = self.compile_function_body
|
|
|
|
|
@@ -530,6 +540,11 @@ class ModelverseJit(object):
|
|
|
self.jit_success_log_function(
|
|
|
"JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
|
|
|
|
|
|
+ if self.jit_timing_log is not None:
|
|
|
+ end_time = time.time()
|
|
|
+ compile_time = end_time - start_time
|
|
|
+ self.jit_timing_log('Compile time for %s:%f' % (function_name, compile_time))
|
|
|
+
|
|
|
raise primitive_functions.PrimitiveFinished(compiled_function)
|
|
|
|
|
|
def get_source_map_name(self, function_name):
|