Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Bug: multiple issues with input/output determination #375

Open
GMNGeoffrey opened this issue Jan 10, 2025 · 0 comments
Open

[TKW] Bug: multiple issues with input/output determination #375

GMNGeoffrey opened this issue Jan 10, 2025 · 0 comments

Comments

@GMNGeoffrey
Copy link
Contributor

The current logic for inputs and outputs has a number of issues:

def determine_input_output_buffers(self, graph: fx.Graph):
# Extract all placeholder nodes.
placeholder_nodes = filter_fx_graph(graph, is_placeholder)
def only_read_dependencies(node):
return all([isinstance(get_custom(x), Read) for x in node.users.keys()])
def only_write_dependencies(node):
if len(node.users) == 0:
return False
return all([isinstance(get_custom(x), Write) for x in node.users.keys()])
for node in placeholder_nodes:
index = None
for i, binding in enumerate(self.bindings):
if binding.reference[1] == node:
index = i
break
if index == None:
continue
# TODO: Match KernelBufferUsage to what bufferType that is expected on IREE.
usage = KernelBufferUsage.INPUT
if only_read_dependencies(node):
usage = KernelBufferUsage.INPUT
if only_write_dependencies(node):
usage = KernelBufferUsage.OUTPUT
# Create new Memory type with the correct usage
memory_type = self.bindings[index].kernel_buffer_type
self.bindings[index].kernel_buffer_type = Memory[
(
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
)
]
return

(amongst other places)

  1. If something has any dependency other than a write, it gets labeled an input, which then marks it as readonly and it comes out unmodified.
  2. Writes inside of reductions are not considered for classifying an output.
  3. Everything designated an output gets moved to the end of the signature, so when the compiled kernel is called the arguments are all scrambled.
  4. Arguments that are unused somehow get stripped out during compilation and then when the compiled kernel is called the arguments are all scrambled.

So far, I've found that the easiest thing was to just mark everything as an output. Not sure what other implications that has besides passing up optimizations dependent on things being read-only. I think this also still doesn't fix problem 4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant