Skip to content

Commit

Permalink
[Bugfix] Fix CFG aliasing error with matrix of matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Dec 22, 2023
1 parent 59156b3 commit 7563ab8
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 16 deletions.
67 changes: 56 additions & 11 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) {
if (after_lower_access &&
!((data_source_ptr->is<MatrixPtrStmt>() &&
data_source_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(data_source_ptr->is<MatrixPtrStmt>() &&
data_source_ptr->as<MatrixPtrStmt>()
->origin->is<MatrixPtrStmt>()) ||
data_source_ptr->is<AllocaStmt>())) {
// After lower_access, we only analyze local variables.
continue;
Expand Down Expand Up @@ -555,6 +558,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
if (!contain_variable(live_kill, load_ptr)) {
Expand All @@ -581,6 +586,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
if (!after_lower_access ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
live_kill.insert(store_ptr);
Expand All @@ -589,32 +596,41 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
}
}

static void update_aliased_stmts(
static void recursive_update_aliased_elements(
const std::unordered_map<Stmt *, std::vector<Stmt *>>
&tensor_to_matrix_ptrs_map,
const std::unordered_map<Stmt *, Stmt *> &matrix_ptr_to_tensor_map,
std::unordered_map<Stmt *, CFGNode::UseDefineStatus> &container,
Stmt *key,
bool to_erase) {
if (tensor_to_matrix_ptrs_map.find(key) != tensor_to_matrix_ptrs_map.end()) {
auto scalars_address = tensor_to_matrix_ptrs_map.at(key);
const auto &elements_address = tensor_to_matrix_ptrs_map.at(key);
// Update aliased MatrixPtrStmt for TensorType<>*
for (auto scalar_address : scalars_address) {
for (const auto &element_address : elements_address) {
if (to_erase) {
if (container.find(scalar_address) != container.end()) {
TI_ASSERT(container[scalar_address] ==
CFGNode::UseDefineStatus::NONE);
container.erase(scalar_address);
if (container.find(element_address) != container.end()) {
container.erase(element_address);
}
} else {
container[scalar_address] = CFGNode::UseDefineStatus::NONE;
container[element_address] = CFGNode::UseDefineStatus::NONE;
if (element_address->ret_type.ptr_removed()->is<TensorType>()) {
container[element_address] = CFGNode::UseDefineStatus::FULL;
}
}

// Recursively update aliased addresses
recursive_update_aliased_elements(tensor_to_matrix_ptrs_map, container,
element_address, to_erase);
}
}
}

// Update aliased TensorType<>* for MatrixPtrStmt
static void recursive_update_aliased_parent(
const std::unordered_map<Stmt *, Stmt *> &matrix_ptr_to_tensor_map,
std::unordered_map<Stmt *, CFGNode::UseDefineStatus> &container,
Stmt *key,
bool to_erase) {
if (matrix_ptr_to_tensor_map.find(key) != matrix_ptr_to_tensor_map.end()) {
auto tensor_address = matrix_ptr_to_tensor_map.at(key);
const auto &tensor_address = matrix_ptr_to_tensor_map.at(key);
// no matter to_erase or not, the tensor_address is only partially defined
// or used
if (to_erase) {
Expand All @@ -624,9 +640,29 @@ static void update_aliased_stmts(
} else {
container[tensor_address] = CFGNode::UseDefineStatus::PARTIAL;
}

// Recursively update aliased addresses
recursive_update_aliased_parent(matrix_ptr_to_tensor_map, container,
tensor_address, to_erase);
}
}

static void update_aliased_stmts(
const std::unordered_map<Stmt *, std::vector<Stmt *>>
&tensor_to_matrix_ptrs_map,
const std::unordered_map<Stmt *, Stmt *> &matrix_ptr_to_tensor_map,
std::unordered_map<Stmt *, CFGNode::UseDefineStatus> &container,
Stmt *key,
bool to_erase) {
// Update aliased MatrixPtrStmt for TensorType<>*
recursive_update_aliased_elements(tensor_to_matrix_ptrs_map, container, key,
to_erase);

// Update aliased TensorType<>* for MatrixPtrStmt
recursive_update_aliased_parent(matrix_ptr_to_tensor_map, container, key,
to_erase);
}

// Insert or erase "key" to "container".
// In case where "key" being MatrixPtrStmt, we also update the aliased original
// address. In case where "key" is involved with TensorType, we also update the
Expand All @@ -648,6 +684,7 @@ static void update_container_with_alias(
} else {
container[key] = CFGNode::UseDefineStatus::NONE;
}
// Recursively update aliased addresses
update_aliased_stmts(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map,
container, key, to_erase);
}
Expand Down Expand Up @@ -715,6 +752,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (!after_lower_access ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
// !may_contain_variable(live_in_this_node, store_ptr): address is not
// loaded after this store
Expand Down Expand Up @@ -816,6 +855,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// live_load_in_this_node[addr]: tracks the
// next load to the same address
Expand Down Expand Up @@ -844,6 +885,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// Addr is used in this node, so it's live in this node
update_container_with_alias(tensor_to_matrix_ptrs_map,
Expand Down Expand Up @@ -976,6 +1019,8 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
auto stmt = nodes[i]->block->statements[j].get();
if ((stmt->is<MatrixPtrStmt>() &&
stmt->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(stmt->is<MatrixPtrStmt>() &&
stmt->as<MatrixPtrStmt>()->origin->is<MatrixPtrStmt>()) ||
(!after_lower_access &&
(stmt->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() ||
stmt->is<BlockLocalPtrStmt>() || stmt->is<ThreadLocalPtrStmt>() ||
Expand Down
9 changes: 4 additions & 5 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,12 +843,11 @@ bool IndexExpression::is_local() const {
}

bool IndexExpression::is_global() const {
// Special case: Indexing into TensorType-element of ExternalPtrStmt
// or GlobalPtrStmt should be treated as global ptrs
if (var.is<IndexExpression>()) {
TI_ASSERT(var.cast<IndexExpression>()->is_matrix_field() ||
var.cast<IndexExpression>()->is_ndarray());
return true;
// Special case: Pointer chasing. For example, if we are indexing into
// tensor elements of fields / ndarrays, this index expr should be treated
// as global.
return var.cast<IndexExpression>()->is_global();
}

// Only Ndarray and Field comes outside from a kernel
Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_shared_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,19 @@ def test():

test()
assert (y.to_numpy()[0] == [4.0, 8.0, 12.0, 16.0]).all()


@test_utils.test(arch=[ti.cuda], debug=True)
def test_shared_array_matrix():
@ti.kernel
def foo():
for x in range(10):
shared = ti.simt.block.SharedArray((10,), dtype=ti.math.vec3)
shared[x] = ti.Vector([x + 1, x + 2, x + 3])
assert shared[x].z == x + 3
assert (shared[x] == ti.Vector([x + 1, x + 2, x + 3])).all()

print(shared[x].z)
print(shared[x])

foo()

0 comments on commit 7563ab8

Please sign in to comment.