Skip to content

Commit

Permalink
[Backport to 17] add initial support for CooperativeMatrixConstructCh…
Browse files Browse the repository at this point in the history
  • Loading branch information
VyacheslavLevytskyy authored and MrSidims committed Apr 10, 2024
1 parent 0873793 commit 4dbe874
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3446,6 +3446,7 @@ class SPIRVCooperativeMatrixCheckedInstructionsINTELInstBase
SPIRV##x##INTEL;
_SPIRV_OP(CooperativeMatrixLoadChecked, true, 9, true, 7)
_SPIRV_OP(CooperativeMatrixStoreChecked, false, 8, true, 8)
_SPIRV_OP(CooperativeMatrixConstructChecked, true, 8)
#undef _SPIRV_OP

class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase {
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ _SPIRV_OP_INTERNAL(CooperativeMatrixLoadCheckedINTEL,
internal::OpCooperativeMatrixLoadCheckedINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixStoreCheckedINTEL,
internal::OpCooperativeMatrixStoreCheckedINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixConstructCheckedINTEL,
internal::OpCooperativeMatrixConstructCheckedINTEL)
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ enum InternalOp {
IOpTypeJointMatrixINTELv2 = 6184,
IOpCooperativeMatrixLoadCheckedINTEL = 6193,
IOpCooperativeMatrixStoreCheckedINTEL = 6194,
IOpCooperativeMatrixConstructCheckedINTEL = 6195,
IOpJointMatrixWorkItemLengthINTEL = 6410,
IOpComplexFMulINTEL = 6415,
IOpComplexFDivINTEL = 6416,
Expand Down Expand Up @@ -195,6 +196,7 @@ _SPIRV_OP(Op, CooperativeMatrixPrefetchINTEL)
_SPIRV_OP(Capability, CooperativeMatrixCheckedInstructionsINTEL)
_SPIRV_OP(Op, CooperativeMatrixLoadCheckedINTEL)
_SPIRV_OP(Op, CooperativeMatrixStoreCheckedINTEL)
_SPIRV_OP(Op, CooperativeMatrixConstructCheckedINTEL)

_SPIRV_OP(Capability, HWThreadQueryINTEL)
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy1:]] [[#Int32Ty]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const2]]
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy2:]] [[#Int8Ty]] [[#Const3]] [[#Const12]] [[#Const48]] [[#Const0]]
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy2]] [[#Load1:]]
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]]
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL

; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructi(i32 0)
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTELiilli(i32 4, i32 4, i64 12, i64 12, i32 %_arg_Initvalue)
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4ciiillli(ptr addrspace(4) %[[MatrixPtr:[%0-9a-z.]+]], i32 0, i32 0, i32 0, i64 12, i64 48, i64 %_arg_K, i32 1)
; CHECK-LLVM: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS144__spirv_CooperativeMatrixKHR__char_3_12_48_0(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0)
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4ciiilll
Expand All @@ -52,7 +52,7 @@ $_ZTSZZ15matrix_multiply = comdat any
@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

; Function Attrs: convergent norecurse
define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB6, ptr addrspace(1) noundef align 4 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K) local_unnamed_addr #0 comdat {
define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB6, ptr addrspace(1) noundef align 4 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K, i32 noundef %_arg_Initvalue) local_unnamed_addr #0 comdat {
entry:
%sub_c.sroa.0.i = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
%ref.tmp29.sroa.0.i = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
Expand All @@ -77,7 +77,7 @@ entry:
%cmp.i58.i = icmp ult i64 %5, 2147483648
%sub5.i = sub nsw i64 %2, %5
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 noundef 0) #4
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef 4, i32 noundef 4, i64 noundef 12, i64 noundef 12, i32 noundef %_arg_Initvalue) #4
store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %call.i.i, ptr %sub_c.sroa.0.i, align 8
%mul.i = mul nsw i64 %sub.i, 12
%div2452.i = lshr i64 %sub5.i, 4
Expand Down Expand Up @@ -133,7 +133,7 @@ _ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6
}

; Function Attrs: convergent
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef, i32 noundef, i64 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2

declare dso_local spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef)

Expand Down

0 comments on commit 4dbe874

Please sign in to comment.