Skip to content

Commit

Permalink
[Flow] Fix dispatch naming for dynamic shaped fusions (#19439)
Browse files Browse the repository at this point in the history
Currently all ops with dynamic shapes are assigned the same estimated
cost when naming dispatches. This means that in cases like fused
elementwise ops with matmuls, the elementwise and matmuls are assigned
the same priority and because of traversal order, the dispatch ends up
following the name of the elementwise op.

This patch hacks it by treating all dynamic shapes as moderately sized
static shapes, but in the future if we have more issues we can look at
adding some tensor size range analysis that can give us upper bounds for
the dynamic shapes.

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
Co-authored-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
qedawkins and IanWood1 authored Jan 9, 2025
1 parent 9055c9d commit a7bac5d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,26 @@ static constexpr int64_t kMaxCost = INT64_MAX;

namespace {

// This op estimates the cost of a list of perfectly nested loop ranges simply
// as the product of ranges. Note that this does not take into account the cost
// of the body of the op whose domain this computes.
static int64_t costOfDomain(ArrayRef<int64_t> domain) {
int64_t product = 1;
for (int64_t size : domain) {
int64_t multiplier = size;
if (ShapedType::isDynamic(size)) {
// HACK: Use a placeholder value for dynamic sizes. In practice, because
// we tend to require that iteration spaces of linalg ops line up for
// fusion to occur, more dynamic dims => a larger iteration domain.
// TODO: Query the upper bound of the dynamic size range instead.
multiplier = 1024;
}

// Preform saturating multiplication
if (product > kMaxCost / multiplier) {
return kMaxCost;
}
product *= size;
product *= multiplier;
}
return product;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,39 @@ flow.executable private @ex {
}
}
}

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>

flow.executable private @ex {
// CHECK: flow.executable.export public @dispatch_matmul_like_16xDx8_f32
flow.executable.export public @dispatch
builtin.module {
func.func @dispatch(%arg0: !flow.dispatch.tensor<readwrite:tensor<16x?xf32>>, %arg1: index) {
%0 = tensor.empty() : tensor<16x8xf32>
%1 = tensor.empty(%arg1) : tensor<8x?xf32>
%init = flow.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [16, %arg1], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<16x?xf32>>{%arg1} -> tensor<16x?xf32>
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%0, %1 : tensor<16x8xf32>, tensor<8x?xf32>) outs(%init : tensor<16x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%3 = arith.mulf %in, %in_0 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<16x?xf32>
%3 = linalg.generic {
indexing_maps = [#map3, #map3],
iterator_types = ["parallel", "parallel"]
} ins(%2 : tensor<16x?xf32>) outs(%2 : tensor<16x?xf32>) {
^bb0(%in: f32, %out: f32):
%4 = math.rsqrt %in : f32
linalg.yield %4 : f32
} -> tensor<16x?xf32>
flow.dispatch.tensor.store %3, %arg0, offsets = [0, 0], sizes = [16, %arg1], strides = [1, 1] : tensor<16x?xf32> -> !flow.dispatch.tensor<readwrite:tensor<16x?xf32>>{%arg1}
return
}
}
}

0 comments on commit a7bac5d

Please sign in to comment.