Skip to content

Commit

Permalink
Merge branch 'main' into llu/fix_index_1dtma_warpspecialization
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiangxl authored Feb 13, 2025
2 parents 37ddfd1 + 94ec5b3 commit d99a60f
Show file tree
Hide file tree
Showing 16 changed files with 705 additions and 153 deletions.
59 changes: 59 additions & 0 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,65 @@ class BFSWithPermissiveDependence
}
};

// Unlike the default BFS behavior, Val is considered ready to
// visit only if all of definitions or uses are visited. The default
// BFS only requires one definition is visited.
template <
typename ExprT,
typename ValT,
typename DefinitionT,
typename UsesT,
typename InputsT,
typename OutputsT>
class BFSWithStrictDependence
: public BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT> {
public:
using NodeType =
typename BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>::
NodeType;

BFSWithStrictDependence(
DefinitionT definition,
UsesT uses,
InputsT inputs,
OutputsT outputs,
std::vector<NodeType> from,
std::vector<NodeType> to,
bool require_all_to_visited = true,
Direction allowed_direction = Direction::Undefined)
: BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>(
definition,
uses,
inputs,
outputs,
std::move(from),
std::move(to),
require_all_to_visited,
allowed_direction) {}

std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
const ValT& v) const override {
decltype(auto) uses = this->uses_(v);
if (!uses.empty() &&
std::all_of(uses.begin(), uses.end(), [&](const ExprT& use_e) -> bool {
return this->isDependencySatisfied(use_e);
})) {
return std::make_pair(
Direction::Backward, std::vector<NodeType>{uses.begin(), uses.end()});
}
decltype(auto) def = this->definition_(v);
if (!def.empty() &&
std::all_of(def.begin(), def.end(), [&](const ExprT& def_e) -> bool {
return this->isDependencySatisfied(def_e);
})) {
return std::make_pair(
Direction::Forward, std::vector<NodeType>{def.begin(), def.end()});
}

return std::nullopt;
}
};

