cfg_data_structures.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """Defines optimizations that replace Modelverse data structures by Python data structures."""
  2. import modelverse_jit.cfg_ir as cfg_ir
  3. SET_DEF_REWRITE_RULES = {
  4. 'dict_keys':
  5. lambda def_def:
  6. def_def.redefine(
  7. cfg_ir.DirectFunctionCall(
  8. cfg_ir.REVERSE_LIST_MACRO_NAME,
  9. [('seq',
  10. def_def.insert_before(
  11. cfg_ir.DirectFunctionCall(
  12. cfg_ir.READ_DICT_KEYS_MACRO_NAME,
  13. cfg_ir.get_def_value(def_def).argument_list,
  14. cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION,
  15. has_side_effects=False)))],
  16. cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION,
  17. has_side_effects=False))
  18. }
  19. def __redefine_as_list_len(use_def, def_def):
  20. use_def.redefine(
  21. cfg_ir.CreateNode(
  22. use_def.insert_before(
  23. cfg_ir.DirectFunctionCall(
  24. 'len', [('seq', def_def)],
  25. cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION,
  26. has_side_effects=False))))
  27. SET_USE_REWRITE_RULES = {
  28. ('set_pop', 'a'):
  29. lambda use_def, def_def:
  30. use_def.redefine(
  31. cfg_ir.DirectFunctionCall(
  32. 'pop', [('self', def_def)],
  33. cfg_ir.SELF_POSITIONAL_CALLING_CONVENTION)),
  34. ('list_len', 'a'): __redefine_as_list_len,
  35. ('read_nr_out', 'a'): __redefine_as_list_len
  36. }
  37. def get_call_def_rewriter(definition, def_rewrite_rules):
  38. """Gets an appropriate rewrite rule from the given dictionary of call rewrite rules."""
  39. if cfg_ir.is_call(definition, calling_convention=cfg_ir.JIT_CALLING_CONVENTION):
  40. call = cfg_ir.get_def_value(definition)
  41. if call.target_name in def_rewrite_rules:
  42. return def_rewrite_rules[call.target_name]
  43. return None
  44. def get_call_use_rewriter(use_definition, def_definition, use_rewrite_rules):
  45. """Gets an appropriate rewrite rule from the given dictionary of call rewrite rules."""
  46. if cfg_ir.is_call(use_definition, calling_convention=cfg_ir.JIT_CALLING_CONVENTION):
  47. call = cfg_ir.get_def_value(use_definition)
  48. for arg_name, arg_def in call.argument_list:
  49. if arg_def == def_definition:
  50. key = (call.target_name, arg_name)
  51. if key in use_rewrite_rules:
  52. return use_rewrite_rules[key]
  53. return None
  54. def apply_rewrite_rules(
  55. entry_point,
  56. get_def_rewriter,
  57. get_use_rewriter):
  58. """Applies the given definition and use rewrite rules to all definitions and uses where a
  59. rewrite rule can be found for both the definitions and the uses."""
  60. # pylint: disable=I0011,W0108
  61. cfg_ir.match_and_rewrite(
  62. entry_point,
  63. lambda def_def:
  64. get_def_rewriter(def_def) is not None,
  65. lambda use_def, def_def:
  66. get_use_rewriter(use_def, def_def) is not None,
  67. lambda def_def:
  68. get_def_rewriter(def_def)(def_def),
  69. lambda use_def, def_def:
  70. get_use_rewriter(use_def, def_def)(use_def, def_def))
  71. def optimize_data_structures(entry_point):
  72. """Optimizes data structures in the graph defined by the given entry point."""
  73. apply_rewrite_rules(
  74. entry_point,
  75. lambda def_def:
  76. get_call_def_rewriter(def_def, SET_DEF_REWRITE_RULES),
  77. lambda use_def, def_def:
  78. get_call_use_rewriter(use_def, def_def, SET_USE_REWRITE_RULES))