Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fix CFG aliasing error with matrix of matrix #8445

Merged
merged 1 commit into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) {
if (after_lower_access &&
!((data_source_ptr->is<MatrixPtrStmt>() &&
data_source_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(data_source_ptr->is<MatrixPtrStmt>() &&
data_source_ptr->as<MatrixPtrStmt>()
->origin->is<MatrixPtrStmt>()) ||
data_source_ptr->is<AllocaStmt>())) {
// After lower_access, we only analyze local variables.
continue;
Expand Down Expand Up @@ -555,6 +558,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
if (!contain_variable(live_kill, load_ptr)) {
Expand All @@ -581,6 +586,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
if (!after_lower_access ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
live_kill.insert(store_ptr);
Expand All @@ -589,32 +596,41 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
}
}

static void update_aliased_stmts(
static void recursive_update_aliased_elements(
const std::unordered_map<Stmt *, std::vector<Stmt *>>
&tensor_to_matrix_ptrs_map,
const std::unordered_map<Stmt *, Stmt *> &matrix_ptr_to_tensor_map,
std::unordered_map<Stmt *, CFGNode::UseDefineStatus> &container,
Stmt *key,
bool to_erase) {
if (tensor_to_matrix_ptrs_map.find(key) != tensor_to_matrix_ptrs_map.end()) {
auto scalars_address = tensor_to_matrix_ptrs_map.at(key);
const auto &elements_address = tensor_to_matrix_ptrs_map.at(key);
// Update aliased MatrixPtrStmt for TensorType<>*
for (auto scalar_address : scalars_address) {
for (const auto &element_address : elements_address) {
if (to_erase) {
if (container.find(scalar_address) != container.end()) {
TI_ASSERT(container[scalar_address] ==
CFGNode::UseDefineStatus::NONE);
container.erase(scalar_address);
if (container.find(element_address) != container.end()) {
container.erase(element_address);
}
} else {
container[scalar_address] = CFGNode::UseDefineStatus::NONE;
container[element_address] = CFGNode::UseDefineStatus::NONE;
if (element_address->ret_type.ptr_removed()->is<TensorType>()) {
container[element_address] = CFGNode::UseDefineStatus::FULL;
}
}

// Recursively update aliased addresses
recursive_update_aliased_elements(tensor_to_matrix_ptrs_map, container,
element_address, to_erase);
}
}
}

// Update aliased TensorType<>* for MatrixPtrStmt
static void recursive_update_aliased_parent(
const std::unordered_map<Stmt *, Stmt *> &matrix_ptr_to_tensor_map,
std::unordered_map<Stmt *, CFGNode::UseDefineStatus> &container,
Stmt *key,
bool to_erase) {
if (matrix_ptr_to_tensor_map.find(key) != matrix_ptr_to_tensor_map.end()) {
auto tensor_address = matrix_ptr_to_tensor_map.at(key);
const auto &tensor_address = matrix_ptr_to_tensor_map.at(key);
// no matter to_erase or not, the tensor_address is only partially defined
// or used
if (to_erase) {
Expand All @@ -624,9 +640,29 @@ static void update_aliased_stmts(
} else {
container[tensor_address] = CFGNode::UseDefineStatus::PARTIAL;
}

// Recursively update aliased addresses
recursive_update_aliased_parent(matrix_ptr_to_tensor_map, container,
tensor_address, to_erase);
}
}

static void update_aliased_stmts(
const std::unordered_map<Stmt *, std::vector<Stmt *>>
&tensor_to_matrix_ptrs_map,
const std::unordered_map<Stmt *, Stmt *> &matrix_ptr_to_tensor_map,
std::unordered_map<Stmt *, CFGNode::UseDefineStatus> &container,
Stmt *key,
bool to_erase) {
// Update aliased MatrixPtrStmt for TensorType<>*
recursive_update_aliased_elements(tensor_to_matrix_ptrs_map, container, key,
to_erase);

// Update aliased TensorType<>* for MatrixPtrStmt
recursive_update_aliased_parent(matrix_ptr_to_tensor_map, container, key,
to_erase);
}

// Insert or erase "key" to "container".
// In case where "key" being MatrixPtrStmt, we also update the aliased original
// address. In case where "key" is involved with TensorType, we also update the
Expand All @@ -648,6 +684,7 @@ static void update_container_with_alias(
} else {
container[key] = CFGNode::UseDefineStatus::NONE;
}
// Recursively update aliased addresses
update_aliased_stmts(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map,
container, key, to_erase);
}
Expand Down Expand Up @@ -715,6 +752,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (!after_lower_access ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
// !may_contain_variable(live_in_this_node, store_ptr): address is not
// loaded after this store
Expand Down Expand Up @@ -816,6 +855,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// live_load_in_this_node[addr]: tracks the
// next load to the same address
Expand Down Expand Up @@ -844,6 +885,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// Addr is used in this node, so it's live in this node
update_container_with_alias(tensor_to_matrix_ptrs_map,
Expand Down Expand Up @@ -976,6 +1019,8 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
auto stmt = nodes[i]->block->statements[j].get();
if ((stmt->is<MatrixPtrStmt>() &&
stmt->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(stmt->is<MatrixPtrStmt>() &&
stmt->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(!after_lower_access &&
(stmt->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() ||
stmt->is<BlockLocalPtrStmt>() || stmt->is<ThreadLocalPtrStmt>() ||
Expand Down
9 changes: 4 additions & 5 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,12 +843,11 @@ bool IndexExpression::is_local() const {
}

bool IndexExpression::is_global() const {
// Special case: Indexing into TensorType-element of ExternalPtrStmt
// or GlobalPtrStmt should be treated as global ptrs
if (var.is<IndexExpression>()) {
TI_ASSERT(var.cast<IndexExpression>()->is_matrix_field() ||
var.cast<IndexExpression>()->is_ndarray());
return true;
// Special case: Pointer chasing. For example, if we are indexing into
// tensor elements of fields / ndarrays, this index expr should be treated
// as global.
return var.cast<IndexExpression>()->is_global();
}

// Only Ndarray and Field comes outside from a kernel
Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_shared_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,19 @@ def test():

test()
assert (y.to_numpy()[0] == [4.0, 8.0, 12.0, 16.0]).all()


@test_utils.test(arch=[ti.cuda], debug=True)
def test_shared_array_matrix():
@ti.kernel
def foo():
for x in range(10):
shared = ti.simt.block.SharedArray((10,), dtype=ti.math.vec3)
shared[x] = ti.Vector([x + 1, x + 2, x + 3])
assert shared[x].z == x + 3
assert (shared[x] == ti.Vector([x + 1, x + 2, x + 3])).all()

print(shared[x].z)
print(shared[x])

foo()
Loading