cfg_ssa_construction.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. """Converts 'declare-local', 'load' and 'store' instructions into SSA form."""
  2. from collections import defaultdict
  3. import modelverse_jit.cfg_ir as cfg_ir
  4. def get_local_id(def_or_value):
  5. """Gets the node of the local resolved or declared by the given definition or value.
  6. If the given definition or value does not refer to a 'resolve-local' or
  7. 'declare-local' node, then None is returned."""
  8. value = cfg_ir.get_def_value(def_or_value)
  9. if isinstance(value, (cfg_ir.ResolveLocal, cfg_ir.DeclareLocal)):
  10. return value.variable.node_id
  11. else:
  12. return None
  13. def get_ineligible_local_ids(entry_point):
  14. """Finds the ids of all local variables which are not eligible for conversion to SSA form."""
  15. # Local variables are eligible for conversion to SSA form if their pointer node is never
  16. # leaked to the outside world. So we know that we can safely convert a local to SSA form
  17. # if 'resolve-local' values are only used by 'load' and 'store' values.
  18. ineligible_local_ids = set()
  19. def __maybe_mark_ineligible(def_or_value):
  20. local_id = get_local_id(def_or_value)
  21. if local_id is not None:
  22. ineligible_local_ids.add(value.variable.node_id)
  23. for block in cfg_ir.get_all_blocks(entry_point):
  24. for definition in block.definitions + [block.flow]:
  25. value = cfg_ir.get_def_value(definition)
  26. if isinstance(value, cfg_ir.LoadPointer):
  27. # Loading a pointer to a local is fine.
  28. pass
  29. elif isinstance(value, cfg_ir.StoreAtPointer):
  30. # Storing a value in a local is fine, too.
  31. # But be careful not to ignore a store where the stored value is a local pointer.
  32. __maybe_mark_ineligible(value.value)
  33. else:
  34. # Walk over the dependencies, and mark them all as ineligible for
  35. # local-to-SSA conversion.
  36. for dependency in value.get_all_dependencies():
  37. __maybe_mark_ineligible(dependency)
  38. return ineligible_local_ids
  39. def construct_ssa_form(entry_point):
  40. """Converts local variables into SSA form in the graph defined by the given entry point."""
  41. # Build some helper data structures.
  42. all_blocks = cfg_ir.get_all_blocks(entry_point)
  43. ineligible_locals = get_ineligible_local_ids(entry_point)
  44. predecessor_map = cfg_ir.get_all_predecessor_blocks(entry_point)
  45. # Create the SSA construction state.
  46. state = SSAConstructionState(all_blocks, ineligible_locals, predecessor_map)
  47. # Fill all blocks in the graph.
  48. for block in all_blocks:
  49. state.fill_block(block)
  50. # Update branches.
  51. for block in all_blocks:
  52. state.update_block_branches(block)
  53. # The algorithms below are based on
  54. # Simple and Efficient Construction of Static Single Assignment Form by M. Braun et al
  55. # (https://pp.info.uni-karlsruhe.de/uploads/publikationen/braun13cc.pdf).
  56. class SSAConstructionState(object):
  57. """Encapsulates state related to SSA construction."""
  58. def __init__(self, all_blocks, ineligible_locals, predecessor_map):
  59. self.all_blocks = all_blocks
  60. self.ineligible_locals = ineligible_locals
  61. self.predecessor_map = predecessor_map
  62. # `current_defs` is a local node id -> basic block -> definition map.
  63. self.current_defs = defaultdict(dict)
  64. # `incomplete_phis` is a basic block -> local node id -> block parameter def map.
  65. self.incomplete_phis = defaultdict(dict)
  66. # `extra_phi_operands` is a basic block -> block parameter def -> def map.
  67. self.extra_phi_operands = defaultdict(dict)
  68. self.processed_blocks = set()
  69. self.filled_blocks = set()
  70. self.sealed_blocks = set()
  71. def read_variable(self, block, node_id):
  72. """Reads the latest definition of the local variable with the
  73. given node id for the specified block."""
  74. if block in self.current_defs[node_id]:
  75. return self.current_defs[node_id][block]
  76. else:
  77. return self.read_variable_recursive(block, node_id)
  78. def write_variable(self, block, node_id, value):
  79. """Writes the given value to the local with the specified id in the
  80. specified block."""
  81. self.current_defs[node_id][block] = value
  82. def read_variable_recursive(self, block, node_id):
  83. """Reads the latest definition of the local variable with the
  84. given node id from one of the given block's predecessor blocks."""
  85. if block not in self.sealed_blocks:
  86. # Create an incomplete phi.
  87. val = block.append_parameter(cfg_ir.BlockParameter())
  88. self.incomplete_phis[block][node_id] = val
  89. elif len(self.predecessor_map[block]) == 1:
  90. # Optimize the common case of one predecessor: no phi needed.
  91. pred = next(iter(self.predecessor_map[block]))
  92. val = self.read_variable(pred, node_id)
  93. else:
  94. # Break potential cycles with an operandless phi.
  95. val = block.append_parameter(cfg_ir.BlockParameter())
  96. self.write_variable(block, node_id, val)
  97. val = self.add_phi_operands(node_id, val)
  98. self.write_variable(block, node_id, val)
  99. return val
  100. def add_phi_operands(self, node_id, phi_def):
  101. """Finds out which arguments branches should provide for the given block
  102. parameter definition."""
  103. # Determine operands from predecessors
  104. all_values = []
  105. for pred in self.predecessor_map[phi_def.block]:
  106. arg = self.read_variable(pred, node_id)
  107. self.extra_phi_operands[pred][phi_def] = arg
  108. all_values.append(arg)
  109. return self.try_remove_trivial_phi(phi_def, all_values)
  110. def try_remove_trivial_phi(self, phi_def, values):
  111. """Tries to remove a trivial block parameter definition."""
  112. # This is a somewhat simplified (and less powerful) version of the
  113. # algorithm in the SSA construction paper. That's kind of okay, though;
  114. # trivial phi elimination is also implemented as a separate pass in the
  115. # optimization pipeline.
  116. trivial_phi_val = cfg_ir.get_trivial_phi_value(phi_def, values)
  117. if trivial_phi_val is None:
  118. return phi_def
  119. else:
  120. phi_def.block.remove_parameter(phi_def)
  121. phi_def.redefine(trivial_phi_val)
  122. phi_def.block.prepend_definition(phi_def)
  123. return trivial_phi_val
  124. def has_sealed(self, block):
  125. """Tells if the given block has been sealed yet."""
  126. return block in self.sealed_blocks
  127. def can_seal(self, block):
  128. """Tells if the given block can be sealed right away."""
  129. # A block can be sealed if all if its predecessors have been filled.
  130. return all(
  131. [predecessor in self.filled_blocks for predecessor in self.predecessor_map[block]])
  132. def seal_all_sealable_blocks(self):
  133. """Seals all sealable blocks."""
  134. for block in self.all_blocks:
  135. if self.can_seal(block):
  136. self.seal_block(block)
  137. def seal_block(self, block):
  138. """Seals the given block."""
  139. if self.has_sealed(block):
  140. return
  141. for node_id, phi_def in self.incomplete_phis[block].items():
  142. self.add_phi_operands(node_id, phi_def)
  143. self.sealed_blocks.add(block)
  144. def has_filled(self, block):
  145. """Tells if the given block has been filled yet."""
  146. return block in self.filled_blocks
  147. def fill_block(self, block):
  148. """Visits all definitions in the given block. Locals are converted into SSA form."""
  149. if block in self.processed_blocks:
  150. return
  151. self.processed_blocks.add(block)
  152. # Try to seal the block right away if at all possible.
  153. if self.can_seal(block):
  154. self.seal_block(block)
  155. block_definitions = list(block.definitions)
  156. for definition in block_definitions:
  157. value = definition.value
  158. if cfg_ir.is_value_def(value, cfg_ir.LoadPointer):
  159. # Read the variable from the definitions dictionary.
  160. node_id = get_local_id(value.pointer)
  161. if node_id is not None and node_id not in self.ineligible_locals:
  162. definition.redefine(self.read_variable(block, node_id))
  163. elif isinstance(value, cfg_ir.StoreAtPointer):
  164. node_id = get_local_id(value.pointer)
  165. if node_id is not None and node_id not in self.ineligible_locals:
  166. # Write to the variable, and replace the definition by a 'None' literal.
  167. self.write_variable(block, node_id, value.value)
  168. definition.redefine(cfg_ir.Literal(None))
  169. elif isinstance(value, cfg_ir.DeclareLocal):
  170. node_id = value.variable.node_id
  171. if node_id not in self.ineligible_locals:
  172. definition.redefine(cfg_ir.Literal(None))
  173. # Mark the block as filled.
  174. self.filled_blocks.add(block)
  175. # Seal all sealable blocks.
  176. self.seal_all_sealable_blocks()
  177. # Fill successor blocks.
  178. for branch in block.flow.branches():
  179. self.fill_block(branch.block)
  180. def update_block_branches(self, block):
  181. """Appends arguments to the given block's flow instruction's branches, if necessary."""
  182. for branch in block.flow.branches():
  183. # Find all pairs phis which are defined in the branch target block.
  184. applicable_pairs = [
  185. (phi_def, operand_def)
  186. for phi_def, operand_def in self.extra_phi_operands[block]
  187. if phi_def.block == branch.block]
  188. if len(applicable_pairs) == 0:
  189. # We might as well early-out here.
  190. continue
  191. # Sort the pairs by block parameter index.
  192. sorted_pairs = sorted(
  193. applicable_pairs,
  194. key=lambda (phi_def, _): phi_def.block.parameters.index(phi_def))
  195. # Append arguments to the branch.
  196. for _, arg in sorted_pairs:
  197. branch.arguments.append(arg)