Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add block support to all SPILUK algorithms #2064

Merged
merged 43 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0e51d17
Interface for block iluk
jgfouca Nov 20, 2023
682a6ad
Progress. Test hooked up
jgfouca Nov 20, 2023
f6b44e1
Progress on test refactoring
jgfouca Nov 20, 2023
d31e190
More test reorg
jgfouca Nov 21, 2023
edcab95
Fix test
jgfouca Nov 21, 2023
6572087
Refactor spiluk numeric a bit with a struct wrapper
jgfouca Nov 21, 2023
f0143a2
Add good logging
jgfouca Nov 22, 2023
8242a8c
progress
jgfouca Nov 24, 2023
2b8c77b
Fix block test
jgfouca Nov 24, 2023
eb65716
Progress but potential dead end
jgfouca Nov 27, 2023
fd2e8d7
Giving up on this approach for now
jgfouca Nov 29, 2023
d24cb74
progress
jgfouca Nov 30, 2023
b36601d
Make verbose
jgfouca Nov 30, 2023
d0e5cb5
Progress
jgfouca Dec 1, 2023
57892db
Progress
jgfouca Dec 1, 2023
efad029
RP working?
jgfouca Dec 4, 2023
7b6a9a0
Progress on TP alg
jgfouca Dec 4, 2023
4b555ec
Bug fix
jgfouca Dec 4, 2023
85bdb4f
Progress on template stuff
jgfouca Dec 5, 2023
5a3f804
Progress on block TP
jgfouca Dec 5, 2023
458d86a
Progress
jgfouca Dec 6, 2023
68ce07c
Get rid of all the static_casts
jgfouca Dec 6, 2023
f6b1aff
More cleanup. Steams now support blocks
jgfouca Dec 6, 2023
8cf6fde
Tests not passing
jgfouca Dec 6, 2023
4764a79
Serail tests all working, both algs, blocked
jgfouca Dec 6, 2023
3442207
Remove output coming from spiluk test
jgfouca Dec 6, 2023
7bd8ace
Final fixes for CPU
jgfouca Dec 7, 2023
cf025f5
Cuda req full template specification for SerialGemm::invoke
jgfouca Dec 7, 2023
7d4efa8
Don't use scratch for now
jgfouca Dec 8, 2023
244fef9
Formatting
jgfouca Dec 8, 2023
baf9318
Fix warnings
jgfouca Dec 8, 2023
29a14c9
Formatting
jgfouca Dec 8, 2023
b38cb19
Add tolerance to view checks. Use macro and remove redundant test util
jgfouca Dec 11, 2023
0cc30d1
Fix for HIP
jgfouca Dec 14, 2023
9849363
formatting
jgfouca Dec 14, 2023
648ad6f
Another test reorg to fix weirdness on solo
jgfouca Dec 22, 2023
6d9a95d
formatting
jgfouca Dec 23, 2023
f61025b
Remove unused var
jgfouca Dec 24, 2023
d17f4d4
Github feedback
jgfouca Jan 9, 2024
79d5d72
Remove test cout
jgfouca Jan 9, 2024
4c1e9ae
formatting
jgfouca Jan 9, 2024
6538958
Zero-size arrays can cause problems
jgfouca Jan 9, 2024
8e06737
Fix unused var warning
jgfouca Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions batched/dense/impl/KokkosBatched_Trsm_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,32 @@ struct SerialTrsm<Side::Right, Uplo::Upper, Trans::NoTranspose, ArgDiag,
}
};

template <typename ArgDiag>
struct SerialTrsm<Side::Right, Uplo::Upper, Trans::Transpose, ArgDiag,
Algo::Trsm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
ArgDiag::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(),
A.stride_0(), A.stride_1(), B.data(), B.stride_1(), B.stride_0());
}
};

template <typename ArgDiag>
struct SerialTrsm<Side::Right, Uplo::Upper, Trans::Transpose, ArgDiag,
Algo::Trsm::Blocked> {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
ArgDiag::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(),
A.stride_0(), A.stride_1(), B.data(), B.stride_1(), B.stride_0());
}
};

///
/// L/U/NT
///
Expand Down
30 changes: 30 additions & 0 deletions batched/dense/impl/KokkosBatched_Trsm_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,36 @@ struct TeamTrsm<MemberType, Side::Right, Uplo::Upper, Trans::NoTranspose,
}
};

template <typename MemberType, typename ArgDiag>
struct TeamTrsm<MemberType, Side::Right, Uplo::Upper, Trans::Transpose, ArgDiag,
Algo::Trsm::Unblocked> {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return TeamTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
member, ArgDiag::use_unit_diag, B.extent(1), B.extent(0), alpha,
A.data(), A.stride_0(), A.stride_1(), B.data(), B.stride_1(),
B.stride_0());
}
};

template <typename MemberType, typename ArgDiag>
struct TeamTrsm<MemberType, Side::Right, Uplo::Upper, Trans::Transpose, ArgDiag,
Algo::Trsm::Blocked> {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return TeamTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
member, ArgDiag::use_unit_diag, B.extent(1), B.extent(0), alpha,
A.data(), A.stride_0(), A.stride_1(), B.data(), B.stride_1(),
B.stride_0());
}
};

///
/// L/U/NT
///
Expand Down
Loading