// Find the shortest path from the from vals to the to
// vals. Dependency between vals and exprs must be satisfied.
// It is an error if no valid path is found unless
Expand Down
5 changes: 5 additions & 0 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ ValGraph& IdModel::buildExactGraph() {
for (auto expr : tv_exprs_) {
TensorView* c_tv = ir_utils::getTvOutput(expr);

NVF_ERROR(
c_tv != nullptr,
"Expected to have a TensorView output: ",
expr->toString());

auto all_tv_outputs = ir_utils::filterByType<TensorView>(expr->outputs());

// Map siblings, as all other tv output domains must match the first tv
Expand Down
4 changes: 4 additions & 0 deletions csrc/id_model/id_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ class IdModel : public PolymorphicBase {
return tvs_;
}

const std::vector<Expr*>& tvExprs() const {
return tv_exprs_;
}

Fusion* fusion() const {
return fusion_;
}
Expand Down
5 changes: 1 addition & 4 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,10 +748,7 @@ TensorIndexer::TensorIndexer(IdModel& id_model) : id_model_(id_model) {
buildLoopIndexMap();

if (isDebugDumpEnabled(DebugDumpOption::IndexingVerbose)) {
std::ofstream ofs("indexing_traversal_graph.dot", std::ofstream::trunc);
auto dot_string = traversalGraph().toGraphvizDotGraph();
ofs << dot_string;
ofs.close();
traversalGraph().dumpGraphvizDotGraph("indexing_traversal_graph.dot");
}
}

Expand Down
143 changes: 97 additions & 46 deletions csrc/id_model/loop_promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <id_model/to_string.h>
#include <ir/utils.h>
#include <iter_visitor.h>
#include <logical_domain_map.h>
#include <options.h>
#include <val_graph_visitor.h>

namespace nvfuser {
Expand Down Expand Up @@ -106,7 +108,7 @@ namespace {
// finding the promotion ID is a trivial probelm. Only the
// loop groups of the loop domains need to be checked as loop
// promotion does not matter for the other domains.
bool isLoopGraphAlmostUniform(const IdModel& id_model) {
bool isLoopGraphUniform(const IdModel& id_model) {
for (const auto tv : id_model.tvs()) {
if (tv->isFusionInput()) {
continue;
Expand All @@ -116,47 +118,103 @@ bool isLoopGraphAlmostUniform(const IdModel& id_model) {
id_model.idGraph(IdMappingMode::LOOP).toGroup(loop_id);
const auto all_exact_groups =
id_model.idGraph(IdMappingMode::EXACT).toGroups(*loop_group);
if (all_exact_groups.size() == 1) {
continue;
if (all_exact_groups.size() > 1) {
return false;
}
}
}

// Even when multiple exact groups are found, if there's only
// one concrete group and all the others are broadcast, it's
// obvious that the concrete group represents the promotion.
bool concrete_group_found = false;
for (const auto& exact_group : all_exact_groups) {
if (!exact_group->front()->as<IterDomain>()->isBroadcast()) {
if (concrete_group_found) {
// multiple concrete groups
return false;
}
concrete_group_found = true;
return true;
}

} // namespace

ValGroups LoopPromotionMapBuilder::getInputGroupsOfExactGraph(
const ValGraph& exact_graph) const {
std::unordered_set<IterDomain*> non_input_ids;

for (auto tv_expr : id_model_.tvExprs()) {
for (const auto producer :
ir_utils::filterByType<TensorView>(tv_expr->inputs())) {
for (const auto consumer :
ir_utils::filterByType<TensorView>(tv_expr->outputs())) {
auto p2c = PairwiseLogicalDomainMap(producer, consumer)
.mapBroadcast(false)
.mapProducerToConsumer();
for (const auto& [p_id, c_id] : p2c) {
non_input_ids.insert(c_id);
}
}
}
}

return true;
ValGroups input_groups;
for (const auto tv : id_model_.tvs()) {
for (const auto maybe_root_id : tv->getMaybeRootDomain()) {
if (!non_input_ids.count(maybe_root_id)) {
input_groups.pushBack(exact_graph.toGroup(maybe_root_id));
}
}
}

// Remove redundancy. There may be dependencies between inputs. For
// example:
//
// Fusion inputs:
// T0: [i0, i1]
// T1: [i2]
//
// T2 = reshape(T0, {i0, i1}, {i0*i1});
// T3 = add(T2, T1)
//
// In this case, i2 forms an input group but is redundant as there
// are i0 and i1. In fact, traversing from {i0, i1, i2} would miss
// the expr between {i0, i1} and {i2}.

ValGroups input_groups_to_keep;
for (auto it = input_groups.begin(); it != input_groups.end(); ++it) {
const ValGroup& input = *it;

ValGroups other_inputs = input_groups_to_keep;
other_inputs.pushBack(it + 1, input_groups.end());
if (ValGraphBFS::getExprGroupsBetween(
exact_graph,
other_inputs,
{input},
/*require_all_to_visited=*/false,
Direction::Forward)
.second) {
// This input group is redundant with respect
continue;
} else {
input_groups_to_keep.pushBack(input);
}
}

return input_groups_to_keep;
}

} // namespace
ValGroups LoopPromotionMapBuilder::getInputGroupsOfIELGraph(
const ValGraph& iel_graph) const {
const auto exact_input_groups =
getInputGroupsOfExactGraph(idGraph(IdMappingMode::EXACT));

ValGroups iel_input_groups;
for (const ValGroup& exact_input_group : exact_input_groups) {
iel_input_groups.pushBack(iel_graph.toGroups(*exact_input_group));
}

return iel_input_groups;
}

std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::build() {
// Some quick shortcut conditions to skip the full loop promotion
// analysis. These are not comprehensive. Should add more conditions
// if necessary.
if (!force_full_loop_promotion_analysis_ &&
isLoopGraphAlmostUniform(id_model_)) {
if (!force_full_loop_promotion_analysis_ && isLoopGraphUniform(id_model_)) {
return buildWithNoBroadcast();
}

// Cyclic exact graph is not supported. Specifically,
// computeCoveredGroups would fail as it uses ValGraphStmtSort.
NVF_ERROR(
!isCyclic(idGraph(IdMappingMode::EXACT)),
"Cyclic exact graph is not supported: ",
idGraph(IdMappingMode::EXACT).toString());

// Make an intersection of the exact and loop map. This will group together
// entries in each loop group that are exact with each other. This provides a
// better graph to do promotion and replays.
Expand Down Expand Up @@ -611,7 +669,8 @@ void LoopPromotionMapBuilder::propagatePromotionsInIELGraph(
const std::unordered_map<ValGroup, IterDomain*>& loop_graph_promotion_map) {
// In order to make this traversal work, the traversal order must be
// topologically sorted.
ValGraphStmtSort iel_stmt_sort(iel_graph);
ValGraphStmtSort iel_stmt_sort(
iel_graph, getInputGroupsOfIELGraph(iel_graph));

for (const ExprGroup& iel_expr : iel_stmt_sort.exprs()) {
NVF_ERROR(!iel_expr->empty());
Expand Down Expand Up @@ -713,21 +772,20 @@ void LoopPromotionMapBuilder::propagatePromotionsInIELGraph(
iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), {});
}

namespace {

// Returns for each ValGroup in provided IdGraph what the input ValGroups are
// traversing on definitions. Ignoring broadcast ValGroups and resetting inputs
// at RFactor ValGroups.
std::unordered_map<ValGroup, ValGroups> computeCoveredGroups(
const ValGraph& graph) {
std::unordered_map<ValGroup, ValGroups> LoopPromotionMapBuilder::
computeCoveredGroups(const ValGraph& graph) const {
// Map from an exact iter domain group, to all the exact iter domain groups it
// covers
std::unordered_map<ValGroup, ValGroups> covered_ids;

ValGroups input_groups = getInputGroupsOfExactGraph(graph);

for (const ValGroup& id_group : graph.disjointValSets().disjointSets()) {
// Initialize inputs
const ExprGroups& id_group_defs = graph.getDefinitions(id_group);
if (id_group_defs.empty()) {
if (input_groups.has(id_group)) {
covered_ids[id_group] = {id_group};
}

Expand All @@ -740,9 +798,9 @@ std::unordered_map<ValGroup, ValGroups> computeCoveredGroups(
}
}

ValGraphStmtSort exact_stmt_sort(graph);
ValGraphStmtSort stmt_sort(graph, input_groups);

for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) {
for (const ExprGroup& exact_expr : stmt_sort.exprs()) {
std::vector<ValGroup> input_groups = graph.inputGroups(exact_expr);

ValGroups covered;
Expand All @@ -763,8 +821,6 @@ std::unordered_map<ValGroup, ValGroups> computeCoveredGroups(
return covered_ids;
}

}; // namespace

std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::
projectIELPromotionToLoopGraph(
const ValGraph& iel_graph,
Expand Down Expand Up @@ -857,7 +913,10 @@ IterDomain* LoopPromotionMapBuilder::findPromotionOfLoopGroup(
ValGroups loop_group_covered_ids;
for (const ValGroup& exact_group : exact_groups) {
auto covered_it = exact_covered_ids.find(exact_group);
NVF_ERROR(covered_it != exact_covered_ids.end());
NVF_ERROR(
covered_it != exact_covered_ids.end(),
"No covered group info for ",
nvfuser::toString(exact_group));
loop_group_covered_ids.pushBack(covered_it->second);
}

Expand Down Expand Up @@ -988,21 +1047,13 @@ std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::
(int64_t)StmtSort::getExprsTo({loop_id->extent()}).size();
auto this_is_const = loop_id->extent()->isConstInt();

// A group is allowed to have one single exact group of concrete
// IDs with a broadcast group.
if (promotion == nullptr ||
(promotion->isBroadcast() && !loop_id->isBroadcast())) {
if (promotion == nullptr) {
is_const = this_is_const;
promotion = loop_id;
num_exprs = this_num_exprs;
continue;
}

// Ignore broadcast if a concrete ID is already found
if (!promotion->isBroadcast() && loop_id->isBroadcast()) {
continue;
}

// If new ID is non-const while the current promotion is const,
// or if both IDs are const or non-const and the number of
// expressions is not smaller, keep the current promotion
Expand Down
39 changes: 39 additions & 0 deletions csrc/id_model/loop_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,45 @@ class LoopPromotionMapBuilder {
LoopPromotionMapBuilderCallback* callback = nullptr,
bool force_full_loop_promotion_analysis = false);

std::unordered_map<ValGroup, ValGroups> computeCoveredGroups(
const ValGraph& graph) const;

// Given an Exact graph, get val groups that should be used as
// starting groups when propagating promotion info. For non-cyclic
// graphs, this should be equivalent to what ValGraph::getTerminatingInputs()
// returns. For cyclic graphs, there may be no terminating inputs
// due to a cyclic dependency, so getTerminatingInputs() may return
// just nothing.
//
// Instead, we first find input iter domains, which are (maybe) root
// iter domains that have no corresponding producer iter domains as
// defined by PairwiseLogicalDomainMap. Any exact groups that
// include any of the input iter domains are considered input
// groups.
//
// For example, given a graph like shown below:
//
// i0 -> i1 -> i2 -> i3
// ^ |
// +----------+
//
// Here, i0 represents a Val group that contains IDs of fusion input
// tensors.
//
// ValGraph::getTerminatingInputs would return nothing as there's no
// terminating input. However, when this is used in
// computeCoveredGroups, what we need to do is to propagate the
// informatiom of the IDs of the fusion inputs, i.e., i0, so the
// propagation should start from i0, then i1, i2 and i3, ignoring
// the back edge to i0.
ValGroups getInputGroupsOfExactGraph(const ValGraph& exact_graph) const;

// Similar to getInputGroupsOfExactGraph but for an IEL graph.
// We first get the inputs of the Exact graph. For the
// IEL propagation, any IEL group that has an ID that is included
// in any of the input groups of the exact graph is used as an input.
ValGroups getInputGroupsOfIELGraph(const ValGraph& iel_graph) const;

std::unordered_map<ValGroup, IterDomain*> build();

// Shortcut to build a map of promotion IDs without doing the full
Expand Down
Loading

0 comments on commit d99a60f

Please sign in to comment.