From 7561da691cb3d1802328e17c6713481f82a9c3fd Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 9 Nov 2019 22:58:54 +0100 Subject: [PATCH 01/81] Implement wrappers for WMMA LLVM intrinsics --- src/device/cuda.jl | 1 + src/device/cuda/wmma.jl | 242 ++++++++++++++++++++++++++++++++++++++++ test/device/wmma.jl | 135 ++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 379 insertions(+) create mode 100644 src/device/cuda/wmma.jl create mode 100644 test/device/wmma.jl diff --git a/src/device/cuda.jl b/src/device/cuda.jl index 13833fd6..4e2772ad 100644 --- a/src/device/cuda.jl +++ b/src/device/cuda.jl @@ -11,6 +11,7 @@ include("cuda/assertion.jl") include("cuda/memory_dynamic.jl") include("cuda/atomics.jl") include("cuda/misc.jl") +include("cuda/wmma.jl") # functionality from libdevice # diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl new file mode 100644 index 00000000..35832f82 --- /dev/null +++ b/src/device/cuda/wmma.jl @@ -0,0 +1,242 @@ +################################################################################ +# CONSTANTS +################################################################################ + +# Maps PTX types to LLVM types +map_ptx_to_llvm = Dict( + "f16" => "<2 x half>", + "f32" => "float" + ) + +# Maps PTX types to the LLVM type that llvmcall expects +map_ptx_to_llvmcall = Dict( + "f16" => "<2 x i16>", + "f32" => "float" + ) + +# Maps PTX types to Julia types +map_ptx_to_jl = Dict( + "f16" => NTuple{2, VecElement{Float16}}, + "f32" => Float32 + ) + +# Maps matrix & PTX types to fragment sizes +map_frag_sizes = Dict( + "a.f16" => 8, + "b.f16" => 8, + "c.f16" => 4, + "c.f32" => 8, + "d.f16" => 4, + "d.f32" => 8 + ) + +################################################################################ +# HELPER FUNCTIONS +################################################################################ + +macro gen_ir(template, count, delim="\n") + return quote + join([$(esc(template)) for $(esc(:i)) in 0:$(esc(count))-1], $(esc(delim))) + end +end + +function join_nonempty(args...) + delim = args[end] + arr = [args[1:end-1]...] + + return join(arr[arr .!= ""], delim) +end + +get_llvm_ty(matrix, ptx_el_type) = map_ptx_to_llvm[ptx_el_type] + +get_llvmcall_ty(matrix, ptx_el_type) = map_ptx_to_llvmcall[ptx_el_type] + +get_jl_ty(matrix, ptx_el_type) = map_ptx_to_jl[ptx_el_type] + +get_frag_sz(matrix, ptx_el_type) = map_frag_sizes["$matrix.$ptx_el_type"] + +################################################################################ +# MATRIX LOAD +################################################################################ + +for mat in ["a", "b", "c"], + layout in ["col", "row"], + shape in ["m16n16k16"], + addr_space in ["", "shared", "global"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Non-stride versions? + + # Float32 is only supported for C + if (elem_type == "f32") && (mat != "c") + continue + end + + # Name of the Julia wrapper function + func_name = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) + + # Name of the LLVM intrinsic + llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "load", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") + + # Determine types for this (matrix, elem_type) combination + sz = get_frag_sz(mat, elem_type) + llvm_ty = get_llvm_ty(mat, elem_type) + struct_ty = "{ $(@gen_ir(llvm_ty, sz, ", ")) }" + lc_ty = get_llvmcall_ty(mat, elem_type) + jl_ty = get_jl_ty(mat, elem_type) + + # Generate LLVM IR + ir = ("declare $struct_ty $llvm_intr(i8*, i32)", + " + %src_ptr = inttoptr i64 %0 to i8* + + %ret.llvm = call $struct_ty $llvm_intr(i8* %src_ptr, i32 %1) + + $(@gen_ir("%ret.llvm.$i = extractvalue $struct_ty %ret.llvm, $i", sz)) + + $(@gen_ir("%ret.jl.$i = bitcast $llvm_ty %ret.llvm.$i to $lc_ty", sz)) + + $(@gen_ir("%ret.aggr.$i = insertvalue [$sz x $lc_ty] $(i == 0 ? "undef" : "%ret.aggr.$(i-1)"), $lc_ty %ret.jl.$i, $i", sz)) + + ret [$sz x $lc_ty] %ret.aggr.$(sz-1) + ") + + @eval $func_name(src_addr, stride) = Base.llvmcall($ir, + NTuple{$sz, $jl_ty}, + Tuple{Int64, Int32}, + convert(Int64, src_addr), + convert(Int32, stride)) + + @eval export $func_name +end + +################################################################################ +# MATRIX STORE +################################################################################ + +for mat in ["d"], + layout in ["col", "row"], + shape in ["m16n16k16"], + addr_space in ["", "shared", "global"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Non-stride versions? + + # Name of the Julia wrapper function + func_name = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) + + # Name of the LLVM intrinsic + llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "store", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") + + # Determine types for this (matrix, elem_type) combination + sz = get_frag_sz(mat, elem_type) + llvm_ty = get_llvm_ty(mat, elem_type) + lc_ty = get_llvmcall_ty(mat, elem_type) + jl_ty = get_jl_ty(mat, elem_type) + + # Generate LLVM IR + ir = ("declare void $llvm_intr(i8*, $(@gen_ir("$llvm_ty", sz, ", ")), i32)", + " + %dst_ptr = inttoptr i64 %0 to i8* + + $(@gen_ir("%data.jl.$i = extractvalue [$sz x $lc_ty] %1, $i", sz)) + + $(@gen_ir("%data.llvm.$i = bitcast $lc_ty %data.jl.$i to $llvm_ty", sz)) + + call void $llvm_intr(i8* %dst_ptr, $(@gen_ir("$llvm_ty %data.llvm.$i", sz, ", ")) , i32 %2) + ret void + ") + + @eval $func_name(dst_addr, data, stride) = Base.llvmcall($ir, + Nothing, + Tuple{Int64, NTuple{$sz, $jl_ty}, Int32}, + convert(Int64, dst_addr), + convert(NTuple{$sz, $jl_ty}, data), + convert(Int32, stride)) + + @eval export $func_name +end + +################################################################################ +# MATRIX MULTIPLY ACCUMULATE +################################################################################ + +for a_layout in ["col", "row"], + b_layout in ["col", "row"], + shape in ["m16n16k16"], + d_elem_type in ["f16", "f32"], + c_elem_type in ["f16", "f32"], + b_elem_type in ["f16"], + a_elem_type in ["f16"] + + # Name of the Julia wrapper function + func_name = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) + + # Name of the LLVM intrinsic + llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "mma", "sync", a_layout, b_layout, shape, d_elem_type, c_elem_type, ".") + + # Determine types for the (matrix, elem_type) combinations for matrix A + a_sz = get_frag_sz("a", a_elem_type) + a_llvm_ty = get_llvm_ty("a", a_elem_type) + a_lc_ty = get_llvmcall_ty("a", a_elem_type) + a_jl_ty = get_jl_ty("a", a_elem_type) + + # Determine types for the (matrix, elem_type) combinations for matrix B + b_sz = get_frag_sz("b", b_elem_type) + b_llvm_ty = get_llvm_ty("b", b_elem_type) + b_lc_ty = get_llvmcall_ty("b", b_elem_type) + b_jl_ty = get_jl_ty("b", b_elem_type) + + # Determine types for the (matrix, elem_type) combinations for matrix C + c_sz = get_frag_sz("c", c_elem_type) + c_llvm_ty = get_llvm_ty("c", c_elem_type) + c_lc_ty = get_llvmcall_ty("c", c_elem_type) + c_jl_ty = get_jl_ty("c", c_elem_type) + + # Determine types for the (matrix, elem_type) combinations for matrix D + d_sz = get_frag_sz("d", d_elem_type) + d_llvm_ty = get_llvm_ty("d", d_elem_type) + d_lc_ty = get_llvmcall_ty("d", d_elem_type) + d_jl_ty = get_jl_ty("d", d_elem_type) + d_struct_ty = "{ $(@gen_ir(d_llvm_ty, d_sz, ", ")) }" + + # Create the argument string to the IR call + args = join([ + @gen_ir("$a_llvm_ty %a.llvm.$i", a_sz, ", "), + @gen_ir("$b_llvm_ty %b.llvm.$i", b_sz, ", "), + @gen_ir("$c_llvm_ty %c.llvm.$i", c_sz, ", ")] + , ", ") + + # Generate LLVM IR + ir = ("declare $d_struct_ty $llvm_intr($args)", + " + $(@gen_ir("%a.jl.$i = extractvalue [$a_sz x $a_lc_ty] %0, $i", a_sz)) + $(@gen_ir("%b.jl.$i = extractvalue [$b_sz x $b_lc_ty] %1, $i", b_sz)) + $(@gen_ir("%c.jl.$i = extractvalue [$c_sz x $c_lc_ty] %2, $i", c_sz)) + + $(@gen_ir("%a.llvm.$i = bitcast $a_lc_ty %a.jl.$i to $a_llvm_ty", a_sz)) + $(@gen_ir("%b.llvm.$i = bitcast $b_lc_ty %b.jl.$i to $b_llvm_ty", b_sz)) + $(@gen_ir("%c.llvm.$i = bitcast $c_lc_ty %c.jl.$i to $c_llvm_ty", c_sz)) + + %d.llvm = call $d_struct_ty $llvm_intr($args) + + $(@gen_ir("%d.llvm.$i = extractvalue $d_struct_ty %d.llvm, $i", d_sz)) + + $(@gen_ir("%d.jl.$i = bitcast $d_llvm_ty %d.llvm.$i to $d_lc_ty", d_sz)) + + $(@gen_ir("%d.aggr.$i = insertvalue [$d_sz x $d_lc_ty] $(i == 0 ? "undef" : "%d.aggr.$(i-1)"), $d_lc_ty %d.jl.$i, $i", d_sz)) + + ret [$d_sz x $d_lc_ty] %d.aggr.$(d_sz-1) + ") + + @eval $func_name(a, b, c) = Base.llvmcall($ir, + NTuple{$d_sz, $d_jl_ty}, + Tuple{NTuple{$a_sz, $a_jl_ty}, NTuple{$b_sz, $b_jl_ty}, NTuple{$c_sz, $c_jl_ty}}, + convert(NTuple{$a_sz, $a_jl_ty}, a), + convert(NTuple{$b_sz, $b_jl_ty}, b), + convert(NTuple{$c_sz, $c_jl_ty}, c)) + + @eval export $func_name +end diff --git a/test/device/wmma.jl b/test/device/wmma.jl new file mode 100644 index 00000000..2e083bff --- /dev/null +++ b/test/device/wmma.jl @@ -0,0 +1,135 @@ +@testset "WMMA" begin + +################################################################################ + + @testset "LLVM intrinsics" begin + + @testset "llvm_wmma_load" begin + @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"], + layout in ["row", "col"], + shape in ["m16n16k16"], + addr_space in [""], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Test address space? + + # Float32 is only supported for C + if (elem_type == "f32") && (mat != "c") + continue + end + + # Type-dependent variables + array_ty = elem_type == "f16" ? Float16 : Float32 + expected = elem_type == "f16" ? (VecElement{Float16}(42), VecElement{Float16}(42)) : Float32(42) + + # Get the function name + func = getfield(Main, Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)")) + + input = 42 * ones(array_ty, (16, 16)) + input_dev = CuArray(input) + result = Array{Bool}(undef, 1) + result_dev = CuArray(result) + + function kernel(input_dev, result_dev) + data = func(pointer(input_dev), 16) + result_dev[1] = all(val -> val == expected, data) + return + end + + @cuda threads=32 kernel(input_dev, result_dev) + @test all(Array(result_dev)) + end + end + + @testset "llvm_wmma_store" begin + @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], + layout in ["row", "col"], + shape in ["m16n16k16"], + addr_space in [""], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Test address space? + + # Type-dependent variables + array_ty = elem_type == "f16" ? Float16 : Float32 + data = elem_type == "f16" ? + ( + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)) + ) : (42, 42, 42, 42, 42, 42, 42, 42) + + # Get the function name + func = getfield(Main, Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)_stride_$(elem_type)")) + + output = Array{array_ty}(undef, (16, 16)) + output_dev = CuArray(output) + + function kernel(output_dev) + func(pointer(output_dev), data, 16) + return + end + + @cuda threads=32 kernel(output_dev) + @test all(Array(output_dev) .== 42.0) + end + end + + @testset "llvm_wmma_mma" begin + @testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], + b_layout in ["row", "col"], + shape in ["m16n16k16"], + d_elem_type in ["f16", "f32"], + c_elem_type in ["f16", "f32"] + + # Type-dependent variables + d_ty = d_elem_type == "f16" ? Float16 : Float32 + c_ty = c_elem_type == "f16" ? Float16 : Float32 + + # Get the function names + lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16")) + ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16")) + ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)")) + mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)")) + std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)")) + + # Generate input matrices + a = rand(Float16, (16, 16)) + a_dev = CuArray(a) + b = rand(Float16, (16, 16)) + b_dev = CuArray(b) + c = rand(c_ty, (16, 16)) + c_dev = CuArray(c) + + # Reserve space for result + d = Array{d_ty}(undef, (16, 16)) + d_dev = CuArray(d) + + # Matrix MAC kernel (D = A * B + C) + function kernel(a_dev, b_dev, c_dev, d_dev) + a_frag = lda_func(pointer(a_dev), 16) + b_frag = ldb_func(pointer(b_dev), 16) + c_frag = ldc_func(pointer(c_dev), 16) + + d_frag = mma_func(a_frag, b_frag, c_frag) + + std_func(pointer(d_dev), d_frag, 16) + return + end + + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) + + new_a = (a_layout == "col" ? a : transpose(a)) + new_b = (b_layout == "col" ? b : transpose(b)) + + @test new_a * new_b + c ≈ Array(d_dev) rtol=0.01 + end + end + end + +################################################################################ + +end diff --git a/test/runtests.jl b/test/runtests.jl index ccedf891..18f933b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,6 +70,7 @@ else include("device/pointer.jl") include("device/array.jl") include("device/cuda.jl") + include("device/wmma.jl") include("examples.jl") end From 8f4f2d1eb7107744d8b80d00d893a67ff2c93d34 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 12:16:23 +0100 Subject: [PATCH 02/81] Implement basic CUDA-style API --- src/device/cuda.jl | 1 + src/device/cuda/wmma-highlevel.jl | 96 +++++++++++++++++++++++++++++++ test/device/wmma-highlevel.jl | 35 +++++++++++ test/runtests.jl | 23 ++++---- 4 files changed, 144 insertions(+), 11 deletions(-) create mode 100644 src/device/cuda/wmma-highlevel.jl create mode 100644 test/device/wmma-highlevel.jl diff --git a/src/device/cuda.jl b/src/device/cuda.jl index 4e2772ad..a93bc453 100644 --- a/src/device/cuda.jl +++ b/src/device/cuda.jl @@ -12,6 +12,7 @@ include("cuda/memory_dynamic.jl") include("cuda/atomics.jl") include("cuda/misc.jl") include("cuda/wmma.jl") +include("cuda/wmma-highlevel.jl") # functionality from libdevice # diff --git a/src/device/cuda/wmma-highlevel.jl b/src/device/cuda/wmma-highlevel.jl new file mode 100644 index 00000000..eb735641 --- /dev/null +++ b/src/device/cuda/wmma-highlevel.jl @@ -0,0 +1,96 @@ +################################################################################ +# WMMA FRAGMENT +################################################################################ + +export wmma_row_major, wmma_col_major, wmma_unspecified + +abstract type wmma_fragment_layout end +struct wmma_row_major <: wmma_fragment_layout end +struct wmma_col_major <: wmma_fragment_layout end +struct wmma_unspecified <: wmma_fragment_layout end + + +export wmma_matrix_a, wmma_matrix_b, wmma_accumulator + +abstract type wmma_fragment_use end +struct wmma_matrix_a <: wmma_fragment_use end +struct wmma_matrix_b <: wmma_fragment_use end +struct wmma_accumulator <: wmma_fragment_use end + + +export wmma_fragment + +struct wmma_fragment{M, N, K, FS, T, L <: wmma_fragment_layout, U <: wmma_fragment_use} + x::NTuple{FS, T} +end + +################################################################################ +# WMMA CONFIGURATION +################################################################################ + +export wmma_config +struct wmma_config{M, N, K} end + +################################################################################ +# WMMA LOAD +################################################################################ + +export wmma_load_a, wmma_load_b, wmma_load_c + +function wmma_load_a(addr::DevicePtr{Float16, AS.Global}, + stride::Number, + layout::Type{wmma_col_major}, + config::Type{wmma_config{16, 16, 16}}) + x = llvm_wmma_load_a_col_m16n16k16_stride_f16(addr, stride) + return wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_a}(x) +end + +function wmma_load_b(addr::DevicePtr{Float16, AS.Global}, + stride::Number, + layout::Type{wmma_col_major}, + config::Type{wmma_config{16, 16, 16}}) + x = llvm_wmma_load_b_col_m16n16k16_stride_f16(addr, stride) + return wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_b}(x) +end + +function wmma_load_c(addr::DevicePtr{Float16, AS.Global}, + stride::Number, + layout::Type{wmma_col_major}, + config::Type{wmma_config{16, 16, 16}}) + x = llvm_wmma_load_c_col_m16n16k16_stride_f16(addr, stride) + return wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}(x) +end + +################################################################################ +# WMMA MMA +################################################################################ + +export wmma_mma + +function wmma_mma(a::wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_a}, + b::wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_b}, + c::wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}) + x = llvm_wmma_mma_col_col_m16n16k16_f16_f16(a.x, b.x, c.x) + return wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}(x) +end + +################################################################################ +# WMMA STORE +################################################################################ + +export wmma_store_d + +function wmma_store_d(addr::DevicePtr{Float16, AS.Global}, + d::wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}, + stride::Number, + layout::Type{wmma_col_major}, + config::Type{wmma_config{16, 16, 16}}) + llvm_wmma_store_d_col_m16n16k16_stride_f16(addr, d.x, stride) + return nothing +end + +################################################################################ +# WMMA FILL FRAGMENT +################################################################################ + +# TODO diff --git a/test/device/wmma-highlevel.jl b/test/device/wmma-highlevel.jl new file mode 100644 index 00000000..21b885f0 --- /dev/null +++ b/test/device/wmma-highlevel.jl @@ -0,0 +1,35 @@ +@testset "WMMA" begin + @testset "CUDA C-style API" begin + + @testset "One specific case" begin + a = rand(Float16, (16, 16)) + b = rand(Float16, (16, 16)) + c = rand(Float16, (16, 16)) + d = Array{Float16}(undef, (16, 16)) + + a_dev = CuArray(a) + b_dev = CuArray(b) + c_dev = CuArray(c) + d_dev = CuArray(d) + + function kernel(a_dev, b_dev, c_dev, d_dev) + conf = wmma_config{16, 16, 16} + + a_frag = wmma_load_a(pointer(a_dev), 16, wmma_col_major, conf) + b_frag = wmma_load_b(pointer(b_dev), 16, wmma_col_major, conf) + c_frag = wmma_load_c(pointer(c_dev), 16, wmma_col_major, conf) + + d_frag = wmma_mma(a_frag, b_frag, c_frag) + + wmma_store_d(pointer(d_dev), d_frag, 16, wmma_col_major, conf) + + return + end + + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) + d = Array(d_dev) + @test a * b + c ≈ d rtol=0.01 + end + + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 18f933b9..954f22f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,9 +55,9 @@ if length(devices()) > 0 end cap = CUDAnative.current_capability() -include("base.jl") -include("pointer.jl") -include("codegen.jl") +#= include("base.jl") =# +#= include("pointer.jl") =# +#= include("codegen.jl") =# if dev === nothing @warn("No CUDA-capable devices available; skipping on-device tests.") @@ -65,14 +65,15 @@ else if capability(dev) < v"2.0" @warn("native execution not supported on SM < 2.0") else - include("device/codegen.jl") - include("device/execution.jl") - include("device/pointer.jl") - include("device/array.jl") - include("device/cuda.jl") - include("device/wmma.jl") - - include("examples.jl") + #= include("device/codegen.jl") =# + #= include("device/execution.jl") =# + #= include("device/pointer.jl") =# + #= include("device/array.jl") =# + #= include("device/cuda.jl") =# + #= include("device/wmma.jl") =# + include("device/wmma-highlevel.jl") + + #= include("examples.jl") =# end end From 23d95521ae243e9f69b5860ae23dc368c6d431dd Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 12:47:05 +0100 Subject: [PATCH 03/81] Generalise load for matrix --- src/device/cuda/wmma-highlevel.jl | 60 ++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/src/device/cuda/wmma-highlevel.jl b/src/device/cuda/wmma-highlevel.jl index eb735641..67318135 100644 --- a/src/device/cuda/wmma-highlevel.jl +++ b/src/device/cuda/wmma-highlevel.jl @@ -31,34 +31,54 @@ end export wmma_config struct wmma_config{M, N, K} end +################################################################################ +# CONSTANTS +################################################################################ + +map_matrix_to_use = Dict( + "a" => wmma_matrix_a, + "b" => wmma_matrix_b, + "c" => wmma_accumulator, + "d" => wmma_accumulator + ) + ################################################################################ # WMMA LOAD ################################################################################ export wmma_load_a, wmma_load_b, wmma_load_c -function wmma_load_a(addr::DevicePtr{Float16, AS.Global}, - stride::Number, - layout::Type{wmma_col_major}, - config::Type{wmma_config{16, 16, 16}}) - x = llvm_wmma_load_a_col_m16n16k16_stride_f16(addr, stride) - return wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_a}(x) -end +for mat in ["a", "b", "c"] + layout = "col" + shape = "m16n16k16" + addr_space = "" + elem_type = "f16" -function wmma_load_b(addr::DevicePtr{Float16, AS.Global}, - stride::Number, - layout::Type{wmma_col_major}, - config::Type{wmma_config{16, 16, 16}}) - x = llvm_wmma_load_b_col_m16n16k16_stride_f16(addr, stride) - return wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_b}(x) -end + # Name of Julia function + func_name = Symbol("wmma_load_$mat") -function wmma_load_c(addr::DevicePtr{Float16, AS.Global}, - stride::Number, - layout::Type{wmma_col_major}, - config::Type{wmma_config{16, 16, 16}}) - x = llvm_wmma_load_c_col_m16n16k16_stride_f16(addr, stride) - return wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}(x) + # Name of the Julia wrapper + wrapper = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)") + + # Get fragment size + frag_sz = get_frag_sz(mat, elem_type) + + # Get Julia element type + julia_type = get_jl_ty(mat, elem_type) + + # Get matrix use type + matrix_use = map_matrix_to_use[mat] + + # Get layout type + layout_ty = (mat == "c") ? wmma_unspecified : (layout == "col") ? wmma_col_major : wmma_row_major + + @eval function $func_name(addr::DevicePtr{Float16, AS.Global}, + stride::Number, + layout::Type{wmma_col_major}, + config::Type{wmma_config{16, 16, 16}}) + x = $wrapper(addr, stride) + return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_ty, $matrix_use}(x) + end end ################################################################################ From faae545a2b2e028bcbaec487396225cf367200b4 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 9 Nov 2019 22:58:54 +0100 Subject: [PATCH 04/81] Implement wrappers for WMMA LLVM intrinsics --- src/device/cuda.jl | 1 + src/device/cuda/wmma.jl | 246 ++++++++++++++++++++++++++++++++++++++++ test/device/wmma.jl | 135 ++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 383 insertions(+) create mode 100644 src/device/cuda/wmma.jl create mode 100644 test/device/wmma.jl diff --git a/src/device/cuda.jl b/src/device/cuda.jl index 13833fd6..4e2772ad 100644 --- a/src/device/cuda.jl +++ b/src/device/cuda.jl @@ -11,6 +11,7 @@ include("cuda/assertion.jl") include("cuda/memory_dynamic.jl") include("cuda/atomics.jl") include("cuda/misc.jl") +include("cuda/wmma.jl") # functionality from libdevice # diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl new file mode 100644 index 00000000..06d4df0a --- /dev/null +++ b/src/device/cuda/wmma.jl @@ -0,0 +1,246 @@ +################################################################################ +# CONSTANTS +################################################################################ + +# Maps PTX types to LLVM types +map_ptx_to_llvm = Dict( + "f16" => "<2 x half>", + "f32" => "float" + ) + +# Maps PTX types to the LLVM type that llvmcall expects +map_ptx_to_llvmcall = Dict( + "f16" => "<2 x i16>", + "f32" => "float" + ) + +# Maps PTX types to Julia types +map_ptx_to_jl = Dict( + "f16" => NTuple{2, VecElement{Float16}}, + "f32" => Float32 + ) + +# Maps matrix & PTX types to fragment sizes +map_frag_sizes = Dict( + "a.f16" => 8, + "b.f16" => 8, + "c.f16" => 4, + "c.f32" => 8, + "d.f16" => 4, + "d.f32" => 8 + ) + +################################################################################ +# HELPER FUNCTIONS +################################################################################ + +macro gen_ir(template, count, delim="\n") + return quote + join([$(esc(template)) for $(esc(:i)) in 0:$(esc(count))-1], $(esc(delim))) + end +end + +function join_nonempty(args...) + delim = args[end] + arr = [args[1:end-1]...] + + return join(arr[arr .!= ""], delim) +end + +get_llvm_ty(matrix, ptx_el_type) = map_ptx_to_llvm[ptx_el_type] + +get_llvmcall_ty(matrix, ptx_el_type) = map_ptx_to_llvmcall[ptx_el_type] + +get_jl_ty(matrix, ptx_el_type) = map_ptx_to_jl[ptx_el_type] + +get_frag_sz(matrix, ptx_el_type) = map_frag_sizes["$matrix.$ptx_el_type"] + +################################################################################ +# LOW LEVEL API +################################################################################ + +# ----------- +# Matrix load +# ----------- + +for mat in ["a", "b", "c"], + layout in ["col", "row"], + shape in ["m16n16k16"], + addr_space in ["", "shared", "global"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Non-stride versions? + + # Float32 is only supported for C + if (elem_type == "f32") && (mat != "c") + continue + end + + # Name of the Julia wrapper function + func_name = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) + + # Name of the LLVM intrinsic + llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "load", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") + + # Determine types for this (matrix, elem_type) combination + sz = get_frag_sz(mat, elem_type) + llvm_ty = get_llvm_ty(mat, elem_type) + struct_ty = "{ $(@gen_ir(llvm_ty, sz, ", ")) }" + lc_ty = get_llvmcall_ty(mat, elem_type) + jl_ty = get_jl_ty(mat, elem_type) + + # Generate LLVM IR + ir = ("declare $struct_ty $llvm_intr(i8*, i32)", + " + %src_ptr = inttoptr i64 %0 to i8* + + %ret.llvm = call $struct_ty $llvm_intr(i8* %src_ptr, i32 %1) + + $(@gen_ir("%ret.llvm.$i = extractvalue $struct_ty %ret.llvm, $i", sz)) + + $(@gen_ir("%ret.jl.$i = bitcast $llvm_ty %ret.llvm.$i to $lc_ty", sz)) + + $(@gen_ir("%ret.aggr.$i = insertvalue [$sz x $lc_ty] $(i == 0 ? "undef" : "%ret.aggr.$(i-1)"), $lc_ty %ret.jl.$i, $i", sz)) + + ret [$sz x $lc_ty] %ret.aggr.$(sz-1) + ") + + @eval $func_name(src_addr, stride) = Base.llvmcall($ir, + NTuple{$sz, $jl_ty}, + Tuple{Int64, Int32}, + convert(Int64, src_addr), + convert(Int32, stride)) + + @eval export $func_name +end + +# ------------ +# Matrix store +# ------------ + +for mat in ["d"], + layout in ["col", "row"], + shape in ["m16n16k16"], + addr_space in ["", "shared", "global"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Non-stride versions? + + # Name of the Julia wrapper function + func_name = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) + + # Name of the LLVM intrinsic + llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "store", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") + + # Determine types for this (matrix, elem_type) combination + sz = get_frag_sz(mat, elem_type) + llvm_ty = get_llvm_ty(mat, elem_type) + lc_ty = get_llvmcall_ty(mat, elem_type) + jl_ty = get_jl_ty(mat, elem_type) + + # Generate LLVM IR + ir = ("declare void $llvm_intr(i8*, $(@gen_ir("$llvm_ty", sz, ", ")), i32)", + " + %dst_ptr = inttoptr i64 %0 to i8* + + $(@gen_ir("%data.jl.$i = extractvalue [$sz x $lc_ty] %1, $i", sz)) + + $(@gen_ir("%data.llvm.$i = bitcast $lc_ty %data.jl.$i to $llvm_ty", sz)) + + call void $llvm_intr(i8* %dst_ptr, $(@gen_ir("$llvm_ty %data.llvm.$i", sz, ", ")) , i32 %2) + ret void + ") + + @eval $func_name(dst_addr, data, stride) = Base.llvmcall($ir, + Nothing, + Tuple{Int64, NTuple{$sz, $jl_ty}, Int32}, + convert(Int64, dst_addr), + convert(NTuple{$sz, $jl_ty}, data), + convert(Int32, stride)) + + @eval export $func_name +end + +# -------------------------- +# Matrix multiply accumulate +# -------------------------- + +for a_layout in ["col", "row"], + b_layout in ["col", "row"], + shape in ["m16n16k16"], + d_elem_type in ["f16", "f32"], + c_elem_type in ["f16", "f32"], + b_elem_type in ["f16"], + a_elem_type in ["f16"] + + # Name of the Julia wrapper function + func_name = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) + + # Name of the LLVM intrinsic + llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "mma", "sync", a_layout, b_layout, shape, d_elem_type, c_elem_type, ".") + + # Determine types for the (matrix, elem_type) combinations for matrix A + a_sz = get_frag_sz("a", a_elem_type) + a_llvm_ty = get_llvm_ty("a", a_elem_type) + a_lc_ty = get_llvmcall_ty("a", a_elem_type) + a_jl_ty = get_jl_ty("a", a_elem_type) + + # Determine types for the (matrix, elem_type) combinations for matrix B + b_sz = get_frag_sz("b", b_elem_type) + b_llvm_ty = get_llvm_ty("b", b_elem_type) + b_lc_ty = get_llvmcall_ty("b", b_elem_type) + b_jl_ty = get_jl_ty("b", b_elem_type) + + # Determine types for the (matrix, elem_type) combinations for matrix C + c_sz = get_frag_sz("c", c_elem_type) + c_llvm_ty = get_llvm_ty("c", c_elem_type) + c_lc_ty = get_llvmcall_ty("c", c_elem_type) + c_jl_ty = get_jl_ty("c", c_elem_type) + + # Determine types for the (matrix, elem_type) combinations for matrix D + d_sz = get_frag_sz("d", d_elem_type) + d_llvm_ty = get_llvm_ty("d", d_elem_type) + d_lc_ty = get_llvmcall_ty("d", d_elem_type) + d_jl_ty = get_jl_ty("d", d_elem_type) + d_struct_ty = "{ $(@gen_ir(d_llvm_ty, d_sz, ", ")) }" + + # Create the argument string to the IR call + args = join([ + @gen_ir("$a_llvm_ty %a.llvm.$i", a_sz, ", "), + @gen_ir("$b_llvm_ty %b.llvm.$i", b_sz, ", "), + @gen_ir("$c_llvm_ty %c.llvm.$i", c_sz, ", ")] + , ", ") + + # Generate LLVM IR + ir = ("declare $d_struct_ty $llvm_intr($args)", + " + $(@gen_ir("%a.jl.$i = extractvalue [$a_sz x $a_lc_ty] %0, $i", a_sz)) + $(@gen_ir("%b.jl.$i = extractvalue [$b_sz x $b_lc_ty] %1, $i", b_sz)) + $(@gen_ir("%c.jl.$i = extractvalue [$c_sz x $c_lc_ty] %2, $i", c_sz)) + + $(@gen_ir("%a.llvm.$i = bitcast $a_lc_ty %a.jl.$i to $a_llvm_ty", a_sz)) + $(@gen_ir("%b.llvm.$i = bitcast $b_lc_ty %b.jl.$i to $b_llvm_ty", b_sz)) + $(@gen_ir("%c.llvm.$i = bitcast $c_lc_ty %c.jl.$i to $c_llvm_ty", c_sz)) + + %d.llvm = call $d_struct_ty $llvm_intr($args) + + $(@gen_ir("%d.llvm.$i = extractvalue $d_struct_ty %d.llvm, $i", d_sz)) + + $(@gen_ir("%d.jl.$i = bitcast $d_llvm_ty %d.llvm.$i to $d_lc_ty", d_sz)) + + $(@gen_ir("%d.aggr.$i = insertvalue [$d_sz x $d_lc_ty] $(i == 0 ? "undef" : "%d.aggr.$(i-1)"), $d_lc_ty %d.jl.$i, $i", d_sz)) + + ret [$d_sz x $d_lc_ty] %d.aggr.$(d_sz-1) + ") + + @eval $func_name(a, b, c) = Base.llvmcall($ir, + NTuple{$d_sz, $d_jl_ty}, + Tuple{NTuple{$a_sz, $a_jl_ty}, NTuple{$b_sz, $b_jl_ty}, NTuple{$c_sz, $c_jl_ty}}, + convert(NTuple{$a_sz, $a_jl_ty}, a), + convert(NTuple{$b_sz, $b_jl_ty}, b), + convert(NTuple{$c_sz, $c_jl_ty}, c)) + + @eval export $func_name +end diff --git a/test/device/wmma.jl b/test/device/wmma.jl new file mode 100644 index 00000000..2e083bff --- /dev/null +++ b/test/device/wmma.jl @@ -0,0 +1,135 @@ +@testset "WMMA" begin + +################################################################################ + + @testset "LLVM intrinsics" begin + + @testset "llvm_wmma_load" begin + @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"], + layout in ["row", "col"], + shape in ["m16n16k16"], + addr_space in [""], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Test address space? + + # Float32 is only supported for C + if (elem_type == "f32") && (mat != "c") + continue + end + + # Type-dependent variables + array_ty = elem_type == "f16" ? Float16 : Float32 + expected = elem_type == "f16" ? (VecElement{Float16}(42), VecElement{Float16}(42)) : Float32(42) + + # Get the function name + func = getfield(Main, Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)")) + + input = 42 * ones(array_ty, (16, 16)) + input_dev = CuArray(input) + result = Array{Bool}(undef, 1) + result_dev = CuArray(result) + + function kernel(input_dev, result_dev) + data = func(pointer(input_dev), 16) + result_dev[1] = all(val -> val == expected, data) + return + end + + @cuda threads=32 kernel(input_dev, result_dev) + @test all(Array(result_dev)) + end + end + + @testset "llvm_wmma_store" begin + @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], + layout in ["row", "col"], + shape in ["m16n16k16"], + addr_space in [""], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Test address space? + + # Type-dependent variables + array_ty = elem_type == "f16" ? Float16 : Float32 + data = elem_type == "f16" ? + ( + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)) + ) : (42, 42, 42, 42, 42, 42, 42, 42) + + # Get the function name + func = getfield(Main, Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)_stride_$(elem_type)")) + + output = Array{array_ty}(undef, (16, 16)) + output_dev = CuArray(output) + + function kernel(output_dev) + func(pointer(output_dev), data, 16) + return + end + + @cuda threads=32 kernel(output_dev) + @test all(Array(output_dev) .== 42.0) + end + end + + @testset "llvm_wmma_mma" begin + @testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], + b_layout in ["row", "col"], + shape in ["m16n16k16"], + d_elem_type in ["f16", "f32"], + c_elem_type in ["f16", "f32"] + + # Type-dependent variables + d_ty = d_elem_type == "f16" ? Float16 : Float32 + c_ty = c_elem_type == "f16" ? Float16 : Float32 + + # Get the function names + lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16")) + ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16")) + ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)")) + mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)")) + std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)")) + + # Generate input matrices + a = rand(Float16, (16, 16)) + a_dev = CuArray(a) + b = rand(Float16, (16, 16)) + b_dev = CuArray(b) + c = rand(c_ty, (16, 16)) + c_dev = CuArray(c) + + # Reserve space for result + d = Array{d_ty}(undef, (16, 16)) + d_dev = CuArray(d) + + # Matrix MAC kernel (D = A * B + C) + function kernel(a_dev, b_dev, c_dev, d_dev) + a_frag = lda_func(pointer(a_dev), 16) + b_frag = ldb_func(pointer(b_dev), 16) + c_frag = ldc_func(pointer(c_dev), 16) + + d_frag = mma_func(a_frag, b_frag, c_frag) + + std_func(pointer(d_dev), d_frag, 16) + return + end + + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) + + new_a = (a_layout == "col" ? a : transpose(a)) + new_b = (b_layout == "col" ? b : transpose(b)) + + @test new_a * new_b + c ≈ Array(d_dev) rtol=0.01 + end + end + end + +################################################################################ + +end diff --git a/test/runtests.jl b/test/runtests.jl index ccedf891..18f933b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,6 +70,7 @@ else include("device/pointer.jl") include("device/array.jl") include("device/cuda.jl") + include("device/wmma.jl") include("examples.jl") end From 844f28e7618eba2d7676374ea1fddde0e706df30 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 13:02:46 +0100 Subject: [PATCH 05/81] Move high-level API to same file --- src/device/cuda.jl | 1 - src/device/cuda/wmma-highlevel.jl | 116 ---------------------------- src/device/cuda/wmma.jl | 121 ++++++++++++++++++++++++++++++ test/device/wmma-highlevel.jl | 35 --------- test/device/wmma.jl | 35 +++++++++ test/runtests.jl | 3 +- 6 files changed, 157 insertions(+), 154 deletions(-) delete mode 100644 src/device/cuda/wmma-highlevel.jl delete mode 100644 test/device/wmma-highlevel.jl diff --git a/src/device/cuda.jl b/src/device/cuda.jl index a93bc453..4e2772ad 100644 --- a/src/device/cuda.jl +++ b/src/device/cuda.jl @@ -12,7 +12,6 @@ include("cuda/memory_dynamic.jl") include("cuda/atomics.jl") include("cuda/misc.jl") include("cuda/wmma.jl") -include("cuda/wmma-highlevel.jl") # functionality from libdevice # diff --git a/src/device/cuda/wmma-highlevel.jl b/src/device/cuda/wmma-highlevel.jl deleted file mode 100644 index 67318135..00000000 --- a/src/device/cuda/wmma-highlevel.jl +++ /dev/null @@ -1,116 +0,0 @@ -################################################################################ -# WMMA FRAGMENT -################################################################################ - -export wmma_row_major, wmma_col_major, wmma_unspecified - -abstract type wmma_fragment_layout end -struct wmma_row_major <: wmma_fragment_layout end -struct wmma_col_major <: wmma_fragment_layout end -struct wmma_unspecified <: wmma_fragment_layout end - - -export wmma_matrix_a, wmma_matrix_b, wmma_accumulator - -abstract type wmma_fragment_use end -struct wmma_matrix_a <: wmma_fragment_use end -struct wmma_matrix_b <: wmma_fragment_use end -struct wmma_accumulator <: wmma_fragment_use end - - -export wmma_fragment - -struct wmma_fragment{M, N, K, FS, T, L <: wmma_fragment_layout, U <: wmma_fragment_use} - x::NTuple{FS, T} -end - -################################################################################ -# WMMA CONFIGURATION -################################################################################ - -export wmma_config -struct wmma_config{M, N, K} end - -################################################################################ -# CONSTANTS -################################################################################ - -map_matrix_to_use = Dict( - "a" => wmma_matrix_a, - "b" => wmma_matrix_b, - "c" => wmma_accumulator, - "d" => wmma_accumulator - ) - -################################################################################ -# WMMA LOAD -################################################################################ - -export wmma_load_a, wmma_load_b, wmma_load_c - -for mat in ["a", "b", "c"] - layout = "col" - shape = "m16n16k16" - addr_space = "" - elem_type = "f16" - - # Name of Julia function - func_name = Symbol("wmma_load_$mat") - - # Name of the Julia wrapper - wrapper = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)") - - # Get fragment size - frag_sz = get_frag_sz(mat, elem_type) - - # Get Julia element type - julia_type = get_jl_ty(mat, elem_type) - - # Get matrix use type - matrix_use = map_matrix_to_use[mat] - - # Get layout type - layout_ty = (mat == "c") ? wmma_unspecified : (layout == "col") ? wmma_col_major : wmma_row_major - - @eval function $func_name(addr::DevicePtr{Float16, AS.Global}, - stride::Number, - layout::Type{wmma_col_major}, - config::Type{wmma_config{16, 16, 16}}) - x = $wrapper(addr, stride) - return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_ty, $matrix_use}(x) - end -end - -################################################################################ -# WMMA MMA -################################################################################ - -export wmma_mma - -function wmma_mma(a::wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_a}, - b::wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_b}, - c::wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}) - x = llvm_wmma_mma_col_col_m16n16k16_f16_f16(a.x, b.x, c.x) - return wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}(x) -end - -################################################################################ -# WMMA STORE -################################################################################ - -export wmma_store_d - -function wmma_store_d(addr::DevicePtr{Float16, AS.Global}, - d::wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}, - stride::Number, - layout::Type{wmma_col_major}, - config::Type{wmma_config{16, 16, 16}}) - llvm_wmma_store_d_col_m16n16k16_stride_f16(addr, d.x, stride) - return nothing -end - -################################################################################ -# WMMA FILL FRAGMENT -################################################################################ - -# TODO diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 06d4df0a..9929ed1f 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -244,3 +244,124 @@ for a_layout in ["col", "row"], @eval export $func_name end + +################################################################################ +# HIGH LEVEL (CUDA-STYLE API) +################################################################################ + +# ------------- +# WMMA fragment +# ------------- + +export wmma_row_major, wmma_col_major, wmma_unspecified + +abstract type wmma_fragment_layout end +struct wmma_row_major <: wmma_fragment_layout end +struct wmma_col_major <: wmma_fragment_layout end +struct wmma_unspecified <: wmma_fragment_layout end + + +export wmma_matrix_a, wmma_matrix_b, wmma_accumulator + +abstract type wmma_fragment_use end +struct wmma_matrix_a <: wmma_fragment_use end +struct wmma_matrix_b <: wmma_fragment_use end +struct wmma_accumulator <: wmma_fragment_use end + +map_matrix_to_use = Dict( + "a" => wmma_matrix_a, + "b" => wmma_matrix_b, + "c" => wmma_accumulator, + "d" => wmma_accumulator + ) + + +export wmma_fragment + +struct wmma_fragment{M, N, K, FS, T, L <: wmma_fragment_layout, U <: wmma_fragment_use} + x::NTuple{FS, T} +end + +# ------------------ +# WMMA configuration +# ------------------ + +export wmma_config +struct wmma_config{M, N, K} end + + +# --------- +# WMMA load +# --------- + +export wmma_load_a, wmma_load_b, wmma_load_c + +for mat in ["a", "b", "c"] + layout = "col" + shape = "m16n16k16" + addr_space = "" + elem_type = "f16" + + # Name of Julia function + func_name = Symbol("wmma_load_$mat") + + # Name of the Julia wrapper + wrapper = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)") + + # Get fragment size + frag_sz = get_frag_sz(mat, elem_type) + + # Get Julia element type + julia_type = get_jl_ty(mat, elem_type) + + # Get matrix use type + matrix_use = map_matrix_to_use[mat] + + # Get layout type + layout_ty = (mat == "c") ? wmma_unspecified : (layout == "col") ? wmma_col_major : wmma_row_major + + @eval function $func_name(addr::DevicePtr{Float16, AS.Global}, + stride::Number, + layout::Type{wmma_col_major}, + config::Type{wmma_config{16, 16, 16}}) + x = $wrapper(addr, stride) + return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_ty, $matrix_use}(x) + end +end + + +# ------------------------ +# WMMA multiply-accumulate +# ------------------------ + +export wmma_mma + +function wmma_mma(a::wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_a}, + b::wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_b}, + c::wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}) + x = llvm_wmma_mma_col_col_m16n16k16_f16_f16(a.x, b.x, c.x) + return wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}(x) +end + + +# ---------- +# WMMA store +# ---------- + +export wmma_store_d + +function wmma_store_d(addr::DevicePtr{Float16, AS.Global}, + d::wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}, + stride::Number, + layout::Type{wmma_col_major}, + config::Type{wmma_config{16, 16, 16}}) + llvm_wmma_store_d_col_m16n16k16_stride_f16(addr, d.x, stride) + return nothing +end + + +# ------------------ +# WMMA fill fragment +# ------------------ + +# TODO diff --git a/test/device/wmma-highlevel.jl b/test/device/wmma-highlevel.jl deleted file mode 100644 index 21b885f0..00000000 --- a/test/device/wmma-highlevel.jl +++ /dev/null @@ -1,35 +0,0 @@ -@testset "WMMA" begin - @testset "CUDA C-style API" begin - - @testset "One specific case" begin - a = rand(Float16, (16, 16)) - b = rand(Float16, (16, 16)) - c = rand(Float16, (16, 16)) - d = Array{Float16}(undef, (16, 16)) - - a_dev = CuArray(a) - b_dev = CuArray(b) - c_dev = CuArray(c) - d_dev = CuArray(d) - - function kernel(a_dev, b_dev, c_dev, d_dev) - conf = wmma_config{16, 16, 16} - - a_frag = wmma_load_a(pointer(a_dev), 16, wmma_col_major, conf) - b_frag = wmma_load_b(pointer(b_dev), 16, wmma_col_major, conf) - c_frag = wmma_load_c(pointer(c_dev), 16, wmma_col_major, conf) - - d_frag = wmma_mma(a_frag, b_frag, c_frag) - - wmma_store_d(pointer(d_dev), d_frag, 16, wmma_col_major, conf) - - return - end - - @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) - d = Array(d_dev) - @test a * b + c ≈ d rtol=0.01 - end - - end -end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 2e083bff..f977dbcc 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -132,4 +132,39 @@ ################################################################################ + @testset "CUDA C-style API" begin + + @testset "One specific case" begin + a = rand(Float16, (16, 16)) + b = rand(Float16, (16, 16)) + c = rand(Float16, (16, 16)) + d = Array{Float16}(undef, (16, 16)) + + a_dev = CuArray(a) + b_dev = CuArray(b) + c_dev = CuArray(c) + d_dev = CuArray(d) + + function kernel(a_dev, b_dev, c_dev, d_dev) + conf = wmma_config{16, 16, 16} + + a_frag = wmma_load_a(pointer(a_dev), 16, wmma_col_major, conf) + b_frag = wmma_load_b(pointer(b_dev), 16, wmma_col_major, conf) + c_frag = wmma_load_c(pointer(c_dev), 16, wmma_col_major, conf) + + d_frag = wmma_mma(a_frag, b_frag, c_frag) + + wmma_store_d(pointer(d_dev), d_frag, 16, wmma_col_major, conf) + + return + end + + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) + d = Array(d_dev) + @test a * b + c ≈ d rtol=0.01 + end + + end + +################################################################################ end diff --git a/test/runtests.jl b/test/runtests.jl index 954f22f4..1d70f530 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,8 +70,7 @@ else #= include("device/pointer.jl") =# #= include("device/array.jl") =# #= include("device/cuda.jl") =# - #= include("device/wmma.jl") =# - include("device/wmma-highlevel.jl") + include("device/wmma.jl") #= include("examples.jl") =# end From 91d6ee7993e21738ba41a509446b3a45c0d01dad Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 13:39:39 +0100 Subject: [PATCH 06/81] Finish load --- src/device/cuda/wmma.jl | 66 ++++++++++++++++++++++++++++++----------- test/device/wmma.jl | 2 +- 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 9929ed1f..aeb8a15b 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -268,13 +268,6 @@ struct wmma_matrix_a <: wmma_fragment_use end struct wmma_matrix_b <: wmma_fragment_use end struct wmma_accumulator <: wmma_fragment_use end -map_matrix_to_use = Dict( - "a" => wmma_matrix_a, - "b" => wmma_matrix_b, - "c" => wmma_accumulator, - "d" => wmma_accumulator - ) - export wmma_fragment @@ -289,6 +282,29 @@ end export wmma_config struct wmma_config{M, N, K} end +# --------- +# Constants +# --------- + +map_matrix_to_use = Dict( + "a" => wmma_matrix_a, + "b" => wmma_matrix_b, + "c" => wmma_accumulator, + "d" => wmma_accumulator + ) + +map_address_space_to_ty = Dict( + "" => AS.Generic, + "shared" => AS.Shared, + "global" => AS.Global + ) + +# ---------------- +# Helper functions +# ---------------- + +get_matrix_use(mat) = map_matrix_to_use[mat] +get_address_space(as) = map_address_space_to_ty[as] # --------- # WMMA load @@ -296,17 +312,24 @@ struct wmma_config{M, N, K} end export wmma_load_a, wmma_load_b, wmma_load_c -for mat in ["a", "b", "c"] - layout = "col" - shape = "m16n16k16" - addr_space = "" - elem_type = "f16" +for mat in ["a", "b", "c"], + layout in ["col", "row"], + shape in ["m16n16k16"], + addr_space in ["", "shared", "global"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + + # Float32 is only supported for C + if (elem_type == "f32") && (mat != "c") + continue + end # Name of Julia function func_name = Symbol("wmma_load_$mat") # Name of the Julia wrapper - wrapper = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)") + wrapper = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) # Get fragment size frag_sz = get_frag_sz(mat, elem_type) @@ -315,17 +338,24 @@ for mat in ["a", "b", "c"] julia_type = get_jl_ty(mat, elem_type) # Get matrix use type - matrix_use = map_matrix_to_use[mat] + matrix_use = get_matrix_use(mat) # Get layout type - layout_ty = (mat == "c") ? wmma_unspecified : (layout == "col") ? wmma_col_major : wmma_row_major + layout_ty = (layout == "col") ? wmma_col_major : wmma_row_major + layout_ret_ty = (mat == "c") ? wmma_unspecified : layout_ty + + # Get pointer type + ptr_ty = (elem_type == "f32") ? Float32 : Float16 + + # Get address space type + as_ty = get_address_space(addr_space) - @eval function $func_name(addr::DevicePtr{Float16, AS.Global}, + @eval function $func_name(addr::DevicePtr{$ptr_ty, $as_ty}, stride::Number, - layout::Type{wmma_col_major}, + layout::Type{$layout_ty}, config::Type{wmma_config{16, 16, 16}}) x = $wrapper(addr, stride) - return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_ty, $matrix_use}(x) + return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_ret_ty, $matrix_use}(x) end end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index f977dbcc..2d9be728 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -134,7 +134,7 @@ @testset "CUDA C-style API" begin - @testset "One specific case" begin + @testset "Matrix multiply-accumulate" begin a = rand(Float16, (16, 16)) b = rand(Float16, (16, 16)) c = rand(Float16, (16, 16)) From db17cc66654b7bd7f04223010d97e6def58e09f5 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 13:45:28 +0100 Subject: [PATCH 07/81] Wrapper for store --- src/device/cuda/wmma.jl | 52 ++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index aeb8a15b..e20ec85d 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -342,7 +342,7 @@ for mat in ["a", "b", "c"], # Get layout type layout_ty = (layout == "col") ? wmma_col_major : wmma_row_major - layout_ret_ty = (mat == "c") ? wmma_unspecified : layout_ty + layout_frag_ty = (mat == "c") ? wmma_unspecified : layout_ty # Get pointer type ptr_ty = (elem_type == "f32") ? Float32 : Float16 @@ -355,7 +355,7 @@ for mat in ["a", "b", "c"], layout::Type{$layout_ty}, config::Type{wmma_config{16, 16, 16}}) x = $wrapper(addr, stride) - return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_ret_ty, $matrix_use}(x) + return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_frag_ty, $matrix_use}(x) end end @@ -380,13 +380,47 @@ end export wmma_store_d -function wmma_store_d(addr::DevicePtr{Float16, AS.Global}, - d::wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}, - stride::Number, - layout::Type{wmma_col_major}, - config::Type{wmma_config{16, 16, 16}}) - llvm_wmma_store_d_col_m16n16k16_stride_f16(addr, d.x, stride) - return nothing +for mat in ["d"], + layout in ["col", "row"], + shape in ["m16n16k16"], + addr_space in ["", "shared", "global"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # Name of Julia function + func_name = Symbol("wmma_store_$mat") + + # Name of the Julia wrapper + wrapper = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) + + # Get fragment size + frag_sz = get_frag_sz(mat, elem_type) + + # Get Julia element type + julia_type = get_jl_ty(mat, elem_type) + + # Get matrix use type + matrix_use = get_matrix_use(mat) + + # Get layout type + layout_ty = (layout == "col") ? wmma_col_major : wmma_row_major + layout_frag_ty = wmma_unspecified + + # Get pointer type + ptr_ty = (elem_type == "f32") ? Float32 : Float16 + + # Get address space type + as_ty = get_address_space(addr_space) + + @eval function $func_name(addr::DevicePtr{$ptr_ty, $as_ty}, + d::wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_frag_ty, $matrix_use}, + stride::Number, + layout::Type{$layout_ty}, + config::Type{wmma_config{16, 16, 16}}) + $wrapper(addr, d.x, stride) + return nothing + end + end From 53657fe7efb9ad94a2e91eff1861804bd6865885 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 14:10:47 +0100 Subject: [PATCH 08/81] Generalise MMA --- src/device/cuda/wmma.jl | 45 ++++++++++++++++++++++++++++++++++++----- test/device/wmma.jl | 2 +- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index e20ec85d..8b2001d1 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -366,11 +366,46 @@ end export wmma_mma -function wmma_mma(a::wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_a}, - b::wmma_fragment{16, 16, 16, 8, NTuple{2, VecElement{Float16}}, wmma_col_major, wmma_matrix_b}, - c::wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}) - x = llvm_wmma_mma_col_col_m16n16k16_f16_f16(a.x, b.x, c.x) - return wmma_fragment{16, 16, 16, 4, NTuple{2, VecElement{Float16}}, wmma_unspecified, wmma_accumulator}(x) +for a_layout in ["col", "row"], + b_layout in ["col", "row"], + shape in ["m16n16k16"], + d_elem_type in ["f16", "f32"], + c_elem_type in ["f16", "f32"], + b_elem_type in ["f16"], + a_elem_type in ["f16"] + + # Name of the Julia wrapper + wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) + + # Information about a + a_frag_sz = get_frag_sz("a", a_elem_type) + a_julia_type = get_jl_ty("a", a_elem_type) + a_layout_ty = (a_layout == "col") ? wmma_col_major : wmma_row_major + + # Information about b + b_frag_sz = get_frag_sz("b", b_elem_type) + b_julia_type = get_jl_ty("b", b_elem_type) + b_layout_ty = (b_layout == "col") ? wmma_col_major : wmma_row_major + + # Information about c + c_frag_sz = get_frag_sz("c", c_elem_type) + c_julia_type = get_jl_ty("c", c_elem_type) + + # Information about d + d_frag_sz = get_frag_sz("d", d_elem_type) + d_julia_type = get_jl_ty("d", d_elem_type) + + # We need some way to select if we want d to be 16 or 32-bit floating point + # during dispatch. + dispatch_ty = (d_elem_type == "f16") ? Float16 : Float32 + + @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_frag_sz, $a_julia_type, $a_layout_ty, wmma_matrix_a}, + b::wmma_fragment{16, 16, 16, $b_frag_sz, $b_julia_type, $b_layout_ty, wmma_matrix_b}, + c::wmma_fragment{16, 16, 16, $c_frag_sz, $c_julia_type, wmma_unspecified, wmma_accumulator}, + d_type::Type{$dispatch_ty}) + x = $wrapper(a.x, b.x, c.x) + return wmma_fragment{16, 16, 16, $d_frag_sz, $d_julia_type, wmma_unspecified, wmma_accumulator}(x) + end end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 2d9be728..0b1c656f 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -152,7 +152,7 @@ b_frag = wmma_load_b(pointer(b_dev), 16, wmma_col_major, conf) c_frag = wmma_load_c(pointer(c_dev), 16, wmma_col_major, conf) - d_frag = wmma_mma(a_frag, b_frag, c_frag) + d_frag = wmma_mma(a_frag, b_frag, c_frag, Float16) wmma_store_d(pointer(d_dev), d_frag, 16, wmma_col_major, conf) From 7a0b1dc1fc49bdb92b697e428c05f312b1de09b8 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 18:16:28 +0100 Subject: [PATCH 09/81] Generalise high level test --- test/device/wmma.jl | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 0b1c656f..dcafc548 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -134,34 +134,46 @@ @testset "CUDA C-style API" begin - @testset "Matrix multiply-accumulate" begin + @testset "MAC: A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [wmma_col_major, wmma_row_major], + b_layout in [wmma_col_major, wmma_row_major], + c_layout in [wmma_col_major, wmma_row_major], + d_layout in [wmma_col_major, wmma_row_major], + c_type in [Float16, Float32], + d_type in [Float16, Float32] + a = rand(Float16, (16, 16)) b = rand(Float16, (16, 16)) - c = rand(Float16, (16, 16)) - d = Array{Float16}(undef, (16, 16)) + c = rand(c_type, (16, 16)) + d = Array{d_type}(undef, (16, 16)) a_dev = CuArray(a) b_dev = CuArray(b) c_dev = CuArray(c) d_dev = CuArray(d) - function kernel(a_dev, b_dev, c_dev, d_dev) + @eval function kernel(a_dev, b_dev, c_dev, d_dev) conf = wmma_config{16, 16, 16} - a_frag = wmma_load_a(pointer(a_dev), 16, wmma_col_major, conf) - b_frag = wmma_load_b(pointer(b_dev), 16, wmma_col_major, conf) - c_frag = wmma_load_c(pointer(c_dev), 16, wmma_col_major, conf) + a_frag = wmma_load_a(pointer(a_dev), 16, $a_layout, conf) + b_frag = wmma_load_b(pointer(b_dev), 16, $b_layout, conf) + c_frag = wmma_load_c(pointer(c_dev), 16, $c_layout, conf) - d_frag = wmma_mma(a_frag, b_frag, c_frag, Float16) + d_frag = wmma_mma(a_frag, b_frag, c_frag, $d_type) - wmma_store_d(pointer(d_dev), d_frag, 16, wmma_col_major, conf) + wmma_store_d(pointer(d_dev), d_frag, 16, $d_layout, conf) return end @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) d = Array(d_dev) - @test a * b + c ≈ d rtol=0.01 + + new_a = (a_layout == wmma_col_major) ? a : transpose(a) + new_b = (b_layout == wmma_col_major) ? b : transpose(b) + new_c = (c_layout == wmma_col_major) ? c : transpose(c) + new_d = (d_layout == wmma_col_major) ? d : transpose(d) + + @test new_a * new_b + new_c ≈ new_d rtol=0.01 end end From d0e490c3e63e9191ce457a2934e95ac97927597e Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 18:29:46 +0100 Subject: [PATCH 10/81] Move d type to config --- src/device/cuda/wmma.jl | 8 ++++---- test/device/wmma.jl | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 8b2001d1..84f6fb47 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -280,7 +280,7 @@ end # ------------------ export wmma_config -struct wmma_config{M, N, K} end +struct wmma_config{M, N, K, d_type} end # --------- # Constants @@ -353,7 +353,7 @@ for mat in ["a", "b", "c"], @eval function $func_name(addr::DevicePtr{$ptr_ty, $as_ty}, stride::Number, layout::Type{$layout_ty}, - config::Type{wmma_config{16, 16, 16}}) + config::Type{wmma_config{16, 16, 16, d_type}}) where d_type x = $wrapper(addr, stride) return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_frag_ty, $matrix_use}(x) end @@ -402,7 +402,7 @@ for a_layout in ["col", "row"], @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_frag_sz, $a_julia_type, $a_layout_ty, wmma_matrix_a}, b::wmma_fragment{16, 16, 16, $b_frag_sz, $b_julia_type, $b_layout_ty, wmma_matrix_b}, c::wmma_fragment{16, 16, 16, $c_frag_sz, $c_julia_type, wmma_unspecified, wmma_accumulator}, - d_type::Type{$dispatch_ty}) + conf::Type{wmma_config{16, 16, 16, $dispatch_ty}}) x = $wrapper(a.x, b.x, c.x) return wmma_fragment{16, 16, 16, $d_frag_sz, $d_julia_type, wmma_unspecified, wmma_accumulator}(x) end @@ -451,7 +451,7 @@ for mat in ["d"], d::wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_frag_ty, $matrix_use}, stride::Number, layout::Type{$layout_ty}, - config::Type{wmma_config{16, 16, 16}}) + config::Type{wmma_config{16, 16, 16, d_type}}) where d_type $wrapper(addr, d.x, stride) return nothing end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index dcafc548..78b3e777 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -152,13 +152,13 @@ d_dev = CuArray(d) @eval function kernel(a_dev, b_dev, c_dev, d_dev) - conf = wmma_config{16, 16, 16} + conf = wmma_config{16, 16, 16, $d_type} a_frag = wmma_load_a(pointer(a_dev), 16, $a_layout, conf) b_frag = wmma_load_b(pointer(b_dev), 16, $b_layout, conf) c_frag = wmma_load_c(pointer(c_dev), 16, $c_layout, conf) - d_frag = wmma_mma(a_frag, b_frag, c_frag, $d_type) + d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) wmma_store_d(pointer(d_dev), d_frag, 16, $d_layout, conf) From 7ec4877dca1f45832e2c10edb854af6d3d1d3074 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 18:59:48 +0100 Subject: [PATCH 11/81] Add fill_fragment function --- src/device/cuda/wmma.jl | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 84f6fb47..99ada7c0 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -463,4 +463,34 @@ end # WMMA fill fragment # ------------------ -# TODO +export wmma_fill_a, wmma_fill_b, wmma_fill_c + +for mat in ["a", "b", "c"], + elem_type in ["f16", "f32"] + + # Float32 is only supported for C + if (elem_type == "f32") && (mat != "c") + continue + end + + # Name of the Julia function + func_name = Symbol("wmma_fill_$mat") + + # Get fragment size + frag_sz = get_frag_sz(mat, elem_type) + + # Value type + val_type = (elem_type == "f16") ? Float16 : Float32 + + # Returned tuple + if elem_type == "f16" + tuple = :(ntuple(i -> ntuple(j -> VecElement{Float16}(value), 2), $frag_sz)) + else + tuple = :(ntuple(i -> value, $frag_sz)) + end + + @eval function $func_name(value::$val_type) + x = $tuple + return x + end +end From cf3ba1916930f9a3e2217cadbe147c0a44e3cc11 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 19:33:13 +0100 Subject: [PATCH 12/81] Add tests for multiply --- src/device/cuda/wmma.jl | 14 ++++++-------- test/device/wmma.jl | 18 ++++++++++++++---- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 99ada7c0..978f0fad 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -463,22 +463,20 @@ end # WMMA fill fragment # ------------------ -export wmma_fill_a, wmma_fill_b, wmma_fill_c +export wmma_fill_c -for mat in ["a", "b", "c"], +for mat in ["c"], elem_type in ["f16", "f32"] - # Float32 is only supported for C - if (elem_type == "f32") && (mat != "c") - continue - end - # Name of the Julia function func_name = Symbol("wmma_fill_$mat") # Get fragment size frag_sz = get_frag_sz(mat, elem_type) + # Get Julia type + julia_type = get_jl_ty(mat, elem_type) + # Value type val_type = (elem_type == "f16") ? Float16 : Float32 @@ -491,6 +489,6 @@ for mat in ["a", "b", "c"], @eval function $func_name(value::$val_type) x = $tuple - return x + return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, wmma_unspecified, wmma_accumulator}(x) end end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 78b3e777..9e23bb32 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -134,12 +134,13 @@ @testset "CUDA C-style API" begin - @testset "MAC: A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [wmma_col_major, wmma_row_major], + @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [wmma_col_major, wmma_row_major], b_layout in [wmma_col_major, wmma_row_major], c_layout in [wmma_col_major, wmma_row_major], d_layout in [wmma_col_major, wmma_row_major], c_type in [Float16, Float32], - d_type in [Float16, Float32] + d_type in [Float16, Float32], + do_mac in [true, false] a = rand(Float16, (16, 16)) b = rand(Float16, (16, 16)) @@ -156,7 +157,12 @@ a_frag = wmma_load_a(pointer(a_dev), 16, $a_layout, conf) b_frag = wmma_load_b(pointer(b_dev), 16, $b_layout, conf) - c_frag = wmma_load_c(pointer(c_dev), 16, $c_layout, conf) + + if $do_mac + c_frag = wmma_load_c(pointer(c_dev), 16, $c_layout, conf) + else + c_frag = wmma_fill_c($c_type(0)) + end d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) @@ -173,7 +179,11 @@ new_c = (c_layout == wmma_col_major) ? c : transpose(c) new_d = (d_layout == wmma_col_major) ? d : transpose(d) - @test new_a * new_b + new_c ≈ new_d rtol=0.01 + if do_mac + @test new_a * new_b + new_c ≈ new_d rtol=0.01 + else + @test new_a * new_b ≈ new_d rtol=0.01 + end end end From 740559ae34862dacfd4daeb28326691901c61cdf Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 19:35:48 +0100 Subject: [PATCH 13/81] Add configuration variable to fill --- src/device/cuda/wmma.jl | 4 +++- test/device/wmma.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 978f0fad..40fdef49 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -487,7 +487,9 @@ for mat in ["c"], tuple = :(ntuple(i -> value, $frag_sz)) end - @eval function $func_name(value::$val_type) + @eval function $func_name(value::$val_type, + config::Type{wmma_config{M, N, K, d_type}}) where {M, N, K, d_type} + x = $tuple return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, wmma_unspecified, wmma_accumulator}(x) end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 9e23bb32..7156537f 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -161,7 +161,7 @@ if $do_mac c_frag = wmma_load_c(pointer(c_dev), 16, $c_layout, conf) else - c_frag = wmma_fill_c($c_type(0)) + c_frag = wmma_fill_c($c_type(0), conf) end d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) From 44898c62b458f9c9dafc64f1d48c17a6eadea9a7 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 20:33:46 +0100 Subject: [PATCH 14/81] Add documentation --- src/device/cuda/wmma.jl | 132 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 1 deletion(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 40fdef49..985c9454 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -253,11 +253,40 @@ end # WMMA fragment # ------------- -export wmma_row_major, wmma_col_major, wmma_unspecified +export wmma_fragment_layout, wmma_row_major, wmma_col_major, wmma_unspecified +""" + wmma_fragment_layout + +Abstract type that specifies the storage layout of a matrix. + +Possible values are [`wmma_row_major`](@ref), [`wmma_col_major`](@ref) and [`wmma_unspecified`](@ref). +""" abstract type wmma_fragment_layout end + +""" + wmma_row_major + +Type that represents a matrix stored in row major (C style) order. +""" struct wmma_row_major <: wmma_fragment_layout end + +""" + wmma_col_major + +Type that represents a matrix stored in column major (Julia style) order. +""" struct wmma_col_major <: wmma_fragment_layout end + +""" + wmma_unspecified + +Type that represents a matrix stored in an unspecified order. + +!!! warning + + This storage format is not valid for all WMMA operations! +""" struct wmma_unspecified <: wmma_fragment_layout end @@ -271,6 +300,13 @@ struct wmma_accumulator <: wmma_fragment_use end export wmma_fragment +""" + wmma_fragment + +Type that represents per-thread intermediate results of WMMA operations. + +You can access individual elements using the `x` member, but beware that the exact ordering of elements is unspecified. +""" struct wmma_fragment{M, N, K, FS, T, L <: wmma_fragment_layout, U <: wmma_fragment_use} x::NTuple{FS, T} end @@ -280,6 +316,25 @@ end # ------------------ export wmma_config + +""" + wmma_config{M, N, K, d_type} + +Type that contains all information for WMMA operations that cannot be inferred from the argument's types. + +WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix, +``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices. + +`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. + +All WMMA operations take a `wmma_config` as their final argument. + +# Examples +```jldoctest +julia> config = wmma_config{16, 16, 16, Float32} +wmma_config{16,16,16,Float32} +``` +""" struct wmma_config{M, N, K, d_type} end # --------- @@ -312,6 +367,28 @@ get_address_space(as) = map_address_space_to_ty[as] export wmma_load_a, wmma_load_b, wmma_load_c +""" + wmma_load_a(addr, stride, layout, config) + wmma_load_b(addr, stride, layout, config) + wmma_load_c(addr, stride, layout, config) + +Load the matrix `a`, `b` or `c` from the memory location indicated by `addr`, and return the resulting [`wmma_fragment`](@ref). + +# Arguments +- `addr`: The address to load the matrix from. +- `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. +- `layout`: The storage layout of the matrix. Possible values are [`wmma_row_major`](@ref) and [`wmma_col_major`](@ref). +- `config`: The WMMA configuration that should be used for loading this matrix. See [`wmma_config`](@ref). + +See also: [`wmma_fragment`](@ref), [`wmma_fragment_layout`](@ref), [`wmma_config`](@ref) + +!!! warning + + All threads in a warp **MUST** execute the load operation in lockstep, and have to use exactly the same arguments. + Failure to do so will result in undefined behaviour. +""" +wmma_load_a, wmma_load_b, wmma_load_c + for mat in ["a", "b", "c"], layout in ["col", "row"], shape in ["m16n16k16"], @@ -366,6 +443,25 @@ end export wmma_mma +""" + wmma_mma(a, b, c, conf) + +Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. + +# Arguments + +- `a`: The [`wmma_fragment`](@ref) corresponding to the matrix ``A``. +- `b`: The [`wmma_fragment`](@ref) corresponding to the matrix ``B``. +- `c`: The [`wmma_fragment`](@ref) corresponding to the matrix ``C``. +- `conf`: The [`wmma_config`](@ref) that should be used in this WMMA operation. + +!!! warning + + All threads in a warp **MUST** execute the `mma` operation in lockstep, and have to use exactly the same arguments. + Failure to do so will result in undefined behaviour. +""" +wmma_mma + for a_layout in ["col", "row"], b_layout in ["col", "row"], shape in ["m16n16k16"], @@ -415,6 +511,27 @@ end export wmma_store_d +""" + wmma_store_d(addr, d, stride, layout, config) + +Store the result matrix `d` to the memory location indicated by `addr`. + +# Arguments +- `addr`: The address to store the matrix to. +- `d`: The [`wmma_fragment`](@ref) corresponding to the `d` matrix. +- `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. +- `layout`: The storage layout of the matrix. Possible values are [`wmma_row_major`](@ref) and [`wmma_col_major`](@ref). +- `config`: The WMMA configuration that should be used for storing this matrix. See [`wmma_config`](@ref). + +See also: [`wmma_fragment`](@ref), [`wmma_fragment_layout`](@ref), [`wmma_config`](@ref) + +!!! warning + + All threads in a warp **MUST** execute the `store` operation in lockstep, and have to use exactly the same arguments. + Failure to do so will result in undefined behaviour. +""" +wmma_store_d + for mat in ["d"], layout in ["col", "row"], shape in ["m16n16k16"], @@ -465,6 +582,19 @@ end export wmma_fill_c +""" + wmma_fill_c(value, config) + +Return a [`wmma_fragment`](@ref) filled with the value `value`. + +This operation is useful if you want to implement a matrix multiplication (and thus want to set ``C = O``). + +# Arguments +- `value`: The value used to fill the fragment. Can be a `Float16` or `Float32`. +- `config`: The WMMA configuration that should be used for this WMMA operation. See [`wmma_config`](@ref). +""" +wmma_fill_c + for mat in ["c"], elem_type in ["f16", "f32"] From cb753daac2900709762a97b4c19f7b9374ae6226 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 10 Nov 2019 23:39:50 +0100 Subject: [PATCH 15/81] Add documentation --- docs/make.jl | 1 + docs/src/lib/device/wmma.md | 138 ++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 docs/src/lib/device/wmma.md diff --git a/docs/make.jl b/docs/make.jl index 7a1114fa..1a61a2ba 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,6 +19,7 @@ makedocs( "lib/reflection.md", "Device Code" => [ "lib/device/cuda.md", + "lib/device/wmma.md", "lib/device/array.md" ] ] diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md new file mode 100644 index 00000000..96aa67ae --- /dev/null +++ b/docs/src/lib/device/wmma.md @@ -0,0 +1,138 @@ +# WMMA + +This section details CUDAnative's interface to CUDA's warp matrix multiply-accumulate (WMMA) operations. +This interface enables programmatic access to Tensor Cores, a new hardware feature in Volta that performs mixed precision matrix MAC operations. + +Access to WMMA using CUDAnative is available in two levels: low level wrappers around the LLVM intrinsics, and a higher-level API, similar to that of CUDA C. + +## Introduction of Terminology + +The WMMA operations perform a matrix multiply-accumulate. +More concretely, it calculates ``D = A \cdot B + C``, where ``A`` is a ``M \times K`` matrix, ``B`` is a ``K \times N`` matrix, and ``C`` and ``D`` are ``M \times N`` matrices. + +Note that not all values of ``M``, ``N`` and ``K`` are allowed. +The tuple ``(M, N, K)`` is often called the "shape" of the multiply accumulate operation. + +The multiply-accumulate consists of the following steps: +- Load the matrices ``A``, ``B`` and ``C`` from memory to registers using a WMMA load operation. +- Perform the matrix multiply-accumulate of ``A``, ``B`` and ``C`` to obtain ``D`` using a WMMA MMA operation. ``D`` is stored in hardware registers after this step. +- Store the result ``D`` back to memory using a WMMA store operation. + +Note that WMMA is a warp-wide operation, which means that all threads in a warp must cooperate, and execute the WMMA operations in lockstep. +Failure to do so will result in undefined behaviour. + +Each thread in a warp will hold a part of the matrix in its registers. +In WMMA parlance, this part is referred to as a "fragment". +Note that the exact mapping between matrix elements and fragment is unspecified, and subject to change in future versions. + +Finally, it is important to note that the resultant ``D`` matrix can be used as a ``C`` matrix for a subsequent multiply-accumulate. +This is useful if one needs to calculate a sum of the form ``\sum_{i=0}^{n} A_i B_i``, where ``A_i`` and ``B_i`` are matrices of the correct dimension. + +## LLVM Intrinsics + +The LLVM intrinsics are accessible by using the one-to-one Julia wrappers. +The return type of each wrapper is the Julia type that corresponds closest to the return type of the LLVM intrinsic. +For example, LLVM's `[8 x <2 x half>]` becomes `NTuple{8, NTuple{2, VecElement{Float16}}}` in Julia. +In essence, these wrappers return the SSA values returned by the LLVM intrinsic. +Currently, all intrinsics that are available in LLVM 6, PTX 6.0 and SM 70 are implemented. + +These LLVM intrinsics are then lowered to the correct PTX instructions by the LLVM NVPTX backend. +For more information about the PTX instructions, please refer to the [PTX Instruction Set Architecture Manual](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions). + +The LLVM intrinsics are subdivided in three categories: load, store and multiply-accumulate. +In what follows, each of these will be discussed. + +### Load matrix + +**Julia function:** `llvm_wmma_load_{matrix}_{layout}_{shape}_{addr_space}_stride_{elem_type}(src_addr, stride)` + +**Corresponding LLVM instrinsic:** `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}` + +**Arguments:** +- `src_addr`: The memory address to load from. +- `stride`: The leading dimension of the matrix, in numbers of elements. + +**Placeholders:** +- `{matrix}`: The matrix to load. Can be `a`, `b` or `c`. +- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. +- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). Note that `f32` is only valid for the matrix ``C``. + +### Perform multiply-accumulate + +**Julia function:** `llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c)` + +**Corresponding LLVM instrinsic:** `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}` + +**Arguments:** +- `a`: The WMMA fragment corresponding to the matrix ``A``. +- `b`: The WMMA fragment corresponding to the matrix ``B``. +- `c`: The WMMA fragment corresponding to the matrix ``C``. + +**Placeholders:** +- `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. +- `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). +- `{c_elem_type}`: The type of each element in the ``C`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). + +!!! warning + + Remember that the shape, type and layout of all operations (be it MMA, load or store) **MUST** match. + Otherwise, the behaviour is undefined! + +### Store matrix + +**Julia function:** `llvm_wmma_store_d_{layout}_{shape}_{addr_space}_stride_{elem_type}(dst_addr, data, stride)` + +**Corresponding LLVM intrinsic:** `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}` + +**Arguments:** +- `dst_addr`: The memory address to store to. +- `data`: The ``D`` fragment to store. +- `stride`: The leading dimension of the matrix, in numbers of elements. + +**Placeholders:** +- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. +- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). + +### Example + +```julia +using CUDAnative +using CuArrays +using Test + +# Generate input matrices +a = rand(Float16, (16, 16)) +a_dev = CuArray(a) +b = rand(Float16, (16, 16)) +b_dev = CuArray(b) +c = rand(Float32, (16, 16)) +c_dev = CuArray(c) + +# Allocate space for result +d_dev = similar(c_dev) + +# Matrix multiply-accumulate kernel (D = A * B + C) +function kernel(a_dev, b_dev, c_dev, d_dev) + a_frag = llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16) + b_frag = llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16) + c_frag = llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16) + + d_frag = llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag) + + llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16) + return +end + +@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) +@test a * b + c ≈ Array(d_dev) rtol=0.01 +``` + +## CUDA C-like API + +NYI From 381e25a6413f3dc8a93cd0c9dfcfaee2d9f540ff Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Mon, 11 Nov 2019 00:25:06 +0100 Subject: [PATCH 16/81] Add initial documentation CUDA style API --- docs/src/lib/device/wmma.md | 74 ++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index 96aa67ae..5d446439 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -135,4 +135,76 @@ end ## CUDA C-like API -NYI +TODO + +### Fragment +```@docs +CUDAnative.wmma_fragment_layout +CUDAnative.wmma_row_major +CUDAnative.wmma_col_major +CUDAnative.wmma_unspecified +CUDAnative.wmma_fragment +``` + +### WMMA configuration +```@docs +CUDAnative.wmma_config +``` + +### Load matrix +```@docs +CUDAnative.wmma_load_a +CUDAnative.wmma_load_b +CUDAnative.wmma_load_c +``` + +### Perform multiply-accumulate +```@docs +CUDAnative.wmma_mma +``` + +### Store matrix +```@docs +CUDAnative.wmma_store_d +``` + +### Fill fragment +```@docs +CUDAnative.wmma_fill_c +``` + +### Example + +```julia +using CUDAnative +using CuArrays +using Test + +a = rand(Float16, (16, 16)) +b = rand(Float16, (16, 16)) +c = rand(Float32, (16, 16)) + +a_dev = CuArray(a) +b_dev = CuArray(b) +c_dev = CuArray(c) +d_dev = similar(c_dev) + +function kernel(a_dev, b_dev, c_dev, d_dev) + conf = wmma_config{16, 16, 16, Float32} + + a_frag = wmma_load_a(pointer(a_dev), 16, wmma_col_major, conf) + b_frag = wmma_load_b(pointer(b_dev), 16, wmma_col_major, conf) + c_frag = wmma_load_c(pointer(c_dev), 16, wmma_col_major, conf) + + d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) + + wmma_store_d(pointer(d_dev), d_frag, 16, wmma_col_major, conf) + + return +end + +@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) +d = Array(d_dev) + +@test a * b + c ≈ d rtol=0.01 +``` From 3ed3a17e933acdf715791ecee286f04b4a04412a Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Mon, 11 Nov 2019 13:39:26 +0100 Subject: [PATCH 17/81] Finalise documentation --- docs/src/lib/device/wmma.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index 5d446439..0b4e1c6e 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -135,7 +135,23 @@ end ## CUDA C-like API -TODO +The main difference between the CUDA C-like API and the lower level wrappers, is that the former enforces several constraints when working with WMMA. +For example, it ensures that the ``A`` fragment argument to the MMA instruction was obtained by a `load_a` call, and not by a `load_b` or `load_c`. +Additionally, it makes sure that the data type and storage layout of the load/store operations and the MMA operation match. + +The CUDA C-like API heavily uses Julia's dispatch mechanism. +As such, the method names are much shorter than the LLVM intrinsic wrappers, as most information is baked into the type of the arguments rather than the method name. + + +Note that, in CUDA C++, the fragment is responsible for both the storage of intermediate results and the WMMA configuration. +All CUDA C++ WMMA calls are function templates that take the resultant fragment as a by-reference argument. +As a result, the type of this argument can be used during overload resolution to select the correct WMMA instruction to call. + +In contrast, the API in Julia separates the WMMA storage ([`wmma_fragment`](@ref)) and configuration ([`wmma_config`](@ref)). +Instead of taking the resultant fragment by reference, the Julia functions just return it. +This makes the dataflow clearer, but it also means that the type of that fragment cannot be used for selection of the correct WMMA instruction. +Thus, there is still a limited amount of information that cannot be inferred from the argument types, but must nonetheless match for all WMMA operations, such as the overall shape of the MMA. +This is accomplished by a separate "WMMA configuration" (see [`wmma_config`](@ref)) that you create once, and then give as an argument to all intrinsics. ### Fragment ```@docs From 18394c1ec21d61dd88fac41db903ff36070f7f0a Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 15 Nov 2019 17:31:30 +0100 Subject: [PATCH 18/81] Implement tests for shared address space --- test/device/wmma.jl | 62 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 2e083bff..11e11cb8 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -1,5 +1,17 @@ @testset "WMMA" begin +################################################################################ + +@eval generic_to_shared(ptr) = Base.llvmcall( + " + %ptr.generic = inttoptr i64 %0 to i8* + %ptr.shared = addrspacecast i8* %ptr.generic to i8 addrspace(3)* + %ret = ptrtoint i8 addrspace(3)* %ptr.shared to i64 + ret i64 %ret", + Int64, + Tuple{Int64}, + convert(Int64, ptr)) + ################################################################################ @testset "LLVM intrinsics" begin @@ -8,12 +20,10 @@ @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"], layout in ["row", "col"], shape in ["m16n16k16"], - addr_space in [""], + addr_space in ["", "_global", "_shared"], stride in ["stride"], elem_type in ["f16", "f32"] - # TODO: Test address space? - # Float32 is only supported for C if (elem_type == "f32") && (mat != "c") continue @@ -23,17 +33,30 @@ array_ty = elem_type == "f16" ? Float16 : Float32 expected = elem_type == "f16" ? (VecElement{Float16}(42), VecElement{Float16}(42)) : Float32(42) + # Address-space dependent variables + do_shared_test = (addr_space == "_shared") + # Get the function name - func = getfield(Main, Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)_stride_$(elem_type)")) + func = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") input = 42 * ones(array_ty, (16, 16)) input_dev = CuArray(input) result = Array{Bool}(undef, 1) result_dev = CuArray(result) - function kernel(input_dev, result_dev) - data = func(pointer(input_dev), 16) - result_dev[1] = all(val -> val == expected, data) + @eval @inbounds function kernel(input_dev, result_dev) + if $do_shared_test + input_shared = @cuStaticSharedMem($array_ty, 256) + fill!(input_shared, 42) + + data = $func(generic_to_shared(input_shared.ptr), 16) + + result_dev[1] = all(val -> val == $expected, data) + else + data = $func(pointer(input_dev), 16) + result_dev[1] = all(val -> val == $expected, data) + end + return end @@ -46,12 +69,10 @@ @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], layout in ["row", "col"], shape in ["m16n16k16"], - addr_space in [""], + addr_space in ["", "_global", "_shared"], stride in ["stride"], elem_type in ["f16", "f32"] - # TODO: Test address space? - # Type-dependent variables array_ty = elem_type == "f16" ? Float16 : Float32 data = elem_type == "f16" ? @@ -63,13 +84,28 @@ ) : (42, 42, 42, 42, 42, 42, 42, 42) # Get the function name - func = getfield(Main, Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)_stride_$(elem_type)")) + func = Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") + + # Address-space dependent variables + do_shared_test = (addr_space == "_shared") output = Array{array_ty}(undef, (16, 16)) output_dev = CuArray(output) - function kernel(output_dev) - func(pointer(output_dev), data, 16) + + @eval function kernel(output_dev) + if $do_shared_test + shared_mem = @cuStaticSharedMem($array_ty, 256) + ptr = generic_to_shared(pointer(shared_mem)) + $func(ptr, $data, 16) + + for i = 1:256 + @inbounds output_dev[i] = shared_mem[i] + end + else + $func(pointer(output_dev), $data, 16) + end + return end From 0b8fff8474e2d25c1ec7b66b4b4594b1f4081c01 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 23 Nov 2019 14:25:30 +0100 Subject: [PATCH 19/81] Change default shared memory alignment --- src/device/cuda/memory_shared.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/device/cuda/memory_shared.jl b/src/device/cuda/memory_shared.jl index a40f6807..6223dbd5 100644 --- a/src/device/cuda/memory_shared.jl +++ b/src/device/cuda/memory_shared.jl @@ -80,8 +80,9 @@ end initializer!(gv, null(gv_typ)) end # by requesting a larger-than-datatype alignment, we might be able to vectorize. - # we pick 16 bytes since this is the largest transaction size as supported by PTX. - alignment!(gv, Base.max(16, datatype_align(T))) + # we pick 32 bytes here, since WMMA instructions require 32-byte alignment. + # TODO: Make the alignment configurable + alignment!(gv, Base.max(32, datatype_align(T))) # generate IR Builder(JuliaContext()) do builder From 410a12b49d52fa32cedcbf203728926e3d50a4b0 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Mon, 25 Nov 2019 09:14:45 +0100 Subject: [PATCH 20/81] Change equality test --- test/device/wmma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 11e11cb8..5c0c1a86 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -161,7 +161,7 @@ new_a = (a_layout == "col" ? a : transpose(a)) new_b = (b_layout == "col" ? b : transpose(b)) - @test new_a * new_b + c ≈ Array(d_dev) rtol=0.01 + @test all(isapprox.(new_a * new_b + c, Array(d_dev); rtol=sqrt(eps(Float16)))) end end end From ab54af06d61febf908f8b1e28967fe9f12f59ff8 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Mon, 25 Nov 2019 09:25:20 +0100 Subject: [PATCH 21/81] Change equality test --- test/device/wmma.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 5eddb922..4a40061a 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -216,9 +216,9 @@ new_d = (d_layout == wmma_col_major) ? d : transpose(d) if do_mac - @test new_a * new_b + new_c ≈ new_d rtol=0.01 + @test all(isapprox.(new_a * new_b + new_c, new_d; rtol=sqrt(eps(Float16)))) else - @test new_a * new_b ≈ new_d rtol=0.01 + @test all(isapprox.(new_a * new_b, new_d; rtol=sqrt(eps(Float16)))) end end From f44388a0a447fc2e31a48309bdb0c1b23a770d34 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 26 Nov 2019 22:10:24 +0100 Subject: [PATCH 22/81] Change load to ccall --- src/device/cuda/wmma.jl | 11 +-- test/device/wmma.jl | 200 ++++++++++++++++++++-------------------- test/runtests.jl | 18 ++-- 3 files changed, 114 insertions(+), 115 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 06d4df0a..05588152 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -81,7 +81,7 @@ for mat in ["a", "b", "c"], func_name = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) # Name of the LLVM intrinsic - llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "load", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") + llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "load", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") # Determine types for this (matrix, elem_type) combination sz = get_frag_sz(mat, elem_type) @@ -106,12 +106,11 @@ for mat in ["a", "b", "c"], ret [$sz x $lc_ty] %ret.aggr.$(sz-1) ") - @eval $func_name(src_addr, stride) = Base.llvmcall($ir, - NTuple{$sz, $jl_ty}, - Tuple{Int64, Int32}, - convert(Int64, src_addr), - convert(Int32, stride)) + base_type = elem_type == "f16" ? Float16 : Float32 + + ccall_name = "extern $llvm_intr" + @eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $jl_ty}, (Ref{$base_type}, Int32), src_addr, stride) @eval export $func_name end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 5c0c1a86..c9164f49 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -20,7 +20,7 @@ @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"], layout in ["row", "col"], shape in ["m16n16k16"], - addr_space in ["", "_global", "_shared"], + addr_space in ["", "_global" #=, "_shared" =#], stride in ["stride"], elem_type in ["f16", "f32"] @@ -65,105 +65,105 @@ end end - @testset "llvm_wmma_store" begin - @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], - layout in ["row", "col"], - shape in ["m16n16k16"], - addr_space in ["", "_global", "_shared"], - stride in ["stride"], - elem_type in ["f16", "f32"] - - # Type-dependent variables - array_ty = elem_type == "f16" ? Float16 : Float32 - data = elem_type == "f16" ? - ( - (VecElement{Float16}(42), VecElement{Float16}(42)), - (VecElement{Float16}(42), VecElement{Float16}(42)), - (VecElement{Float16}(42), VecElement{Float16}(42)), - (VecElement{Float16}(42), VecElement{Float16}(42)) - ) : (42, 42, 42, 42, 42, 42, 42, 42) - - # Get the function name - func = Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") - - # Address-space dependent variables - do_shared_test = (addr_space == "_shared") - - output = Array{array_ty}(undef, (16, 16)) - output_dev = CuArray(output) - - - @eval function kernel(output_dev) - if $do_shared_test - shared_mem = @cuStaticSharedMem($array_ty, 256) - ptr = generic_to_shared(pointer(shared_mem)) - $func(ptr, $data, 16) - - for i = 1:256 - @inbounds output_dev[i] = shared_mem[i] - end - else - $func(pointer(output_dev), $data, 16) - end - - return - end - - @cuda threads=32 kernel(output_dev) - @test all(Array(output_dev) .== 42.0) - end - end - - @testset "llvm_wmma_mma" begin - @testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], - b_layout in ["row", "col"], - shape in ["m16n16k16"], - d_elem_type in ["f16", "f32"], - c_elem_type in ["f16", "f32"] - - # Type-dependent variables - d_ty = d_elem_type == "f16" ? Float16 : Float32 - c_ty = c_elem_type == "f16" ? Float16 : Float32 - - # Get the function names - lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16")) - ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16")) - ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)")) - mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)")) - std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)")) - - # Generate input matrices - a = rand(Float16, (16, 16)) - a_dev = CuArray(a) - b = rand(Float16, (16, 16)) - b_dev = CuArray(b) - c = rand(c_ty, (16, 16)) - c_dev = CuArray(c) - - # Reserve space for result - d = Array{d_ty}(undef, (16, 16)) - d_dev = CuArray(d) - - # Matrix MAC kernel (D = A * B + C) - function kernel(a_dev, b_dev, c_dev, d_dev) - a_frag = lda_func(pointer(a_dev), 16) - b_frag = ldb_func(pointer(b_dev), 16) - c_frag = ldc_func(pointer(c_dev), 16) - - d_frag = mma_func(a_frag, b_frag, c_frag) - - std_func(pointer(d_dev), d_frag, 16) - return - end - - @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) - - new_a = (a_layout == "col" ? a : transpose(a)) - new_b = (b_layout == "col" ? b : transpose(b)) - - @test all(isapprox.(new_a * new_b + c, Array(d_dev); rtol=sqrt(eps(Float16)))) - end - end + #= @testset "llvm_wmma_store" begin =# + #= @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], =# + #= layout in ["row", "col"], =# + #= shape in ["m16n16k16"], =# + #= addr_space in ["", "_global", "_shared"], =# + #= stride in ["stride"], =# + #= elem_type in ["f16", "f32"] =# + + #= # Type-dependent variables =# + #= array_ty = elem_type == "f16" ? Float16 : Float32 =# + #= data = elem_type == "f16" ? =# + #= ( =# + #= (VecElement{Float16}(42), VecElement{Float16}(42)), =# + #= (VecElement{Float16}(42), VecElement{Float16}(42)), =# + #= (VecElement{Float16}(42), VecElement{Float16}(42)), =# + #= (VecElement{Float16}(42), VecElement{Float16}(42)) =# + #= ) : (42, 42, 42, 42, 42, 42, 42, 42) =# + + #= # Get the function name =# + #= func = Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") =# + + #= # Address-space dependent variables =# + #= do_shared_test = (addr_space == "_shared") =# + + #= output = Array{array_ty}(undef, (16, 16)) =# + #= output_dev = CuArray(output) =# + + + #= @eval function kernel(output_dev) =# + #= if $do_shared_test =# + #= shared_mem = @cuStaticSharedMem($array_ty, 256) =# + #= ptr = generic_to_shared(pointer(shared_mem)) =# + #= $func(ptr, $data, 16) =# + + #= for i = 1:256 =# + #= @inbounds output_dev[i] = shared_mem[i] =# + #= end =# + #= else =# + #= $func(pointer(output_dev), $data, 16) =# + #= end =# + + #= return =# + #= end =# + + #= @cuda threads=32 kernel(output_dev) =# + #= @test all(Array(output_dev) .== 42.0) =# + #= end =# + #= end =# + + #= @testset "llvm_wmma_mma" begin =# + #= @testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], =# + #= b_layout in ["row", "col"], =# + #= shape in ["m16n16k16"], =# + #= d_elem_type in ["f16", "f32"], =# + #= c_elem_type in ["f16", "f32"] =# + + #= # Type-dependent variables =# + #= d_ty = d_elem_type == "f16" ? Float16 : Float32 =# + #= c_ty = c_elem_type == "f16" ? Float16 : Float32 =# + + #= # Get the function names =# + #= lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16")) =# + #= ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16")) =# + #= ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)")) =# + #= mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)")) =# + #= std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)")) =# + + #= # Generate input matrices =# + #= a = rand(Float16, (16, 16)) =# + #= a_dev = CuArray(a) =# + #= b = rand(Float16, (16, 16)) =# + #= b_dev = CuArray(b) =# + #= c = rand(c_ty, (16, 16)) =# + #= c_dev = CuArray(c) =# + + #= # Reserve space for result =# + #= d = Array{d_ty}(undef, (16, 16)) =# + #= d_dev = CuArray(d) =# + + #= # Matrix MAC kernel (D = A * B + C) =# + #= function kernel(a_dev, b_dev, c_dev, d_dev) =# + #= a_frag = lda_func(pointer(a_dev), 16) =# + #= b_frag = ldb_func(pointer(b_dev), 16) =# + #= c_frag = ldc_func(pointer(c_dev), 16) =# + + #= d_frag = mma_func(a_frag, b_frag, c_frag) =# + + #= std_func(pointer(d_dev), d_frag, 16) =# + #= return =# + #= end =# + + #= @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) =# + + #= new_a = (a_layout == "col" ? a : transpose(a)) =# + #= new_b = (b_layout == "col" ? b : transpose(b)) =# + + #= @test all(isapprox.(new_a * new_b + c, Array(d_dev); rtol=sqrt(eps(Float16)))) =# + #= end =# + #= end =# end ################################################################################ diff --git a/test/runtests.jl b/test/runtests.jl index 18f933b9..1d70f530 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,9 +55,9 @@ if length(devices()) > 0 end cap = CUDAnative.current_capability() -include("base.jl") -include("pointer.jl") -include("codegen.jl") +#= include("base.jl") =# +#= include("pointer.jl") =# +#= include("codegen.jl") =# if dev === nothing @warn("No CUDA-capable devices available; skipping on-device tests.") @@ -65,14 +65,14 @@ else if capability(dev) < v"2.0" @warn("native execution not supported on SM < 2.0") else - include("device/codegen.jl") - include("device/execution.jl") - include("device/pointer.jl") - include("device/array.jl") - include("device/cuda.jl") + #= include("device/codegen.jl") =# + #= include("device/execution.jl") =# + #= include("device/pointer.jl") =# + #= include("device/array.jl") =# + #= include("device/cuda.jl") =# include("device/wmma.jl") - include("examples.jl") + #= include("examples.jl") =# end end From b1865dd0c0ddfddd9803a367745709f3d94cddef Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 26 Nov 2019 23:22:44 +0100 Subject: [PATCH 23/81] Use ccall for store --- src/device/cuda/wmma.jl | 16 +++---- test/device/wmma.jl | 92 ++++++++++++++++++++--------------------- 2 files changed, 55 insertions(+), 53 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 05588152..4c02b3d1 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -131,7 +131,7 @@ for mat in ["d"], func_name = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) # Name of the LLVM intrinsic - llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "store", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") + llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "store", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") # Determine types for this (matrix, elem_type) combination sz = get_frag_sz(mat, elem_type) @@ -152,12 +152,14 @@ for mat in ["d"], ret void ") - @eval $func_name(dst_addr, data, stride) = Base.llvmcall($ir, - Nothing, - Tuple{Int64, NTuple{$sz, $jl_ty}, Int32}, - convert(Int64, dst_addr), - convert(NTuple{$sz, $jl_ty}, data), - convert(Int32, stride)) + ccall_name = "extern $llvm_intr" + base_type = elem_type == "f16" ? Float16 : Float32 + + if sz == 4 + @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$base_type}, $jl_ty, $jl_ty, $jl_ty, $jl_ty, Int32), dst_addr, data[1], data[2], data[3], data[4], stride) + elseif sz == 8 + @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$base_type}, $jl_ty, $jl_ty, $jl_ty, $jl_ty, $jl_ty, $jl_ty, $jl_ty, $jl_ty, Int32), dst_addr, data[1], data[2], data[3], data[4], data[5], data[6], data[7], data[8], stride) + end @eval export $func_name end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index c9164f49..0fb19ea2 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -20,7 +20,7 @@ @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"], layout in ["row", "col"], shape in ["m16n16k16"], - addr_space in ["", "_global" #=, "_shared" =#], + addr_space in ["", "_global"], stride in ["stride"], elem_type in ["f16", "f32"] @@ -65,54 +65,54 @@ end end - #= @testset "llvm_wmma_store" begin =# - #= @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], =# - #= layout in ["row", "col"], =# - #= shape in ["m16n16k16"], =# - #= addr_space in ["", "_global", "_shared"], =# - #= stride in ["stride"], =# - #= elem_type in ["f16", "f32"] =# + @testset "llvm_wmma_store" begin + @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], + layout in ["row", "col"], + shape in ["m16n16k16"], + addr_space in ["", "_global"], + stride in ["stride"], + elem_type in ["f16", "f32"] - #= # Type-dependent variables =# - #= array_ty = elem_type == "f16" ? Float16 : Float32 =# - #= data = elem_type == "f16" ? =# - #= ( =# - #= (VecElement{Float16}(42), VecElement{Float16}(42)), =# - #= (VecElement{Float16}(42), VecElement{Float16}(42)), =# - #= (VecElement{Float16}(42), VecElement{Float16}(42)), =# - #= (VecElement{Float16}(42), VecElement{Float16}(42)) =# - #= ) : (42, 42, 42, 42, 42, 42, 42, 42) =# - - #= # Get the function name =# - #= func = Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") =# - - #= # Address-space dependent variables =# - #= do_shared_test = (addr_space == "_shared") =# - - #= output = Array{array_ty}(undef, (16, 16)) =# - #= output_dev = CuArray(output) =# - - - #= @eval function kernel(output_dev) =# - #= if $do_shared_test =# - #= shared_mem = @cuStaticSharedMem($array_ty, 256) =# - #= ptr = generic_to_shared(pointer(shared_mem)) =# - #= $func(ptr, $data, 16) =# - - #= for i = 1:256 =# - #= @inbounds output_dev[i] = shared_mem[i] =# - #= end =# - #= else =# - #= $func(pointer(output_dev), $data, 16) =# - #= end =# + # Type-dependent variables + array_ty = elem_type == "f16" ? Float16 : Float32 + data = elem_type == "f16" ? + ( + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)), + (VecElement{Float16}(42), VecElement{Float16}(42)) + ) : (42, 42, 42, 42, 42, 42, 42, 42) - #= return =# - #= end =# + # Get the function name + func = Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") - #= @cuda threads=32 kernel(output_dev) =# - #= @test all(Array(output_dev) .== 42.0) =# - #= end =# - #= end =# + # Address-space dependent variables + do_shared_test = (addr_space == "_shared") + + output = Array{array_ty}(undef, (16, 16)) + output_dev = CuArray(output) + + + @eval function kernel(output_dev) + if $do_shared_test + shared_mem = @cuStaticSharedMem($array_ty, 256) + ptr = generic_to_shared(pointer(shared_mem)) + $func(ptr, $data, 16) + + for i = 1:256 + @inbounds output_dev[i] = shared_mem[i] + end + else + $func(pointer(output_dev), $data, 16) + end + + return + end + + @cuda threads=32 kernel(output_dev) + @test all(Array(output_dev) .== 42.0) + end + end #= @testset "llvm_wmma_mma" begin =# #= @testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], =# From 865aac5a067942949f206dbce3fe9a54b37f1ae5 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 10:01:55 +0100 Subject: [PATCH 24/81] Cleanup store --- src/device/cuda/wmma.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 4c02b3d1..51c09a41 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -154,12 +154,10 @@ for mat in ["d"], ccall_name = "extern $llvm_intr" base_type = elem_type == "f16" ? Float16 : Float32 + frag_types = ntuple(i -> jl_ty, sz) + frag_vars = ntuple(i -> :(data[$i]), sz) - if sz == 4 - @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$base_type}, $jl_ty, $jl_ty, $jl_ty, $jl_ty, Int32), dst_addr, data[1], data[2], data[3], data[4], stride) - elseif sz == 8 - @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$base_type}, $jl_ty, $jl_ty, $jl_ty, $jl_ty, $jl_ty, $jl_ty, $jl_ty, $jl_ty, Int32), dst_addr, data[1], data[2], data[3], data[4], data[5], data[6], data[7], data[8], stride) - end + @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$base_type}, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name end From 4e2cb3c8cae2d4b8c9db822123d824d8810f33ea Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 10:10:53 +0100 Subject: [PATCH 25/81] Fix wmma --- src/device/cuda/wmma.jl | 27 ++++++++--- test/device/wmma.jl | 100 ++++++++++++++++++++-------------------- 2 files changed, 70 insertions(+), 57 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 51c09a41..a079da16 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -178,7 +178,7 @@ for a_layout in ["col", "row"], func_name = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) # Name of the LLVM intrinsic - llvm_intr = join_nonempty("@llvm", "nvvm", "wmma", "mma", "sync", a_layout, b_layout, shape, d_elem_type, c_elem_type, ".") + llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "mma", "sync", a_layout, b_layout, shape, d_elem_type, c_elem_type, ".") # Determine types for the (matrix, elem_type) combinations for matrix A a_sz = get_frag_sz("a", a_elem_type) @@ -234,12 +234,25 @@ for a_layout in ["col", "row"], ret [$d_sz x $d_lc_ty] %d.aggr.$(d_sz-1) ") - @eval $func_name(a, b, c) = Base.llvmcall($ir, - NTuple{$d_sz, $d_jl_ty}, - Tuple{NTuple{$a_sz, $a_jl_ty}, NTuple{$b_sz, $b_jl_ty}, NTuple{$c_sz, $c_jl_ty}}, - convert(NTuple{$a_sz, $a_jl_ty}, a), - convert(NTuple{$b_sz, $b_jl_ty}, b), - convert(NTuple{$c_sz, $c_jl_ty}, c)) + #= @eval $func_name(a, b, c) = Base.llvmcall($ir, =# + #= NTuple{$d_sz, $d_jl_ty}, =# + #= Tuple{NTuple{$a_sz, $a_jl_ty}, NTuple{$b_sz, $b_jl_ty}, NTuple{$c_sz, $c_jl_ty}}, =# + #= convert(NTuple{$a_sz, $a_jl_ty}, a), =# + #= convert(NTuple{$b_sz, $b_jl_ty}, b), =# + #= convert(NTuple{$c_sz, $c_jl_ty}, c)) =# + + ccall_name = "extern $llvm_intr" + + a_types = ntuple(i -> a_jl_ty, a_sz) + b_types = ntuple(i -> b_jl_ty, b_sz) + c_types = ntuple(i -> c_jl_ty, c_sz) + + a_vars = ntuple(i -> :(a[$i]), a_sz) + b_vars = ntuple(i -> :(b[$i]), b_sz) + c_vars = ntuple(i -> :(c[$i]), c_sz) + + + @eval $func_name(a, b, c) = ccall($ccall_name, llvmcall, NTuple{$d_sz, $d_jl_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)) @eval export $func_name end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 0fb19ea2..37b94627 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -114,56 +114,56 @@ end end - #= @testset "llvm_wmma_mma" begin =# - #= @testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], =# - #= b_layout in ["row", "col"], =# - #= shape in ["m16n16k16"], =# - #= d_elem_type in ["f16", "f32"], =# - #= c_elem_type in ["f16", "f32"] =# - - #= # Type-dependent variables =# - #= d_ty = d_elem_type == "f16" ? Float16 : Float32 =# - #= c_ty = c_elem_type == "f16" ? Float16 : Float32 =# - - #= # Get the function names =# - #= lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16")) =# - #= ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16")) =# - #= ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)")) =# - #= mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)")) =# - #= std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)")) =# - - #= # Generate input matrices =# - #= a = rand(Float16, (16, 16)) =# - #= a_dev = CuArray(a) =# - #= b = rand(Float16, (16, 16)) =# - #= b_dev = CuArray(b) =# - #= c = rand(c_ty, (16, 16)) =# - #= c_dev = CuArray(c) =# - - #= # Reserve space for result =# - #= d = Array{d_ty}(undef, (16, 16)) =# - #= d_dev = CuArray(d) =# - - #= # Matrix MAC kernel (D = A * B + C) =# - #= function kernel(a_dev, b_dev, c_dev, d_dev) =# - #= a_frag = lda_func(pointer(a_dev), 16) =# - #= b_frag = ldb_func(pointer(b_dev), 16) =# - #= c_frag = ldc_func(pointer(c_dev), 16) =# - - #= d_frag = mma_func(a_frag, b_frag, c_frag) =# - - #= std_func(pointer(d_dev), d_frag, 16) =# - #= return =# - #= end =# - - #= @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) =# - - #= new_a = (a_layout == "col" ? a : transpose(a)) =# - #= new_b = (b_layout == "col" ? b : transpose(b)) =# - - #= @test all(isapprox.(new_a * new_b + c, Array(d_dev); rtol=sqrt(eps(Float16)))) =# - #= end =# - #= end =# + @testset "llvm_wmma_mma" begin + @testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], + b_layout in ["row", "col"], + shape in ["m16n16k16"], + d_elem_type in ["f16", "f32"], + c_elem_type in ["f16", "f32"] + + # Type-dependent variables + d_ty = d_elem_type == "f16" ? Float16 : Float32 + c_ty = c_elem_type == "f16" ? Float16 : Float32 + + # Get the function names + lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16")) + ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16")) + ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)")) + mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)")) + std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)")) + + # Generate input matrices + a = rand(Float16, (16, 16)) + a_dev = CuArray(a) + b = rand(Float16, (16, 16)) + b_dev = CuArray(b) + c = rand(c_ty, (16, 16)) + c_dev = CuArray(c) + + # Reserve space for result + d = Array{d_ty}(undef, (16, 16)) + d_dev = CuArray(d) + + # Matrix MAC kernel (D = A * B + C) + function kernel(a_dev, b_dev, c_dev, d_dev) + a_frag = lda_func(pointer(a_dev), 16) + b_frag = ldb_func(pointer(b_dev), 16) + c_frag = ldc_func(pointer(c_dev), 16) + + d_frag = mma_func(a_frag, b_frag, c_frag) + + std_func(pointer(d_dev), d_frag, 16) + return + end + + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) + + new_a = (a_layout == "col" ? a : transpose(a)) + new_b = (b_layout == "col" ? b : transpose(b)) + + @test all(isapprox.(new_a * new_b + c, Array(d_dev); rtol=sqrt(eps(Float16)))) + end + end end ################################################################################ From c85e4d5c633d5e4379c86c4fa71086224361263c Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 10:22:49 +0100 Subject: [PATCH 26/81] Fix shared tests --- test/device/wmma.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 37b94627..f7b75e6c 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -2,13 +2,13 @@ ################################################################################ -@eval generic_to_shared(ptr) = Base.llvmcall( +@eval generic_to_shared(ptr::CUDAnative.DevicePtr{T, AS.Shared}) where T = Base.llvmcall( " %ptr.generic = inttoptr i64 %0 to i8* %ptr.shared = addrspacecast i8* %ptr.generic to i8 addrspace(3)* %ret = ptrtoint i8 addrspace(3)* %ptr.shared to i64 ret i64 %ret", - Int64, + CUDAnative.DevicePtr{T, AS.Shared}, Tuple{Int64}, convert(Int64, ptr)) @@ -20,7 +20,7 @@ @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"], layout in ["row", "col"], shape in ["m16n16k16"], - addr_space in ["", "_global"], + addr_space in ["", "_global", "_shared"], stride in ["stride"], elem_type in ["f16", "f32"] @@ -69,7 +69,7 @@ @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], layout in ["row", "col"], shape in ["m16n16k16"], - addr_space in ["", "_global"], + addr_space in ["", "_global", "_shared"], stride in ["stride"], elem_type in ["f16", "f32"] From 84f43f0439f1cf0abfef4a0797f0282351a00b77 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 11:58:53 +0100 Subject: [PATCH 27/81] Clean up wrappers --- src/device/cuda/wmma.jl | 172 +++++++--------------------------------- 1 file changed, 29 insertions(+), 143 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index a079da16..e056a5d4 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -2,23 +2,17 @@ # CONSTANTS ################################################################################ -# Maps PTX types to LLVM types -map_ptx_to_llvm = Dict( - "f16" => "<2 x half>", - "f32" => "float" - ) - -# Maps PTX types to the LLVM type that llvmcall expects -map_ptx_to_llvmcall = Dict( - "f16" => "<2 x i16>", - "f32" => "float" - ) - -# Maps PTX types to Julia types -map_ptx_to_jl = Dict( - "f16" => NTuple{2, VecElement{Float16}}, - "f32" => Float32 - ) +# Maps PTX types to Julia array types +map_ptx_to_jl_array = Dict( + "f16" => Float16, + "f32" => Float32 + ) + +# Maps PTX types to Julia fragment types +map_ptx_to_jl_frag = Dict( + "f16" => NTuple{2, VecElement{Float16}}, + "f32" => Float32 + ) # Maps matrix & PTX types to fragment sizes map_frag_sizes = Dict( @@ -34,12 +28,6 @@ map_frag_sizes = Dict( # HELPER FUNCTIONS ################################################################################ -macro gen_ir(template, count, delim="\n") - return quote - join([$(esc(template)) for $(esc(:i)) in 0:$(esc(count))-1], $(esc(delim))) - end -end - function join_nonempty(args...) delim = args[end] arr = [args[1:end-1]...] @@ -47,13 +35,8 @@ function join_nonempty(args...) return join(arr[arr .!= ""], delim) end -get_llvm_ty(matrix, ptx_el_type) = map_ptx_to_llvm[ptx_el_type] - -get_llvmcall_ty(matrix, ptx_el_type) = map_ptx_to_llvmcall[ptx_el_type] - -get_jl_ty(matrix, ptx_el_type) = map_ptx_to_jl[ptx_el_type] - -get_frag_sz(matrix, ptx_el_type) = map_frag_sizes["$matrix.$ptx_el_type"] +# Returns (Julia array type, Julia fragment type, fragment size) +get_frag_info(matrix, ptx_el_type) = (map_ptx_to_jl_array[ptx_el_type], map_ptx_to_jl_frag[elem_type], map_frag_sizes["$matrix.$ptx_el_type"]) ################################################################################ # LOW LEVEL API @@ -83,34 +66,12 @@ for mat in ["a", "b", "c"], # Name of the LLVM intrinsic llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "load", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") - # Determine types for this (matrix, elem_type) combination - sz = get_frag_sz(mat, elem_type) - llvm_ty = get_llvm_ty(mat, elem_type) - struct_ty = "{ $(@gen_ir(llvm_ty, sz, ", ")) }" - lc_ty = get_llvmcall_ty(mat, elem_type) - jl_ty = get_jl_ty(mat, elem_type) - - # Generate LLVM IR - ir = ("declare $struct_ty $llvm_intr(i8*, i32)", - " - %src_ptr = inttoptr i64 %0 to i8* - - %ret.llvm = call $struct_ty $llvm_intr(i8* %src_ptr, i32 %1) - - $(@gen_ir("%ret.llvm.$i = extractvalue $struct_ty %ret.llvm, $i", sz)) - - $(@gen_ir("%ret.jl.$i = bitcast $llvm_ty %ret.llvm.$i to $lc_ty", sz)) - - $(@gen_ir("%ret.aggr.$i = insertvalue [$sz x $lc_ty] $(i == 0 ? "undef" : "%ret.aggr.$(i-1)"), $lc_ty %ret.jl.$i, $i", sz)) - - ret [$sz x $lc_ty] %ret.aggr.$(sz-1) - ") - - base_type = elem_type == "f16" ? Float16 : Float32 + # Determine types + size for this (matrix, elem_type) combination + arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) ccall_name = "extern $llvm_intr" - @eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $jl_ty}, (Ref{$base_type}, Int32), src_addr, stride) + @eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $frag_ty}, (Ref{$arr_ty}, Int32), src_addr, stride) @eval export $func_name end @@ -133,32 +94,14 @@ for mat in ["d"], # Name of the LLVM intrinsic llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "store", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") - # Determine types for this (matrix, elem_type) combination - sz = get_frag_sz(mat, elem_type) - llvm_ty = get_llvm_ty(mat, elem_type) - lc_ty = get_llvmcall_ty(mat, elem_type) - jl_ty = get_jl_ty(mat, elem_type) - - # Generate LLVM IR - ir = ("declare void $llvm_intr(i8*, $(@gen_ir("$llvm_ty", sz, ", ")), i32)", - " - %dst_ptr = inttoptr i64 %0 to i8* - - $(@gen_ir("%data.jl.$i = extractvalue [$sz x $lc_ty] %1, $i", sz)) - - $(@gen_ir("%data.llvm.$i = bitcast $lc_ty %data.jl.$i to $llvm_ty", sz)) - - call void $llvm_intr(i8* %dst_ptr, $(@gen_ir("$llvm_ty %data.llvm.$i", sz, ", ")) , i32 %2) - ret void - ") + # Determine types + size for this (matrix, elem_type) combination + arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) ccall_name = "extern $llvm_intr" - base_type = elem_type == "f16" ? Float16 : Float32 - frag_types = ntuple(i -> jl_ty, sz) + frag_types = ntuple(i -> frag_ty, sz) frag_vars = ntuple(i -> :(data[$i]), sz) - @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$base_type}, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) - + @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$arr_ty}, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name end @@ -180,79 +123,22 @@ for a_layout in ["col", "row"], # Name of the LLVM intrinsic llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "mma", "sync", a_layout, b_layout, shape, d_elem_type, c_elem_type, ".") - # Determine types for the (matrix, elem_type) combinations for matrix A - a_sz = get_frag_sz("a", a_elem_type) - a_llvm_ty = get_llvm_ty("a", a_elem_type) - a_lc_ty = get_llvmcall_ty("a", a_elem_type) - a_jl_ty = get_jl_ty("a", a_elem_type) - - # Determine types for the (matrix, elem_type) combinations for matrix B - b_sz = get_frag_sz("b", b_elem_type) - b_llvm_ty = get_llvm_ty("b", b_elem_type) - b_lc_ty = get_llvmcall_ty("b", b_elem_type) - b_jl_ty = get_jl_ty("b", b_elem_type) - - # Determine types for the (matrix, elem_type) combinations for matrix C - c_sz = get_frag_sz("c", c_elem_type) - c_llvm_ty = get_llvm_ty("c", c_elem_type) - c_lc_ty = get_llvmcall_ty("c", c_elem_type) - c_jl_ty = get_jl_ty("c", c_elem_type) - - # Determine types for the (matrix, elem_type) combinations for matrix D - d_sz = get_frag_sz("d", d_elem_type) - d_llvm_ty = get_llvm_ty("d", d_elem_type) - d_lc_ty = get_llvmcall_ty("d", d_elem_type) - d_jl_ty = get_jl_ty("d", d_elem_type) - d_struct_ty = "{ $(@gen_ir(d_llvm_ty, d_sz, ", ")) }" - - # Create the argument string to the IR call - args = join([ - @gen_ir("$a_llvm_ty %a.llvm.$i", a_sz, ", "), - @gen_ir("$b_llvm_ty %b.llvm.$i", b_sz, ", "), - @gen_ir("$c_llvm_ty %c.llvm.$i", c_sz, ", ")] - , ", ") - - # Generate LLVM IR - ir = ("declare $d_struct_ty $llvm_intr($args)", - " - $(@gen_ir("%a.jl.$i = extractvalue [$a_sz x $a_lc_ty] %0, $i", a_sz)) - $(@gen_ir("%b.jl.$i = extractvalue [$b_sz x $b_lc_ty] %1, $i", b_sz)) - $(@gen_ir("%c.jl.$i = extractvalue [$c_sz x $c_lc_ty] %2, $i", c_sz)) - - $(@gen_ir("%a.llvm.$i = bitcast $a_lc_ty %a.jl.$i to $a_llvm_ty", a_sz)) - $(@gen_ir("%b.llvm.$i = bitcast $b_lc_ty %b.jl.$i to $b_llvm_ty", b_sz)) - $(@gen_ir("%c.llvm.$i = bitcast $c_lc_ty %c.jl.$i to $c_llvm_ty", c_sz)) - - %d.llvm = call $d_struct_ty $llvm_intr($args) - - $(@gen_ir("%d.llvm.$i = extractvalue $d_struct_ty %d.llvm, $i", d_sz)) - - $(@gen_ir("%d.jl.$i = bitcast $d_llvm_ty %d.llvm.$i to $d_lc_ty", d_sz)) - - $(@gen_ir("%d.aggr.$i = insertvalue [$d_sz x $d_lc_ty] $(i == 0 ? "undef" : "%d.aggr.$(i-1)"), $d_lc_ty %d.jl.$i, $i", d_sz)) - - ret [$d_sz x $d_lc_ty] %d.aggr.$(d_sz-1) - ") - - #= @eval $func_name(a, b, c) = Base.llvmcall($ir, =# - #= NTuple{$d_sz, $d_jl_ty}, =# - #= Tuple{NTuple{$a_sz, $a_jl_ty}, NTuple{$b_sz, $b_jl_ty}, NTuple{$c_sz, $c_jl_ty}}, =# - #= convert(NTuple{$a_sz, $a_jl_ty}, a), =# - #= convert(NTuple{$b_sz, $b_jl_ty}, b), =# - #= convert(NTuple{$c_sz, $c_jl_ty}, c)) =# + # Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D + a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type) + b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type) + c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type) + d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type) ccall_name = "extern $llvm_intr" - a_types = ntuple(i -> a_jl_ty, a_sz) - b_types = ntuple(i -> b_jl_ty, b_sz) - c_types = ntuple(i -> c_jl_ty, c_sz) + a_types = ntuple(i -> a_frag_ty, a_sz) + b_types = ntuple(i -> b_frag_ty, b_sz) + c_types = ntuple(i -> c_frag_ty, c_sz) a_vars = ntuple(i -> :(a[$i]), a_sz) b_vars = ntuple(i -> :(b[$i]), b_sz) c_vars = ntuple(i -> :(c[$i]), c_sz) - - @eval $func_name(a, b, c) = ccall($ccall_name, llvmcall, NTuple{$d_sz, $d_jl_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)) - + @eval $func_name(a, b, c) = ccall($ccall_name, llvmcall, NTuple{$d_sz, $d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)) @eval export $func_name end From 6b6e65b2390bfbad089fdf3569e1fcf4e3b9e7a9 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 11:59:31 +0100 Subject: [PATCH 28/81] Fix indenting --- src/device/cuda/wmma.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index e056a5d4..01e230d2 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -36,7 +36,11 @@ function join_nonempty(args...) end # Returns (Julia array type, Julia fragment type, fragment size) -get_frag_info(matrix, ptx_el_type) = (map_ptx_to_jl_array[ptx_el_type], map_ptx_to_jl_frag[elem_type], map_frag_sizes["$matrix.$ptx_el_type"]) +get_frag_info(matrix, ptx_el_type) = ( + map_ptx_to_jl_array[ptx_el_type], + map_ptx_to_jl_frag[elem_type], + map_frag_sizes["$matrix.$ptx_el_type"] + ) ################################################################################ # LOW LEVEL API From 2455e3b80627cf92ec8127f4ec8cc891f3a7cf81 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 12:01:30 +0100 Subject: [PATCH 29/81] Fix typo --- src/device/cuda/wmma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 01e230d2..71fc52a4 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -38,7 +38,7 @@ end # Returns (Julia array type, Julia fragment type, fragment size) get_frag_info(matrix, ptx_el_type) = ( map_ptx_to_jl_array[ptx_el_type], - map_ptx_to_jl_frag[elem_type], + map_ptx_to_jl_frag[ptx_elem_type], map_frag_sizes["$matrix.$ptx_el_type"] ) From 4d39b9ba6545250f04e1979bed9f8f18e4321f67 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 12:58:49 +0100 Subject: [PATCH 30/81] Fix typo --- src/device/cuda/wmma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 71fc52a4..25bb15db 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -38,7 +38,7 @@ end # Returns (Julia array type, Julia fragment type, fragment size) get_frag_info(matrix, ptx_el_type) = ( map_ptx_to_jl_array[ptx_el_type], - map_ptx_to_jl_frag[ptx_elem_type], + map_ptx_to_jl_frag[ptx_el_type], map_frag_sizes["$matrix.$ptx_el_type"] ) From 579a060dd5a0ad726911e4b6e9f98f66f7c75d4b Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 13:29:02 +0100 Subject: [PATCH 31/81] Cleanup addrspacecast --- src/device/pointer.jl | 29 +++++++++++++++++++++++++++++ test/device/wmma.jl | 22 +++++++++++----------- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/device/pointer.jl b/src/device/pointer.jl index 62bd24df..df924b88 100644 --- a/src/device/pointer.jl +++ b/src/device/pointer.jl @@ -276,3 +276,32 @@ end @inline unsafe_cached_load(p::DevicePtr{T,AS.Global}, i::Integer=1, args...) where {T} = recurse_pointer_invocation(unsafe_cached_load, p+sizeof(T)*Int(i-one(i)), CachedLoadPointers, 1, args...) + +export addrspacecast + +@generated function addrspacecast(p::DevicePtr{T, AS}) where {T, AS} + # types + eltyp = convert(LLVMType, T) + T_ptr = convert(LLVMType, DevicePtr{T, AS}) + + T_actual_ptr = LLVM.PointerType(eltyp) + T_actual_ptr_as = LLVM.PointerType(eltyp, convert(Int, AS)) + + # create function + param_types = [T_ptr] + llvm_f, _ = create_function(T_ptr, param_types) + + # generate LLVM IR + Builder(JuliaContext()) do builder + entry = BasicBlock(llvm_f, "entry", JuliaContext()) + position!(builder, entry) + + ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr) + ptr_with_as = addrspacecast!(builder, ptr, T_actual_ptr_as) + ret = ptrtoint!(builder, ptr_with_as, LLVM.Int64Type(JuliaContext())) + + ret!(builder, ret) + end + + call_function(llvm_f, DevicePtr{T, AS}, Tuple{DevicePtr{T, AS}}, :((p,))) +end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index f7b75e6c..41294437 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -2,15 +2,15 @@ ################################################################################ -@eval generic_to_shared(ptr::CUDAnative.DevicePtr{T, AS.Shared}) where T = Base.llvmcall( - " - %ptr.generic = inttoptr i64 %0 to i8* - %ptr.shared = addrspacecast i8* %ptr.generic to i8 addrspace(3)* - %ret = ptrtoint i8 addrspace(3)* %ptr.shared to i64 - ret i64 %ret", - CUDAnative.DevicePtr{T, AS.Shared}, - Tuple{Int64}, - convert(Int64, ptr)) +#= @eval generic_to_shared(ptr::CUDAnative.DevicePtr{T, AS.Shared}) where T = Base.llvmcall( =# +#= " =# +#= %ptr.generic = inttoptr i64 %0 to i8* =# +#= %ptr.shared = addrspacecast i8* %ptr.generic to i8 addrspace(3)* =# +#= %ret = ptrtoint i8 addrspace(3)* %ptr.shared to i64 =# +#= ret i64 %ret", =# +#= CUDAnative.DevicePtr{T, AS.Shared}, =# +#= Tuple{Int64}, =# +#= convert(Int64, ptr)) =# ################################################################################ @@ -49,7 +49,7 @@ input_shared = @cuStaticSharedMem($array_ty, 256) fill!(input_shared, 42) - data = $func(generic_to_shared(input_shared.ptr), 16) + data = $func(addrspacecast(input_shared.ptr), 16) result_dev[1] = all(val -> val == $expected, data) else @@ -96,7 +96,7 @@ @eval function kernel(output_dev) if $do_shared_test shared_mem = @cuStaticSharedMem($array_ty, 256) - ptr = generic_to_shared(pointer(shared_mem)) + ptr = addrspacecast(pointer(shared_mem)) $func(ptr, $data, 16) for i = 1:256 From 56063686c7db1c595333125ba90d0c8306bd0ce9 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 27 Nov 2019 13:29:44 +0100 Subject: [PATCH 32/81] Re-enable tests --- test/runtests.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1d70f530..18f933b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,9 +55,9 @@ if length(devices()) > 0 end cap = CUDAnative.current_capability() -#= include("base.jl") =# -#= include("pointer.jl") =# -#= include("codegen.jl") =# +include("base.jl") +include("pointer.jl") +include("codegen.jl") if dev === nothing @warn("No CUDA-capable devices available; skipping on-device tests.") @@ -65,14 +65,14 @@ else if capability(dev) < v"2.0" @warn("native execution not supported on SM < 2.0") else - #= include("device/codegen.jl") =# - #= include("device/execution.jl") =# - #= include("device/pointer.jl") =# - #= include("device/array.jl") =# - #= include("device/cuda.jl") =# + include("device/codegen.jl") + include("device/execution.jl") + include("device/pointer.jl") + include("device/array.jl") + include("device/cuda.jl") include("device/wmma.jl") - #= include("examples.jl") =# + include("examples.jl") end end From fc108a615f7f3563a6c1a0c18c12015647a97343 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 28 Nov 2019 01:13:56 +0100 Subject: [PATCH 33/81] Fix intrinsics for LLVM 8 --- src/device/cuda/wmma.jl | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 25bb15db..cde2b621 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -23,6 +23,12 @@ map_frag_sizes = Dict( "d.f16" => 4, "d.f32" => 8 ) +# Maps PTX AS to Int +map_ptx_as_to_int = Dict( + "" => 0, + "shared" => 3, + "global" => 1 + ) ################################################################################ # HELPER FUNCTIONS @@ -42,6 +48,8 @@ get_frag_info(matrix, ptx_el_type) = ( map_frag_sizes["$matrix.$ptx_el_type"] ) +get_addrspace_info(addr_space) = map_ptx_as_to_int[addr_space] + ################################################################################ # LOW LEVEL API ################################################################################ @@ -64,11 +72,13 @@ for mat in ["a", "b", "c"], continue end + addr_space_int = get_addrspace_info(addr_space) + # Name of the Julia wrapper function func_name = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) # Name of the LLVM intrinsic - llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "load", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") + llvm_intr = "llvm.nvvm.wmma.$shape.load.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8" # Determine types + size for this (matrix, elem_type) combination arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) @@ -92,11 +102,13 @@ for mat in ["d"], # TODO: Non-stride versions? + addr_space_int = get_addrspace_info(addr_space) + # Name of the Julia wrapper function func_name = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) # Name of the LLVM intrinsic - llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "store", mat, "sync", layout, shape, addr_space, stride, elem_type, ".") + llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8" # Determine types + size for this (matrix, elem_type) combination arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) @@ -125,7 +137,7 @@ for a_layout in ["col", "row"], func_name = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) # Name of the LLVM intrinsic - llvm_intr = join_nonempty("llvm", "nvvm", "wmma", "mma", "sync", a_layout, b_layout, shape, d_elem_type, c_elem_type, ".") + llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type" # Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type) From 3c16a2039dff11d3b31276f833539bb866bc4654 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 28 Nov 2019 12:05:57 +0100 Subject: [PATCH 34/81] Fix shared tests --- src/device/pointer.jl | 29 ----------------------------- test/device/wmma.jl | 16 ++-------------- 2 files changed, 2 insertions(+), 43 deletions(-) diff --git a/src/device/pointer.jl b/src/device/pointer.jl index df924b88..62bd24df 100644 --- a/src/device/pointer.jl +++ b/src/device/pointer.jl @@ -276,32 +276,3 @@ end @inline unsafe_cached_load(p::DevicePtr{T,AS.Global}, i::Integer=1, args...) where {T} = recurse_pointer_invocation(unsafe_cached_load, p+sizeof(T)*Int(i-one(i)), CachedLoadPointers, 1, args...) - -export addrspacecast - -@generated function addrspacecast(p::DevicePtr{T, AS}) where {T, AS} - # types - eltyp = convert(LLVMType, T) - T_ptr = convert(LLVMType, DevicePtr{T, AS}) - - T_actual_ptr = LLVM.PointerType(eltyp) - T_actual_ptr_as = LLVM.PointerType(eltyp, convert(Int, AS)) - - # create function - param_types = [T_ptr] - llvm_f, _ = create_function(T_ptr, param_types) - - # generate LLVM IR - Builder(JuliaContext()) do builder - entry = BasicBlock(llvm_f, "entry", JuliaContext()) - position!(builder, entry) - - ptr = inttoptr!(builder, parameters(llvm_f)[1], T_actual_ptr) - ptr_with_as = addrspacecast!(builder, ptr, T_actual_ptr_as) - ret = ptrtoint!(builder, ptr_with_as, LLVM.Int64Type(JuliaContext())) - - ret!(builder, ret) - end - - call_function(llvm_f, DevicePtr{T, AS}, Tuple{DevicePtr{T, AS}}, :((p,))) -end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 41294437..d57e8ac0 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -1,17 +1,5 @@ @testset "WMMA" begin -################################################################################ - -#= @eval generic_to_shared(ptr::CUDAnative.DevicePtr{T, AS.Shared}) where T = Base.llvmcall( =# -#= " =# -#= %ptr.generic = inttoptr i64 %0 to i8* =# -#= %ptr.shared = addrspacecast i8* %ptr.generic to i8 addrspace(3)* =# -#= %ret = ptrtoint i8 addrspace(3)* %ptr.shared to i64 =# -#= ret i64 %ret", =# -#= CUDAnative.DevicePtr{T, AS.Shared}, =# -#= Tuple{Int64}, =# -#= convert(Int64, ptr)) =# - ################################################################################ @testset "LLVM intrinsics" begin @@ -49,7 +37,7 @@ input_shared = @cuStaticSharedMem($array_ty, 256) fill!(input_shared, 42) - data = $func(addrspacecast(input_shared.ptr), 16) + data = $func(input_shared.ptr, 16) result_dev[1] = all(val -> val == $expected, data) else @@ -96,7 +84,7 @@ @eval function kernel(output_dev) if $do_shared_test shared_mem = @cuStaticSharedMem($array_ty, 256) - ptr = addrspacecast(pointer(shared_mem)) + ptr = pointer(shared_mem) $func(ptr, $data, 16) for i = 1:256 From 2d2b5929cc02bf4e4d2190d319a160d17d69d9d1 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 28 Nov 2019 12:53:58 +0100 Subject: [PATCH 35/81] Clean up tests --- test/device/wmma.jl | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index d57e8ac0..cfb73095 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -19,7 +19,7 @@ # Type-dependent variables array_ty = elem_type == "f16" ? Float16 : Float32 - expected = elem_type == "f16" ? (VecElement{Float16}(42), VecElement{Float16}(42)) : Float32(42) + expected = elem_type == "f16" ? ntuple(i -> VecElement{Float16}(42), 2) : Float32(42) # Address-space dependent variables do_shared_test = (addr_space == "_shared") @@ -37,14 +37,13 @@ input_shared = @cuStaticSharedMem($array_ty, 256) fill!(input_shared, 42) - data = $func(input_shared.ptr, 16) - - result_dev[1] = all(val -> val == $expected, data) + data = $func(pointer(input_shared), 16) else data = $func(pointer(input_dev), 16) - result_dev[1] = all(val -> val == $expected, data) end + result_dev[1] = all(val -> val == $expected, data) + return end @@ -63,13 +62,7 @@ # Type-dependent variables array_ty = elem_type == "f16" ? Float16 : Float32 - data = elem_type == "f16" ? - ( - (VecElement{Float16}(42), VecElement{Float16}(42)), - (VecElement{Float16}(42), VecElement{Float16}(42)), - (VecElement{Float16}(42), VecElement{Float16}(42)), - (VecElement{Float16}(42), VecElement{Float16}(42)) - ) : (42, 42, 42, 42, 42, 42, 42, 42) + data = elem_type == "f16" ? ntuple(i -> ntuple(j -> VecElement{Float16}(42), 2), 4) : ntuple(i -> 42, 8) # Get the function name func = Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") @@ -80,12 +73,10 @@ output = Array{array_ty}(undef, (16, 16)) output_dev = CuArray(output) - @eval function kernel(output_dev) if $do_shared_test shared_mem = @cuStaticSharedMem($array_ty, 256) - ptr = pointer(shared_mem) - $func(ptr, $data, 16) + $func(pointer(shared_mem), $data, 16) for i = 1:256 @inbounds output_dev[i] = shared_mem[i] From 2d0c7cf263ac046ab4b35c53316184ae20aa8d55 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 28 Nov 2019 16:47:37 +0100 Subject: [PATCH 36/81] Fixes --- src/device/cuda/wmma.jl | 75 +++++++++++++---------------------------- 1 file changed, 23 insertions(+), 52 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index e607d53f..7ed94592 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -323,10 +323,7 @@ for mat in ["a", "b", "c"], wrapper = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) # Get fragment size - frag_sz = get_frag_sz(mat, elem_type) - - # Get Julia element type - julia_type = get_jl_ty(mat, elem_type) + arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) # Get matrix use type matrix_use = get_matrix_use(mat) @@ -335,18 +332,15 @@ for mat in ["a", "b", "c"], layout_ty = (layout == "col") ? wmma_col_major : wmma_row_major layout_frag_ty = (mat == "c") ? wmma_unspecified : layout_ty - # Get pointer type - ptr_ty = (elem_type == "f32") ? Float32 : Float16 - # Get address space type as_ty = get_address_space(addr_space) - @eval function $func_name(addr::DevicePtr{$ptr_ty, $as_ty}, + @eval function $func_name(addr::DevicePtr{$arr_ty, $as_ty}, stride::Number, layout::Type{$layout_ty}, config::Type{wmma_config{16, 16, 16, d_type}}) where d_type x = $wrapper(addr, stride) - return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_frag_ty, $matrix_use}(x) + return wmma_fragment{16, 16, 16, $sz, $frag_ty, $layout_frag_ty, $matrix_use}(x) end end @@ -387,34 +381,23 @@ for a_layout in ["col", "row"], # Name of the Julia wrapper wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) - # Information about a - a_frag_sz = get_frag_sz("a", a_elem_type) - a_julia_type = get_jl_ty("a", a_elem_type) + # Get types + a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type) a_layout_ty = (a_layout == "col") ? wmma_col_major : wmma_row_major - # Information about b - b_frag_sz = get_frag_sz("b", b_elem_type) - b_julia_type = get_jl_ty("b", b_elem_type) + b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type) b_layout_ty = (b_layout == "col") ? wmma_col_major : wmma_row_major - # Information about c - c_frag_sz = get_frag_sz("c", c_elem_type) - c_julia_type = get_jl_ty("c", c_elem_type) - - # Information about d - d_frag_sz = get_frag_sz("d", d_elem_type) - d_julia_type = get_jl_ty("d", d_elem_type) + c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type) - # We need some way to select if we want d to be 16 or 32-bit floating point - # during dispatch. - dispatch_ty = (d_elem_type == "f16") ? Float16 : Float32 + d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type) - @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_frag_sz, $a_julia_type, $a_layout_ty, wmma_matrix_a}, - b::wmma_fragment{16, 16, 16, $b_frag_sz, $b_julia_type, $b_layout_ty, wmma_matrix_b}, - c::wmma_fragment{16, 16, 16, $c_frag_sz, $c_julia_type, wmma_unspecified, wmma_accumulator}, - conf::Type{wmma_config{16, 16, 16, $dispatch_ty}}) + @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_sz, $a_frag_ty, $a_layout_ty, wmma_matrix_a}, + b::wmma_fragment{16, 16, 16, $b_sz, $b_frag_ty, $b_layout_ty, wmma_matrix_b}, + c::wmma_fragment{16, 16, 16, $c_sz, $c_frag_ty, wmma_unspecified, wmma_accumulator}, + conf::Type{wmma_config{16, 16, 16, $d_arr_ty}}) x = $wrapper(a.x, b.x, c.x) - return wmma_fragment{16, 16, 16, $d_frag_sz, $d_julia_type, wmma_unspecified, wmma_accumulator}(x) + return wmma_fragment{16, 16, 16, $d_sz, $d_frag_ty, wmma_unspecified, wmma_accumulator}(x) end end @@ -459,11 +442,8 @@ for mat in ["d"], # Name of the Julia wrapper wrapper = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) - # Get fragment size - frag_sz = get_frag_sz(mat, elem_type) - - # Get Julia element type - julia_type = get_jl_ty(mat, elem_type) + # Get types + arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) # Get matrix use type matrix_use = get_matrix_use(mat) @@ -472,14 +452,11 @@ for mat in ["d"], layout_ty = (layout == "col") ? wmma_col_major : wmma_row_major layout_frag_ty = wmma_unspecified - # Get pointer type - ptr_ty = (elem_type == "f32") ? Float32 : Float16 - # Get address space type as_ty = get_address_space(addr_space) - @eval function $func_name(addr::DevicePtr{$ptr_ty, $as_ty}, - d::wmma_fragment{16, 16, 16, $frag_sz, $julia_type, $layout_frag_ty, $matrix_use}, + @eval function $func_name(addr::DevicePtr{$arr_ty, $as_ty}, + d::wmma_fragment{16, 16, 16, $sz, $frag_ty, $layout_frag_ty, $matrix_use}, stride::Number, layout::Type{$layout_ty}, config::Type{wmma_config{16, 16, 16, d_type}}) where d_type @@ -515,26 +492,20 @@ for mat in ["c"], # Name of the Julia function func_name = Symbol("wmma_fill_$mat") - # Get fragment size - frag_sz = get_frag_sz(mat, elem_type) - - # Get Julia type - julia_type = get_jl_ty(mat, elem_type) - - # Value type - val_type = (elem_type == "f16") ? Float16 : Float32 + # Get fragment types and size + arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) # Returned tuple if elem_type == "f16" - tuple = :(ntuple(i -> ntuple(j -> VecElement{Float16}(value), 2), $frag_sz)) + tuple = :(ntuple(i -> ntuple(j -> VecElement{Float16}(value), 2), $sz)) else - tuple = :(ntuple(i -> value, $frag_sz)) + tuple = :(ntuple(i -> value, $sz)) end - @eval function $func_name(value::$val_type, + @eval function $func_name(value::$arr_ty, config::Type{wmma_config{M, N, K, d_type}}) where {M, N, K, d_type} x = $tuple - return wmma_fragment{16, 16, 16, $frag_sz, $julia_type, wmma_unspecified, wmma_accumulator}(x) + return wmma_fragment{16, 16, 16, $sz, $frag_ty, wmma_unspecified, wmma_accumulator}(x) end end From 9373c557f3ad0f19e0f915433011d5e419a80303 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 28 Nov 2019 18:12:57 +0100 Subject: [PATCH 37/81] Reenable tests --- test/runtests.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1d70f530..18f933b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,9 +55,9 @@ if length(devices()) > 0 end cap = CUDAnative.current_capability() -#= include("base.jl") =# -#= include("pointer.jl") =# -#= include("codegen.jl") =# +include("base.jl") +include("pointer.jl") +include("codegen.jl") if dev === nothing @warn("No CUDA-capable devices available; skipping on-device tests.") @@ -65,14 +65,14 @@ else if capability(dev) < v"2.0" @warn("native execution not supported on SM < 2.0") else - #= include("device/codegen.jl") =# - #= include("device/execution.jl") =# - #= include("device/pointer.jl") =# - #= include("device/array.jl") =# - #= include("device/cuda.jl") =# + include("device/codegen.jl") + include("device/execution.jl") + include("device/pointer.jl") + include("device/array.jl") + include("device/cuda.jl") include("device/wmma.jl") - #= include("examples.jl") =# + include("examples.jl") end end From ba6ff5bcb4efa225201ed61868f0ad30ccb0d836 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 29 Nov 2019 12:03:36 +0100 Subject: [PATCH 38/81] Add whitespace --- src/device/cuda/wmma.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 7ed94592..9030cf86 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -23,6 +23,7 @@ map_frag_sizes = Dict( "d.f16" => 4, "d.f32" => 8 ) + # Maps PTX AS to Int map_ptx_as_to_int = Dict( "" => 0, From 06f1f1ecf93db38e4049c8272407afbc4212f7b6 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 29 Nov 2019 13:03:48 +0100 Subject: [PATCH 39/81] Use separate frag size variable for high-level API --- src/device/cuda/wmma.jl | 30 +++++++++++++++++++++++------- test/runtests.jl | 18 +++++++++--------- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 9030cf86..e199f77e 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -269,6 +269,16 @@ map_address_space_to_ty = Dict( "global" => AS.Global ) +# Maps matrix & PTX types to number of elements (size after flattening) +map_num_elements = Dict( + "a.f16" => 8, + "b.f16" => 8, + "c.f16" => 4, + "c.f32" => 8, + "d.f16" => 4, + "d.f32" => 8 + ) + # ---------------- # Helper functions # ---------------- @@ -276,6 +286,12 @@ map_address_space_to_ty = Dict( get_matrix_use(mat) = map_matrix_to_use[mat] get_address_space(as) = map_address_space_to_ty[as] +get_hl_frag_info(matrix, ptx_el_type) = ( + map_ptx_to_jl_array[ptx_el_type], + map_ptx_to_jl_frag[ptx_el_type], + map_num_elements["$matrix.$ptx_el_type"] + ) + # --------- # WMMA load # --------- @@ -324,7 +340,7 @@ for mat in ["a", "b", "c"], wrapper = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) # Get fragment size - arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) + arr_ty, frag_ty, sz = get_hl_frag_info(mat, elem_type) # Get matrix use type matrix_use = get_matrix_use(mat) @@ -383,15 +399,15 @@ for a_layout in ["col", "row"], wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) # Get types - a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type) + a_arr_ty, a_frag_ty, a_sz = get_hl_frag_info("a", a_elem_type) a_layout_ty = (a_layout == "col") ? wmma_col_major : wmma_row_major - b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type) + b_arr_ty, b_frag_ty, b_sz = get_hl_frag_info("b", b_elem_type) b_layout_ty = (b_layout == "col") ? wmma_col_major : wmma_row_major - c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type) + c_arr_ty, c_frag_ty, c_sz = get_hl_frag_info("c", c_elem_type) - d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type) + d_arr_ty, d_frag_ty, d_sz = get_hl_frag_info("d", d_elem_type) @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_sz, $a_frag_ty, $a_layout_ty, wmma_matrix_a}, b::wmma_fragment{16, 16, 16, $b_sz, $b_frag_ty, $b_layout_ty, wmma_matrix_b}, @@ -444,7 +460,7 @@ for mat in ["d"], wrapper = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) # Get types - arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) + arr_ty, frag_ty, sz = get_hl_frag_info(mat, elem_type) # Get matrix use type matrix_use = get_matrix_use(mat) @@ -494,7 +510,7 @@ for mat in ["c"], func_name = Symbol("wmma_fill_$mat") # Get fragment types and size - arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) + arr_ty, frag_ty, sz = get_hl_frag_info(mat, elem_type) # Returned tuple if elem_type == "f16" diff --git a/test/runtests.jl b/test/runtests.jl index 18f933b9..1d70f530 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,9 +55,9 @@ if length(devices()) > 0 end cap = CUDAnative.current_capability() -include("base.jl") -include("pointer.jl") -include("codegen.jl") +#= include("base.jl") =# +#= include("pointer.jl") =# +#= include("codegen.jl") =# if dev === nothing @warn("No CUDA-capable devices available; skipping on-device tests.") @@ -65,14 +65,14 @@ else if capability(dev) < v"2.0" @warn("native execution not supported on SM < 2.0") else - include("device/codegen.jl") - include("device/execution.jl") - include("device/pointer.jl") - include("device/array.jl") - include("device/cuda.jl") + #= include("device/codegen.jl") =# + #= include("device/execution.jl") =# + #= include("device/pointer.jl") =# + #= include("device/array.jl") =# + #= include("device/cuda.jl") =# include("device/wmma.jl") - include("examples.jl") + #= include("examples.jl") =# end end From f5ebc8ed74eca2d0991910490e2d00ca79817666 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 29 Nov 2019 22:37:59 +0100 Subject: [PATCH 40/81] Implement flattening --- src/device/cuda/wmma.jl | 93 ++++++++++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 29 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index e199f77e..f67d3018 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -160,6 +160,43 @@ for a_layout in ["col", "row"], @eval export $func_name end +################################################################################ +# FLATTENING/UNFLATTENING LOGIC +################################################################################ + +# Base case (Float16, Float32, ...) +flatten_recurse(typ, e) = [:($e)] +unflatten_recurse(typ, e, idx) = :($e[$idx]), idx + 1 + +# VecElements +flatten_recurse(typ::Type{VecElement{T}}, e) where T = [:($e.value)] +unflatten_recurse(typ::Type{VecElement{T}}, e, idx) where T = :(VecElement{$T}($e[$idx])), idx + 1 + +# NTuples +function flatten_recurse(typ::Type{NTuple{N, T}}, e) where {N, T} + ret = Expr[] + + for (i, eltyp) in enumerate(typ.types) + append!(ret, flatten_recurse(eltyp, :($e[$i]))) + end + + return ret +end + +function unflatten_recurse(typ::Type{NTuple{N, T}}, e, idx) where {N, T} + ret = Expr(:tuple) + + for (i, eltyp) in enumerate(typ.types) + arg, idx = unflatten_recurse(eltyp, e, idx) + push!(ret.args, arg) + end + + return ret, idx +end + +@generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...) +@generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1] + ################################################################################ # HIGH LEVEL (CUDA-STYLE API) ################################################################################ @@ -271,11 +308,11 @@ map_address_space_to_ty = Dict( # Maps matrix & PTX types to number of elements (size after flattening) map_num_elements = Dict( - "a.f16" => 8, - "b.f16" => 8, - "c.f16" => 4, + "a.f16" => 16, + "b.f16" => 16, + "c.f16" => 8, "c.f32" => 8, - "d.f16" => 4, + "d.f16" => 8, "d.f32" => 8 ) @@ -289,6 +326,7 @@ get_address_space(as) = map_address_space_to_ty[as] get_hl_frag_info(matrix, ptx_el_type) = ( map_ptx_to_jl_array[ptx_el_type], map_ptx_to_jl_frag[ptx_el_type], + map_frag_sizes["$matrix.$ptx_el_type"], map_num_elements["$matrix.$ptx_el_type"] ) @@ -340,7 +378,7 @@ for mat in ["a", "b", "c"], wrapper = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) # Get fragment size - arr_ty, frag_ty, sz = get_hl_frag_info(mat, elem_type) + arr_ty, _, _, sz = get_hl_frag_info(mat, elem_type) # Get matrix use type matrix_use = get_matrix_use(mat) @@ -356,8 +394,8 @@ for mat in ["a", "b", "c"], stride::Number, layout::Type{$layout_ty}, config::Type{wmma_config{16, 16, 16, d_type}}) where d_type - x = $wrapper(addr, stride) - return wmma_fragment{16, 16, 16, $sz, $frag_ty, $layout_frag_ty, $matrix_use}(x) + x = flatten($wrapper(addr, stride)) + return wmma_fragment{16, 16, 16, $sz, $arr_ty, $layout_frag_ty, $matrix_use}(x) end end @@ -399,22 +437,25 @@ for a_layout in ["col", "row"], wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) # Get types - a_arr_ty, a_frag_ty, a_sz = get_hl_frag_info("a", a_elem_type) + a_arr_ty, a_frag_ty, a_sz_unfl, a_sz = get_hl_frag_info("a", a_elem_type) a_layout_ty = (a_layout == "col") ? wmma_col_major : wmma_row_major - b_arr_ty, b_frag_ty, b_sz = get_hl_frag_info("b", b_elem_type) + b_arr_ty, b_frag_ty, b_sz_unfl, b_sz = get_hl_frag_info("b", b_elem_type) b_layout_ty = (b_layout == "col") ? wmma_col_major : wmma_row_major - c_arr_ty, c_frag_ty, c_sz = get_hl_frag_info("c", c_elem_type) + c_arr_ty, c_frag_ty, c_sz_unfl, c_sz = get_hl_frag_info("c", c_elem_type) - d_arr_ty, d_frag_ty, d_sz = get_hl_frag_info("d", d_elem_type) + d_arr_ty, _, _, d_sz = get_hl_frag_info("d", d_elem_type) - @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_sz, $a_frag_ty, $a_layout_ty, wmma_matrix_a}, - b::wmma_fragment{16, 16, 16, $b_sz, $b_frag_ty, $b_layout_ty, wmma_matrix_b}, - c::wmma_fragment{16, 16, 16, $c_sz, $c_frag_ty, wmma_unspecified, wmma_accumulator}, + @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_sz, $a_arr_ty, $a_layout_ty, wmma_matrix_a}, + b::wmma_fragment{16, 16, 16, $b_sz, $b_arr_ty, $b_layout_ty, wmma_matrix_b}, + c::wmma_fragment{16, 16, 16, $c_sz, $c_arr_ty, wmma_unspecified, wmma_accumulator}, conf::Type{wmma_config{16, 16, 16, $d_arr_ty}}) - x = $wrapper(a.x, b.x, c.x) - return wmma_fragment{16, 16, 16, $d_sz, $d_frag_ty, wmma_unspecified, wmma_accumulator}(x) + a_unfl = unflatten(NTuple{$a_sz_unfl, $a_frag_ty}, a.x) + b_unfl = unflatten(NTuple{$b_sz_unfl, $b_frag_ty}, b.x) + c_unfl = unflatten(NTuple{$c_sz_unfl, $c_frag_ty}, c.x) + x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) + return wmma_fragment{16, 16, 16, $d_sz, $d_arr_ty, wmma_unspecified, wmma_accumulator}(x) end end @@ -460,7 +501,7 @@ for mat in ["d"], wrapper = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) # Get types - arr_ty, frag_ty, sz = get_hl_frag_info(mat, elem_type) + arr_ty, frag_ty, sz_unfl, sz = get_hl_frag_info(mat, elem_type) # Get matrix use type matrix_use = get_matrix_use(mat) @@ -473,11 +514,12 @@ for mat in ["d"], as_ty = get_address_space(addr_space) @eval function $func_name(addr::DevicePtr{$arr_ty, $as_ty}, - d::wmma_fragment{16, 16, 16, $sz, $frag_ty, $layout_frag_ty, $matrix_use}, + d::wmma_fragment{16, 16, 16, $sz, $arr_ty, $layout_frag_ty, $matrix_use}, stride::Number, layout::Type{$layout_ty}, config::Type{wmma_config{16, 16, 16, d_type}}) where d_type - $wrapper(addr, d.x, stride) + d_unfl = unflatten(NTuple{$sz_unfl, $frag_ty}, d.x) + $wrapper(addr, d_unfl, stride) return nothing end @@ -510,19 +552,12 @@ for mat in ["c"], func_name = Symbol("wmma_fill_$mat") # Get fragment types and size - arr_ty, frag_ty, sz = get_hl_frag_info(mat, elem_type) - - # Returned tuple - if elem_type == "f16" - tuple = :(ntuple(i -> ntuple(j -> VecElement{Float16}(value), 2), $sz)) - else - tuple = :(ntuple(i -> value, $sz)) - end + arr_ty, _, _, sz = get_hl_frag_info(mat, elem_type) @eval function $func_name(value::$arr_ty, config::Type{wmma_config{M, N, K, d_type}}) where {M, N, K, d_type} - x = $tuple - return wmma_fragment{16, 16, 16, $sz, $frag_ty, wmma_unspecified, wmma_accumulator}(x) + x = ntuple(i -> value, $sz) + return wmma_fragment{16, 16, 16, $sz, $arr_ty, wmma_unspecified, wmma_accumulator}(x) end end From 08a6e6c7513439ec033c9eeb0fc78d922e30c590 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 29 Nov 2019 22:52:36 +0100 Subject: [PATCH 41/81] Test elementwise op --- test/device/wmma.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 43de5b23..e2b7680e 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -167,7 +167,10 @@ c_dev = CuArray(c) d_dev = CuArray(d) - @eval function kernel(a_dev, b_dev, c_dev, d_dev) + alpha = rand() + beta = rand() + + @eval function kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) conf = wmma_config{16, 16, 16, $d_type} a_frag = wmma_load_a(pointer(a_dev), 16, $a_layout, conf) @@ -179,6 +182,10 @@ c_frag = wmma_fill_c($c_type(0), conf) end + # TODO: Make this less awkward by implementing Base.broadcast for wmma_fragment + a_frag = typeof(a_frag)(alpha .* a_frag.x) + c_frag = typeof(c_frag)(beta .* c_frag.x) + d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) wmma_store_d(pointer(d_dev), d_frag, 16, $d_layout, conf) @@ -186,7 +193,7 @@ return end - @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) d = Array(d_dev) new_a = (a_layout == wmma_col_major) ? a : transpose(a) @@ -195,9 +202,9 @@ new_d = (d_layout == wmma_col_major) ? d : transpose(d) if do_mac - @test all(isapprox.(new_a * new_b + new_c, new_d; rtol=sqrt(eps(Float16)))) + @test all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16)))) else - @test all(isapprox.(new_a * new_b, new_d; rtol=sqrt(eps(Float16)))) + @test all(isapprox.(alpha * new_a * new_b, new_d; rtol=sqrt(eps(Float16)))) end end From 05ec4a53d6db1ffcfd62f04b035667be3954b436 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 29 Nov 2019 23:20:21 +0100 Subject: [PATCH 42/81] Change comment --- test/device/wmma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index e2b7680e..f7cd4cdb 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -182,7 +182,7 @@ c_frag = wmma_fill_c($c_type(0), conf) end - # TODO: Make this less awkward by implementing Base.broadcast for wmma_fragment + # TODO: Make this less awkward, see https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting-1 a_frag = typeof(a_frag)(alpha .* a_frag.x) c_frag = typeof(c_frag)(beta .* c_frag.x) From a83bbfa5ad9aea2a8be9adb55890e85d8552686e Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 00:02:00 +0100 Subject: [PATCH 43/81] Only run WMMA test for recent Julia --- test/device/wmma.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index f7cd4cdb..0c702f04 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -1,3 +1,5 @@ +# Need https://github.com/JuliaLang/julia/pull/33970 +if VERSION >= v"1.4.0-DEV.534" @testset "WMMA" begin ################################################################################ @@ -212,3 +214,4 @@ ################################################################################ end +end From 062038d189d0ce6682f21e0ae9608293e377db78 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 00:39:35 +0100 Subject: [PATCH 44/81] Reenable other tests --- test/runtests.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1d70f530..18f933b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,9 +55,9 @@ if length(devices()) > 0 end cap = CUDAnative.current_capability() -#= include("base.jl") =# -#= include("pointer.jl") =# -#= include("codegen.jl") =# +include("base.jl") +include("pointer.jl") +include("codegen.jl") if dev === nothing @warn("No CUDA-capable devices available; skipping on-device tests.") @@ -65,14 +65,14 @@ else if capability(dev) < v"2.0" @warn("native execution not supported on SM < 2.0") else - #= include("device/codegen.jl") =# - #= include("device/execution.jl") =# - #= include("device/pointer.jl") =# - #= include("device/array.jl") =# - #= include("device/cuda.jl") =# + include("device/codegen.jl") + include("device/execution.jl") + include("device/pointer.jl") + include("device/array.jl") + include("device/cuda.jl") include("device/wmma.jl") - #= include("examples.jl") =# + include("examples.jl") end end From 62607ef3eca3f066e2e6c9e9445808cef6910da7 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 00:52:14 +0100 Subject: [PATCH 45/81] Add minimum Julia version to documentation --- docs/src/lib/device/wmma.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index 0b4e1c6e..9b4c2be6 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -5,6 +5,20 @@ This interface enables programmatic access to Tensor Cores, a new hardware featu Access to WMMA using CUDAnative is available in two levels: low level wrappers around the LLVM intrinsics, and a higher-level API, similar to that of CUDA C. +Note that to use the WMMA intrinsics, you need a sufficiently recent version of Julia: `v1.4.0-DEV.534` or later. +You can check this by running the following in the REPL: +```julia +VERSION >= v"1.4.0-DEV.534" +``` + +!!! note + + If you're running into the following error while using the WMMA interfaces: + ``` + LLVM error: Do not know how to split the result of this operator! + ``` + then make sure you are running Julia v1.4.0-DEV.534 or later! + ## Introduction of Terminology The WMMA operations perform a matrix multiply-accumulate. From df1ca7db76580f06e09b4d32fd3b42a7a061da82 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 14:27:21 +0100 Subject: [PATCH 46/81] Refactor load to use @generated --- src/device/cuda/wmma.jl | 143 ++++++++++++++++++++++++++++------------ test/runtests.jl | 18 ++--- 2 files changed, 110 insertions(+), 51 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index f67d3018..05973516 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -293,12 +293,6 @@ struct wmma_config{M, N, K, d_type} end # Constants # --------- -map_matrix_to_use = Dict( - "a" => wmma_matrix_a, - "b" => wmma_matrix_b, - "c" => wmma_accumulator, - "d" => wmma_accumulator - ) map_address_space_to_ty = Dict( "" => AS.Generic, @@ -316,6 +310,43 @@ map_num_elements = Dict( "d.f32" => 8 ) + + + + +# Maps Julia array types to string +map_jl_array_to_str = Dict(val => key for (key, val) in map_ptx_to_jl_array) + +# Maps CUDAnative.AS types to string +map_as_ty_to_str = Dict( + AS.Generic => "", + AS.Shared => "shared", + AS.Global => "global" + ) + +# Maps layout types to string +map_layout_ty_to_str = Dict( + wmma_row_major => "row", + wmma_col_major => "col" + ) + +# Maps matrix & type to number of elements (size after flattening) +map_num_elems = Dict( + ("a", Float16) => 16, + ("b", Float16) => 16, + ("c", Float16) => 8, + ("c", Float32) => 8, + ("d", Float16) => 8, + ("d", Float32) => 8 + ) + +# Maps matrix to its use +map_matrix_to_use = Dict( + "a" => wmma_matrix_a, + "b" => wmma_matrix_b, + "c" => wmma_accumulator, + "d" => wmma_accumulator + ) # ---------------- # Helper functions # ---------------- @@ -330,6 +361,50 @@ get_hl_frag_info(matrix, ptx_el_type) = ( map_num_elements["$matrix.$ptx_el_type"] ) +function get_hl_ptr_info(T, AS) + arr_str, as_str = nothing, nothing + + try + arr_str = map_jl_array_to_str[T] + catch + error("Invalid element type for WMMA: $T") + end + + try + as_str = map_as_ty_to_str[AS] + catch + error("Invalid address space for WMMA: $AS") + end + + return (arr_str, as_str) +end + +function get_hl_layout(L) + try + return map_layout_ty_to_str[L] + catch + error("Invalid layout for WMMA: $L") + end +end + +function get_hl_shape(M, N, K) + if (M, N, K) != (16, 16, 16) + error("Invalid shape for WMMA: (M, N, K) = ($M, $N, $K)") + end + + return "m$(M)n$(N)k$(K)" +end + +function get_hl_num_elems(matrix, T) + try + return map_num_elems[(matrix, T)] + catch + error("Invalid type $T for matrix $matrix") + end +end + +get_hl_mat_use(mat) = map_matrix_to_use[mat] + # --------- # WMMA load # --------- @@ -358,44 +433,28 @@ See also: [`wmma_fragment`](@ref), [`wmma_fragment_layout`](@ref), [`wmma_config """ wmma_load_a, wmma_load_b, wmma_load_c -for mat in ["a", "b", "c"], - layout in ["col", "row"], - shape in ["m16n16k16"], - addr_space in ["", "shared", "global"], - stride in ["stride"], - elem_type in ["f16", "f32"] - - - # Float32 is only supported for C - if (elem_type == "f32") && (mat != "c") - continue - end - - # Name of Julia function +for mat in ["a", "b", "c"] func_name = Symbol("wmma_load_$mat") - # Name of the Julia wrapper - wrapper = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) - - # Get fragment size - arr_ty, _, _, sz = get_hl_frag_info(mat, elem_type) - - # Get matrix use type - matrix_use = get_matrix_use(mat) - - # Get layout type - layout_ty = (layout == "col") ? wmma_col_major : wmma_row_major - layout_frag_ty = (mat == "c") ? wmma_unspecified : layout_ty - - # Get address space type - as_ty = get_address_space(addr_space) - - @eval function $func_name(addr::DevicePtr{$arr_ty, $as_ty}, - stride::Number, - layout::Type{$layout_ty}, - config::Type{wmma_config{16, 16, 16, d_type}}) where d_type - x = flatten($wrapper(addr, stride)) - return wmma_fragment{16, 16, 16, $sz, $arr_ty, $layout_frag_ty, $matrix_use}(x) + @eval @generated function $func_name(addr::DevicePtr{T, AS}, + stride::Number, + layout::Type{L}, + config::Type{wmma_config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} + + arr_str, as_str = get_hl_ptr_info(T, AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) + num_els = get_hl_num_elems($mat, T) + U = get_hl_mat_use($mat) + L_ret = ($mat == "c") ? wmma_unspecified : L + + # Name of the Julia wrapper + wrapper = Symbol(join_nonempty("llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str, "_")) + + return quote + x = flatten($wrapper(addr, stride)) + return wmma_fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 18f933b9..1d70f530 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,9 +55,9 @@ if length(devices()) > 0 end cap = CUDAnative.current_capability() -include("base.jl") -include("pointer.jl") -include("codegen.jl") +#= include("base.jl") =# +#= include("pointer.jl") =# +#= include("codegen.jl") =# if dev === nothing @warn("No CUDA-capable devices available; skipping on-device tests.") @@ -65,14 +65,14 @@ else if capability(dev) < v"2.0" @warn("native execution not supported on SM < 2.0") else - include("device/codegen.jl") - include("device/execution.jl") - include("device/pointer.jl") - include("device/array.jl") - include("device/cuda.jl") + #= include("device/codegen.jl") =# + #= include("device/execution.jl") =# + #= include("device/pointer.jl") =# + #= include("device/array.jl") =# + #= include("device/cuda.jl") =# include("device/wmma.jl") - include("examples.jl") + #= include("examples.jl") =# end end From 9fd6b7426415f82178a7f988239bab35eb652aa3 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 14:57:43 +0100 Subject: [PATCH 47/81] Refactor store to use @generated --- src/device/cuda/wmma.jl | 75 +++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 40 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 05973516..f3cd2b77 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -314,6 +314,8 @@ map_num_elements = Dict( + + # Maps Julia array types to string map_jl_array_to_str = Dict(val => key for (key, val) in map_ptx_to_jl_array) @@ -354,7 +356,7 @@ map_matrix_to_use = Dict( get_matrix_use(mat) = map_matrix_to_use[mat] get_address_space(as) = map_address_space_to_ty[as] -get_hl_frag_info(matrix, ptx_el_type) = ( +get_hl_frag_info_old(matrix, ptx_el_type) = ( map_ptx_to_jl_array[ptx_el_type], map_ptx_to_jl_frag[ptx_el_type], map_frag_sizes["$matrix.$ptx_el_type"], @@ -395,16 +397,26 @@ function get_hl_shape(M, N, K) return "m$(M)n$(N)k$(K)" end -function get_hl_num_elems(matrix, T) +get_hl_mat_use(mat) = map_matrix_to_use[mat] + +function get_hl_frag_info(matrix, T) + ptx_ty = nothing + + try + ptx_ty = map_jl_array_to_str[T] + catch + error("Invalid element type for WMMA: $T") + end + try - return map_num_elems[(matrix, T)] + return (map_num_elems[(matrix, T)], + map_frag_sizes["$matrix.$ptx_ty"], + map_ptx_to_jl_frag[ptx_ty]) catch error("Invalid type $T for matrix $matrix") end end -get_hl_mat_use(mat) = map_matrix_to_use[mat] - # --------- # WMMA load # --------- @@ -444,7 +456,7 @@ for mat in ["a", "b", "c"] arr_str, as_str = get_hl_ptr_info(T, AS) layout = get_hl_layout(L) shape = get_hl_shape(M, N, K) - num_els = get_hl_num_elems($mat, T) + num_els, _, _ = get_hl_frag_info($mat, T) U = get_hl_mat_use($mat) L_ret = ($mat == "c") ? wmma_unspecified : L @@ -496,15 +508,15 @@ for a_layout in ["col", "row"], wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) # Get types - a_arr_ty, a_frag_ty, a_sz_unfl, a_sz = get_hl_frag_info("a", a_elem_type) + a_arr_ty, a_frag_ty, a_sz_unfl, a_sz = get_hl_frag_info_old("a", a_elem_type) a_layout_ty = (a_layout == "col") ? wmma_col_major : wmma_row_major - b_arr_ty, b_frag_ty, b_sz_unfl, b_sz = get_hl_frag_info("b", b_elem_type) + b_arr_ty, b_frag_ty, b_sz_unfl, b_sz = get_hl_frag_info_old("b", b_elem_type) b_layout_ty = (b_layout == "col") ? wmma_col_major : wmma_row_major - c_arr_ty, c_frag_ty, c_sz_unfl, c_sz = get_hl_frag_info("c", c_elem_type) + c_arr_ty, c_frag_ty, c_sz_unfl, c_sz = get_hl_frag_info_old("c", c_elem_type) - d_arr_ty, _, _, d_sz = get_hl_frag_info("d", d_elem_type) + d_arr_ty, _, _, d_sz = get_hl_frag_info_old("d", d_elem_type) @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_sz, $a_arr_ty, $a_layout_ty, wmma_matrix_a}, b::wmma_fragment{16, 16, 16, $b_sz, $b_arr_ty, $b_layout_ty, wmma_matrix_b}, @@ -546,42 +558,25 @@ See also: [`wmma_fragment`](@ref), [`wmma_fragment_layout`](@ref), [`wmma_config """ wmma_store_d -for mat in ["d"], - layout in ["col", "row"], - shape in ["m16n16k16"], - addr_space in ["", "shared", "global"], - stride in ["stride"], - elem_type in ["f16", "f32"] +@generated function wmma_store_d(addr::DevicePtr{T, AS}, + d::wmma_fragment{M, N, K, D_SZ, T, wmma_unspecified, wmma_accumulator}, + stride::Number, + layout::Type{L}, + config::Type{wmma_config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} - # Name of Julia function - func_name = Symbol("wmma_store_$mat") + arr_str, as_str = get_hl_ptr_info(T, AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) + num_els, frag_sz, frag_ty = get_hl_frag_info("d", T) # Name of the Julia wrapper - wrapper = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) + wrapper = Symbol(join_nonempty("llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str, "_")) - # Get types - arr_ty, frag_ty, sz_unfl, sz = get_hl_frag_info(mat, elem_type) - - # Get matrix use type - matrix_use = get_matrix_use(mat) - - # Get layout type - layout_ty = (layout == "col") ? wmma_col_major : wmma_row_major - layout_frag_ty = wmma_unspecified - - # Get address space type - as_ty = get_address_space(addr_space) - - @eval function $func_name(addr::DevicePtr{$arr_ty, $as_ty}, - d::wmma_fragment{16, 16, 16, $sz, $arr_ty, $layout_frag_ty, $matrix_use}, - stride::Number, - layout::Type{$layout_ty}, - config::Type{wmma_config{16, 16, 16, d_type}}) where d_type - d_unfl = unflatten(NTuple{$sz_unfl, $frag_ty}, d.x) + return quote + d_unfl = unflatten(NTuple{$frag_sz, $frag_ty}, d.x) $wrapper(addr, d_unfl, stride) return nothing end - end @@ -611,7 +606,7 @@ for mat in ["c"], func_name = Symbol("wmma_fill_$mat") # Get fragment types and size - arr_ty, _, _, sz = get_hl_frag_info(mat, elem_type) + arr_ty, _, _, sz = get_hl_frag_info_old(mat, elem_type) @eval function $func_name(value::$arr_ty, config::Type{wmma_config{M, N, K, d_type}}) where {M, N, K, d_type} From c7eef61a1c9810e9136c71bdc38116a9306ef1e6 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 15:41:39 +0100 Subject: [PATCH 48/81] Refactor fill to use @generated --- src/device/cuda/wmma.jl | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index f3cd2b77..ffe37a9e 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -599,19 +599,17 @@ This operation is useful if you want to implement a matrix multiplication (and t """ wmma_fill_c -for mat in ["c"], - elem_type in ["f16", "f32"] - - # Name of the Julia function - func_name = Symbol("wmma_fill_$mat") +@generated function wmma_fill_c(value::T, + config::Type{wmma_config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} - # Get fragment types and size - arr_ty, _, _, sz = get_hl_frag_info_old(mat, elem_type) + # We can't use closures in @generated functions, so we'll have to do it this way instead of + # ntuple(i -> val, $num_els). + num_els, _, _ = get_hl_frag_info("c", T) - @eval function $func_name(value::$arr_ty, - config::Type{wmma_config{M, N, K, d_type}}) where {M, N, K, d_type} + args = [:value for i=1:num_els] + expr = :(tuple($(args...))) - x = ntuple(i -> value, $sz) - return wmma_fragment{16, 16, 16, $sz, $arr_ty, wmma_unspecified, wmma_accumulator}(x) + return quote + return wmma_fragment{$M, $N, $K, $num_els, $T, wmma_unspecified, wmma_accumulator}($expr) end end From 0d6777c3f59b27b7a67aae453dc77302da8dfb2e Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 16:08:11 +0100 Subject: [PATCH 49/81] Refactor wmma to use @generated --- src/device/cuda/wmma.jl | 133 +++++++++++++++++++++++++--------------- 1 file changed, 84 insertions(+), 49 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index ffe37a9e..90b9f410 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -363,22 +363,12 @@ get_hl_frag_info_old(matrix, ptx_el_type) = ( map_num_elements["$matrix.$ptx_el_type"] ) -function get_hl_ptr_info(T, AS) - arr_str, as_str = nothing, nothing - - try - arr_str = map_jl_array_to_str[T] - catch - error("Invalid element type for WMMA: $T") - end - +function get_hl_as_info(AS) try - as_str = map_as_ty_to_str[AS] + return map_as_ty_to_str[AS] catch error("Invalid address space for WMMA: $AS") end - - return (arr_str, as_str) end function get_hl_layout(L) @@ -411,7 +401,8 @@ function get_hl_frag_info(matrix, T) try return (map_num_elems[(matrix, T)], map_frag_sizes["$matrix.$ptx_ty"], - map_ptx_to_jl_frag[ptx_ty]) + map_ptx_to_jl_frag[ptx_ty], + ptx_ty) catch error("Invalid type $T for matrix $matrix") end @@ -453,12 +444,12 @@ for mat in ["a", "b", "c"] layout::Type{L}, config::Type{wmma_config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} - arr_str, as_str = get_hl_ptr_info(T, AS) - layout = get_hl_layout(L) - shape = get_hl_shape(M, N, K) - num_els, _, _ = get_hl_frag_info($mat, T) - U = get_hl_mat_use($mat) - L_ret = ($mat == "c") ? wmma_unspecified : L + as_str = get_hl_as_info(AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) + num_els, _, _, arr_str = get_hl_frag_info($mat, T) + U = get_hl_mat_use($mat) + L_ret = ($mat == "c") ? wmma_unspecified : L # Name of the Julia wrapper wrapper = Symbol(join_nonempty("llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str, "_")) @@ -496,40 +487,67 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. """ wmma_mma -for a_layout in ["col", "row"], - b_layout in ["col", "row"], - shape in ["m16n16k16"], - d_elem_type in ["f16", "f32"], - c_elem_type in ["f16", "f32"], - b_elem_type in ["f16"], - a_elem_type in ["f16"] - - # Name of the Julia wrapper - wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) +@generated function wmma_mma(a::wmma_fragment{M, N, K, A_SZ, A_T, A_L, wmma_matrix_a}, + b::wmma_fragment{M, N, K, B_SZ, B_T, B_L, wmma_matrix_b}, + c::wmma_fragment{M, N, K, C_SZ, C_T, wmma_unspecified, wmma_accumulator}, + config::Type{wmma_config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} - # Get types - a_arr_ty, a_frag_ty, a_sz_unfl, a_sz = get_hl_frag_info_old("a", a_elem_type) - a_layout_ty = (a_layout == "col") ? wmma_col_major : wmma_row_major + _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T) + _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T) + _, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T) + d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T) - b_arr_ty, b_frag_ty, b_sz_unfl, b_sz = get_hl_frag_info_old("b", b_elem_type) - b_layout_ty = (b_layout == "col") ? wmma_col_major : wmma_row_major + a_layout = get_hl_layout(A_L) + b_layout = get_hl_layout(B_L) + shape = get_hl_shape(M, N, K) - c_arr_ty, c_frag_ty, c_sz_unfl, c_sz = get_hl_frag_info_old("c", c_elem_type) + # Name of the Julia wrapper + wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str, "_")) - d_arr_ty, _, _, d_sz = get_hl_frag_info_old("d", d_elem_type) + return quote + a_unfl = unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x) + b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x) + c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) - @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_sz, $a_arr_ty, $a_layout_ty, wmma_matrix_a}, - b::wmma_fragment{16, 16, 16, $b_sz, $b_arr_ty, $b_layout_ty, wmma_matrix_b}, - c::wmma_fragment{16, 16, 16, $c_sz, $c_arr_ty, wmma_unspecified, wmma_accumulator}, - conf::Type{wmma_config{16, 16, 16, $d_arr_ty}}) - a_unfl = unflatten(NTuple{$a_sz_unfl, $a_frag_ty}, a.x) - b_unfl = unflatten(NTuple{$b_sz_unfl, $b_frag_ty}, b.x) - c_unfl = unflatten(NTuple{$c_sz_unfl, $c_frag_ty}, c.x) x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) - return wmma_fragment{16, 16, 16, $d_sz, $d_arr_ty, wmma_unspecified, wmma_accumulator}(x) + return wmma_fragment{$M, $N, $K, $d_num_els, $D_T, wmma_unspecified, wmma_accumulator}(x) end end +#= for a_layout in ["col", "row"], =# +#= b_layout in ["col", "row"], =# +#= shape in ["m16n16k16"], =# +#= d_elem_type in ["f16", "f32"], =# +#= c_elem_type in ["f16", "f32"], =# +#= b_elem_type in ["f16"], =# +#= a_elem_type in ["f16"] =# + +#= # Name of the Julia wrapper =# +#= wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) =# + +#= # Get types =# +#= a_arr_ty, a_frag_ty, a_sz_unfl, a_sz = get_hl_frag_info_old("a", a_elem_type) =# +#= a_layout_ty = (a_layout == "col") ? wmma_col_major : wmma_row_major =# + +#= b_arr_ty, b_frag_ty, b_sz_unfl, b_sz = get_hl_frag_info_old("b", b_elem_type) =# +#= b_layout_ty = (b_layout == "col") ? wmma_col_major : wmma_row_major =# + +#= c_arr_ty, c_frag_ty, c_sz_unfl, c_sz = get_hl_frag_info_old("c", c_elem_type) =# + +#= d_arr_ty, _, _, d_sz = get_hl_frag_info_old("d", d_elem_type) =# + +#= @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_sz, $a_arr_ty, $a_layout_ty, wmma_matrix_a}, =# +#= b::wmma_fragment{16, 16, 16, $b_sz, $b_arr_ty, $b_layout_ty, wmma_matrix_b}, =# +#= c::wmma_fragment{16, 16, 16, $c_sz, $c_arr_ty, wmma_unspecified, wmma_accumulator}, =# +#= conf::Type{wmma_config{16, 16, 16, $d_arr_ty}}) =# +#= a_unfl = unflatten(NTuple{$a_sz_unfl, $a_frag_ty}, a.x) =# +#= b_unfl = unflatten(NTuple{$b_sz_unfl, $b_frag_ty}, b.x) =# +#= c_unfl = unflatten(NTuple{$c_sz_unfl, $c_frag_ty}, c.x) =# +#= x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) =# +#= return wmma_fragment{16, 16, 16, $d_sz, $d_arr_ty, wmma_unspecified, wmma_accumulator}(x) =# +#= end =# +#= end =# + # ---------- # WMMA store @@ -564,10 +582,10 @@ wmma_store_d layout::Type{L}, config::Type{wmma_config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} - arr_str, as_str = get_hl_ptr_info(T, AS) - layout = get_hl_layout(L) - shape = get_hl_shape(M, N, K) - num_els, frag_sz, frag_ty = get_hl_frag_info("d", T) + as_str = get_hl_as_info(AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) + num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T) # Name of the Julia wrapper wrapper = Symbol(join_nonempty("llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str, "_")) @@ -603,7 +621,7 @@ wmma_fill_c config::Type{wmma_config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} # We can't use closures in @generated functions, so we'll have to do it this way instead of - # ntuple(i -> val, $num_els). + # ntuple(i -> val, $num_els) num_els, _, _ = get_hl_frag_info("c", T) args = [:value for i=1:num_els] @@ -613,3 +631,20 @@ wmma_fill_c return wmma_fragment{$M, $N, $K, $num_els, $T, wmma_unspecified, wmma_accumulator}($expr) end end + +#= for mat in ["c"], =# +#= elem_type in ["f16", "f32"] =# + +#= # Name of the Julia function =# +#= func_name = Symbol("wmma_fill_$mat") =# + +#= # Get fragment types and size =# +#= arr_ty, _, _, sz = get_hl_frag_info_old(mat, elem_type) =# + +#= @eval function $func_name(value::$arr_ty, =# +#= config::Type{wmma_config{M, N, K, d_type}}) where {M, N, K, d_type} =# + +#= x = ntuple(i -> value, $sz) =# +#= return wmma_fragment{16, 16, 16, $sz, $arr_ty, wmma_unspecified, wmma_accumulator}(x) =# +#= end =# +#= end =# From 702d37215d41e24b45aa5e45420135e3c9c5d0aa Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 16:09:53 +0100 Subject: [PATCH 50/81] Cleanup --- src/device/cuda/wmma.jl | 85 +---------------------------------------- 1 file changed, 1 insertion(+), 84 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 90b9f410..780ce8b4 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -293,29 +293,6 @@ struct wmma_config{M, N, K, d_type} end # Constants # --------- - -map_address_space_to_ty = Dict( - "" => AS.Generic, - "shared" => AS.Shared, - "global" => AS.Global - ) - -# Maps matrix & PTX types to number of elements (size after flattening) -map_num_elements = Dict( - "a.f16" => 16, - "b.f16" => 16, - "c.f16" => 8, - "c.f32" => 8, - "d.f16" => 8, - "d.f32" => 8 - ) - - - - - - - # Maps Julia array types to string map_jl_array_to_str = Dict(val => key for (key, val) in map_ptx_to_jl_array) @@ -349,20 +326,11 @@ map_matrix_to_use = Dict( "c" => wmma_accumulator, "d" => wmma_accumulator ) + # ---------------- # Helper functions # ---------------- -get_matrix_use(mat) = map_matrix_to_use[mat] -get_address_space(as) = map_address_space_to_ty[as] - -get_hl_frag_info_old(matrix, ptx_el_type) = ( - map_ptx_to_jl_array[ptx_el_type], - map_ptx_to_jl_frag[ptx_el_type], - map_frag_sizes["$matrix.$ptx_el_type"], - map_num_elements["$matrix.$ptx_el_type"] - ) - function get_hl_as_info(AS) try return map_as_ty_to_str[AS] @@ -514,40 +482,6 @@ wmma_mma end end -#= for a_layout in ["col", "row"], =# -#= b_layout in ["col", "row"], =# -#= shape in ["m16n16k16"], =# -#= d_elem_type in ["f16", "f32"], =# -#= c_elem_type in ["f16", "f32"], =# -#= b_elem_type in ["f16"], =# -#= a_elem_type in ["f16"] =# - -#= # Name of the Julia wrapper =# -#= wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) =# - -#= # Get types =# -#= a_arr_ty, a_frag_ty, a_sz_unfl, a_sz = get_hl_frag_info_old("a", a_elem_type) =# -#= a_layout_ty = (a_layout == "col") ? wmma_col_major : wmma_row_major =# - -#= b_arr_ty, b_frag_ty, b_sz_unfl, b_sz = get_hl_frag_info_old("b", b_elem_type) =# -#= b_layout_ty = (b_layout == "col") ? wmma_col_major : wmma_row_major =# - -#= c_arr_ty, c_frag_ty, c_sz_unfl, c_sz = get_hl_frag_info_old("c", c_elem_type) =# - -#= d_arr_ty, _, _, d_sz = get_hl_frag_info_old("d", d_elem_type) =# - -#= @eval function wmma_mma(a::wmma_fragment{16, 16, 16, $a_sz, $a_arr_ty, $a_layout_ty, wmma_matrix_a}, =# -#= b::wmma_fragment{16, 16, 16, $b_sz, $b_arr_ty, $b_layout_ty, wmma_matrix_b}, =# -#= c::wmma_fragment{16, 16, 16, $c_sz, $c_arr_ty, wmma_unspecified, wmma_accumulator}, =# -#= conf::Type{wmma_config{16, 16, 16, $d_arr_ty}}) =# -#= a_unfl = unflatten(NTuple{$a_sz_unfl, $a_frag_ty}, a.x) =# -#= b_unfl = unflatten(NTuple{$b_sz_unfl, $b_frag_ty}, b.x) =# -#= c_unfl = unflatten(NTuple{$c_sz_unfl, $c_frag_ty}, c.x) =# -#= x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) =# -#= return wmma_fragment{16, 16, 16, $d_sz, $d_arr_ty, wmma_unspecified, wmma_accumulator}(x) =# -#= end =# -#= end =# - # ---------- # WMMA store @@ -631,20 +565,3 @@ wmma_fill_c return wmma_fragment{$M, $N, $K, $num_els, $T, wmma_unspecified, wmma_accumulator}($expr) end end - -#= for mat in ["c"], =# -#= elem_type in ["f16", "f32"] =# - -#= # Name of the Julia function =# -#= func_name = Symbol("wmma_fill_$mat") =# - -#= # Get fragment types and size =# -#= arr_ty, _, _, sz = get_hl_frag_info_old(mat, elem_type) =# - -#= @eval function $func_name(value::$arr_ty, =# -#= config::Type{wmma_config{M, N, K, d_type}}) where {M, N, K, d_type} =# - -#= x = ntuple(i -> value, $sz) =# -#= return wmma_fragment{16, 16, 16, $sz, $arr_ty, wmma_unspecified, wmma_accumulator}(x) =# -#= end =# -#= end =# From 9a3036c6c8804f12277fd9a89b01a0cc719bb672 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 1 Dec 2019 16:13:04 +0100 Subject: [PATCH 51/81] Reenable tests --- test/runtests.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 1d70f530..18f933b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,9 +55,9 @@ if length(devices()) > 0 end cap = CUDAnative.current_capability() -#= include("base.jl") =# -#= include("pointer.jl") =# -#= include("codegen.jl") =# +include("base.jl") +include("pointer.jl") +include("codegen.jl") if dev === nothing @warn("No CUDA-capable devices available; skipping on-device tests.") @@ -65,14 +65,14 @@ else if capability(dev) < v"2.0" @warn("native execution not supported on SM < 2.0") else - #= include("device/codegen.jl") =# - #= include("device/execution.jl") =# - #= include("device/pointer.jl") =# - #= include("device/array.jl") =# - #= include("device/cuda.jl") =# + include("device/codegen.jl") + include("device/execution.jl") + include("device/pointer.jl") + include("device/array.jl") + include("device/cuda.jl") include("device/wmma.jl") - #= include("examples.jl") =# + include("examples.jl") end end From e4fe14355b0c019ede17438cbefce0003868c1ea Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 3 Dec 2019 00:24:57 +0100 Subject: [PATCH 52/81] Implement broadcasting --- src/device/cuda/wmma.jl | 49 +++++++++++++++++++++++++++++++++++++++++ test/device/wmma.jl | 5 ++--- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 780ce8b4..763e7840 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -565,3 +565,52 @@ wmma_fill_c return wmma_fragment{$M, $N, $K, $num_els, $T, wmma_unspecified, wmma_accumulator}($expr) end end + +################################################################################ +# BROADCASTING OVER WMMA_FRAGMENTS +################################################################################ + +# Based on broadcasting implementation of Tuples in +# https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl + + +# Custom broadcast style for wmma_fragments +struct wmma_fragment_broadcast_style <: Broadcast.BroadcastStyle end + +# Use this broadcasting style for wmma_fragments +Base.BroadcastStyle(::Type{<:wmma_fragment}) = wmma_fragment_broadcast_style() + +# Broadcast style precedence rules +# If we broadcast a fragment with a scalar, we want the wmma_fragment style to take precedence +Base.BroadcastStyle(s::wmma_fragment_broadcast_style, t::Broadcast.DefaultArrayStyle{0}) = s + +# We don't want to convert fragments before broadcasting +Base.broadcastable(frag::wmma_fragment) = frag + +# Needed for broadcast machinery +Base.axes(frag::wmma_fragment) = axes(frag.x) + +# Helper functions to get element at specified index +@inline get_index(x, i) = x # scalar +@inline get_index(frag::wmma_fragment, i) = frag.x[i] # wmma_fragment + +# Helper functions to get first fragment in broadcast call +@inline find_first_fragment(args::Tuple) = find_first_fragment(args[1], Base.tail(args)) +@inline find_first_fragment(a::wmma_fragment, tail) = a +@inline find_first_fragment(::Any, tail) = find_first_fragment(tail) + +# Custom broadcast implementation that returns a wmma_fragment +@inline function Base.copy(bc::Broadcast.Broadcasted{wmma_fragment_broadcast_style}) + dim = Broadcast.combine_axes(bc.args...) + + if length(dim) != 1 + throw(DimensionMismatch("WMMA fragment broadcast only supports one dimension!")) + end + + N = length(dim[1]) + + tuple = ntuple(i -> bc.f(map(arg -> get_index(arg, i), bc.args)...), Val(N)) + + frag_ty = typeof(find_first_fragment(bc.args)) + return frag_ty(tuple) +end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 0c702f04..e4dd095e 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -184,9 +184,8 @@ if VERSION >= v"1.4.0-DEV.534" c_frag = wmma_fill_c($c_type(0), conf) end - # TODO: Make this less awkward, see https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting-1 - a_frag = typeof(a_frag)(alpha .* a_frag.x) - c_frag = typeof(c_frag)(beta .* c_frag.x) + a_frag = alpha .* a_frag + c_frag = beta .* c_frag d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) From 556bbdfaa1081f2b932ad840673d2cc527e53098 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 3 Dec 2019 20:19:29 +0100 Subject: [PATCH 53/81] Use correct type for alpha and beta --- test/device/wmma.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index e4dd095e..5bd0dc28 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -169,8 +169,8 @@ if VERSION >= v"1.4.0-DEV.534" c_dev = CuArray(c) d_dev = CuArray(d) - alpha = rand() - beta = rand() + alpha = rand(Float16) + beta = rand(c_type) @eval function kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) conf = wmma_config{16, 16, 16, $d_type} From 40ff035c7b371cfdcb53e7dc7666bd5448759342 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 4 Dec 2019 11:27:03 +0100 Subject: [PATCH 54/81] Add tests for flattening and unflattening --- test/device/wmma.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 5bd0dc28..7824b46e 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -147,6 +147,30 @@ if VERSION >= v"1.4.0-DEV.534" end end +################################################################################ + + @testset "Flattening/unflattening" begin + @testset "Flattening" begin + @test flatten(5) == (5,) + @test flatten(5.0) == (5.0,) + @test flatten(VecElement{Float16}(5)) == (Float16(5),) + @test flatten(ntuple(i -> i, 8)) == ntuple(i -> i, 8) + @test flatten(ntuple(i -> VecElement{Float16}(i), 8)) == ntuple(i -> Float16(i), 8) + @test flatten(ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8)) == ntuple(i -> i, 2 * 8) + @test flatten(ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)) == ntuple(i -> Float16(i), 2 * 8) + end + + @testset "Unflattening" begin + @test unflatten(Int64, (5,)) == 5 + @test unflatten(Float64, (5.0,)) == 5.0 + @test unflatten(VecElement{Float16}, (Float16(5),)) == VecElement{Float16}(5) + @test unflatten(NTuple{8, Int64}, ntuple(i -> i, 8)) == ntuple(i -> i, 8) + @test unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) + @test unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) + @test unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) + end + end + ################################################################################ @testset "CUDA C-style API" begin From 39118711832005e72b32e79cc848b0154a4fa080 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 4 Dec 2019 11:38:52 +0100 Subject: [PATCH 55/81] Add tests for broadcasting --- test/device/wmma.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 7824b46e..755e600b 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -171,6 +171,14 @@ if VERSION >= v"1.4.0-DEV.534" end end +################################################################################ + + @testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5], + ty = [Float16, Float32] + @test ty(5) .* wmma_fragment{16, 16, 16, sz, ty, wmma_row_major, wmma_matrix_a}(ntuple(i -> ty(i), sz)) == wmma_fragment{16, 16, 16, sz, ty, wmma_row_major, wmma_matrix_a}(ntuple(i -> ty(5 * i), sz)) + @test ty(5) .+ wmma_fragment{16, 16, 16, sz, ty, wmma_row_major, wmma_matrix_a}(ntuple(i -> ty(i), sz)) == wmma_fragment{16, 16, 16, sz, ty, wmma_row_major, wmma_matrix_a}(ntuple(i -> ty(5 + i), sz)) + end + ################################################################################ @testset "CUDA C-style API" begin From 70f35df8388afadf63ef1fcf7d9f10f0c1840f34 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Wed, 4 Dec 2019 20:30:25 +0100 Subject: [PATCH 56/81] Add CUDAnative prefix --- test/device/wmma.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 755e600b..f474378d 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -151,23 +151,23 @@ if VERSION >= v"1.4.0-DEV.534" @testset "Flattening/unflattening" begin @testset "Flattening" begin - @test flatten(5) == (5,) - @test flatten(5.0) == (5.0,) - @test flatten(VecElement{Float16}(5)) == (Float16(5),) - @test flatten(ntuple(i -> i, 8)) == ntuple(i -> i, 8) - @test flatten(ntuple(i -> VecElement{Float16}(i), 8)) == ntuple(i -> Float16(i), 8) - @test flatten(ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8)) == ntuple(i -> i, 2 * 8) - @test flatten(ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)) == ntuple(i -> Float16(i), 2 * 8) + @test CUDAnative.flatten(5) == (5,) + @test CUDAnative.flatten(5.0) == (5.0,) + @test CUDAnative.flatten(VecElement{Float16}(5)) == (Float16(5),) + @test CUDAnative.flatten(ntuple(i -> i, 8)) == ntuple(i -> i, 8) + @test CUDAnative.flatten(ntuple(i -> VecElement{Float16}(i), 8)) == ntuple(i -> Float16(i), 8) + @test CUDAnative.flatten(ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8)) == ntuple(i -> i, 2 * 8) + @test CUDAnative.flatten(ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)) == ntuple(i -> Float16(i), 2 * 8) end @testset "Unflattening" begin - @test unflatten(Int64, (5,)) == 5 - @test unflatten(Float64, (5.0,)) == 5.0 - @test unflatten(VecElement{Float16}, (Float16(5),)) == VecElement{Float16}(5) - @test unflatten(NTuple{8, Int64}, ntuple(i -> i, 8)) == ntuple(i -> i, 8) - @test unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) - @test unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) - @test unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) + @test CUDAnative.unflatten(Int64, (5,)) == 5 + @test CUDAnative.unflatten(Float64, (5.0,)) == 5.0 + @test CUDAnative.unflatten(VecElement{Float16}, (Float16(5),)) == VecElement{Float16}(5) + @test CUDAnative.unflatten(NTuple{8, Int64}, ntuple(i -> i, 8)) == ntuple(i -> i, 8) + @test CUDAnative.unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) + @test CUDAnative.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) + @test CUDAnative.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) end end From c99d9f66c0cc89e297be956285071d6f4129c1cb Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 6 Dec 2019 12:03:51 +0100 Subject: [PATCH 57/81] Adhere to Julia naming convention for types --- docs/src/lib/device/wmma.md | 26 +++---- src/device/cuda/wmma.jl | 138 ++++++++++++++++++------------------ test/device/wmma.jl | 22 +++--- 3 files changed, 93 insertions(+), 93 deletions(-) diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index 9b4c2be6..31c56e43 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -161,24 +161,24 @@ Note that, in CUDA C++, the fragment is responsible for both the storage of inte All CUDA C++ WMMA calls are function templates that take the resultant fragment as a by-reference argument. As a result, the type of this argument can be used during overload resolution to select the correct WMMA instruction to call. -In contrast, the API in Julia separates the WMMA storage ([`wmma_fragment`](@ref)) and configuration ([`wmma_config`](@ref)). +In contrast, the API in Julia separates the WMMA storage ([`WmmaFragment`](@ref)) and configuration ([`WmmaConfig`](@ref)). Instead of taking the resultant fragment by reference, the Julia functions just return it. This makes the dataflow clearer, but it also means that the type of that fragment cannot be used for selection of the correct WMMA instruction. Thus, there is still a limited amount of information that cannot be inferred from the argument types, but must nonetheless match for all WMMA operations, such as the overall shape of the MMA. -This is accomplished by a separate "WMMA configuration" (see [`wmma_config`](@ref)) that you create once, and then give as an argument to all intrinsics. +This is accomplished by a separate "WMMA configuration" (see [`WmmaConfig`](@ref)) that you create once, and then give as an argument to all intrinsics. ### Fragment ```@docs -CUDAnative.wmma_fragment_layout -CUDAnative.wmma_row_major -CUDAnative.wmma_col_major -CUDAnative.wmma_unspecified -CUDAnative.wmma_fragment +CUDAnative.WmmaFragmentLayout +CUDAnative.WmmaRowMajor +CUDAnative.WmmaColMajor +CUDAnative.WmmaUnspecified +CUDAnative.WmmaFragment ``` ### WMMA configuration ```@docs -CUDAnative.wmma_config +CUDAnative.WmmaConfig ``` ### Load matrix @@ -220,15 +220,15 @@ c_dev = CuArray(c) d_dev = similar(c_dev) function kernel(a_dev, b_dev, c_dev, d_dev) - conf = wmma_config{16, 16, 16, Float32} + conf = WmmaConfig{16, 16, 16, Float32} - a_frag = wmma_load_a(pointer(a_dev), 16, wmma_col_major, conf) - b_frag = wmma_load_b(pointer(b_dev), 16, wmma_col_major, conf) - c_frag = wmma_load_c(pointer(c_dev), 16, wmma_col_major, conf) + a_frag = wmma_load_a(pointer(a_dev), 16, WmmaColMajor, conf) + b_frag = wmma_load_b(pointer(b_dev), 16, WmmaColMajor, conf) + c_frag = wmma_load_c(pointer(c_dev), 16, WmmaColMajor, conf) d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) - wmma_store_d(pointer(d_dev), d_frag, 16, wmma_col_major, conf) + wmma_store_d(pointer(d_dev), d_frag, 16, WmmaColMajor, conf) return end diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 763e7840..8e14d6c1 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -205,33 +205,33 @@ end # WMMA fragment # ------------- -export wmma_fragment_layout, wmma_row_major, wmma_col_major, wmma_unspecified +export WmmaFragmentLayout, WmmaRowMajor, WmmaColMajor, WmmaUnspecified """ - wmma_fragment_layout + WmmaFragmentLayout Abstract type that specifies the storage layout of a matrix. -Possible values are [`wmma_row_major`](@ref), [`wmma_col_major`](@ref) and [`wmma_unspecified`](@ref). +Possible values are [`WmmaRowMajor`](@ref), [`WmmaColMajor`](@ref) and [`WmmaUnspecified`](@ref). """ -abstract type wmma_fragment_layout end +abstract type WmmaFragmentLayout end """ - wmma_row_major + WmmaRowMajor Type that represents a matrix stored in row major (C style) order. """ -struct wmma_row_major <: wmma_fragment_layout end +struct WmmaRowMajor <: WmmaFragmentLayout end """ - wmma_col_major + WmmaColMajor Type that represents a matrix stored in column major (Julia style) order. """ -struct wmma_col_major <: wmma_fragment_layout end +struct WmmaColMajor <: WmmaFragmentLayout end """ - wmma_unspecified + WmmaUnspecified Type that represents a matrix stored in an unspecified order. @@ -239,27 +239,27 @@ Type that represents a matrix stored in an unspecified order. This storage format is not valid for all WMMA operations! """ -struct wmma_unspecified <: wmma_fragment_layout end +struct WmmaUnspecified <: WmmaFragmentLayout end -export wmma_matrix_a, wmma_matrix_b, wmma_accumulator +export WmmaMatrixA, WmmaMatrixB, WmmaAccumulator -abstract type wmma_fragment_use end -struct wmma_matrix_a <: wmma_fragment_use end -struct wmma_matrix_b <: wmma_fragment_use end -struct wmma_accumulator <: wmma_fragment_use end +abstract type WmmaFragmentUse end +struct WmmaMatrixA <: WmmaFragmentUse end +struct WmmaMatrixB <: WmmaFragmentUse end +struct WmmaAccumulator <: WmmaFragmentUse end -export wmma_fragment +export WmmaFragment """ - wmma_fragment + WmmaFragment Type that represents per-thread intermediate results of WMMA operations. You can access individual elements using the `x` member, but beware that the exact ordering of elements is unspecified. """ -struct wmma_fragment{M, N, K, FS, T, L <: wmma_fragment_layout, U <: wmma_fragment_use} +struct WmmaFragment{M, N, K, FS, T, L <: WmmaFragmentLayout, U <: WmmaFragmentUse} x::NTuple{FS, T} end @@ -267,10 +267,10 @@ end # WMMA configuration # ------------------ -export wmma_config +export WmmaConfig """ - wmma_config{M, N, K, d_type} + WmmaConfig{M, N, K, d_type} Type that contains all information for WMMA operations that cannot be inferred from the argument's types. @@ -279,15 +279,15 @@ WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\c `d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. -All WMMA operations take a `wmma_config` as their final argument. +All WMMA operations take a `WmmaConfig` as their final argument. # Examples ```jldoctest -julia> config = wmma_config{16, 16, 16, Float32} -wmma_config{16,16,16,Float32} +julia> config = WmmaConfig{16, 16, 16, Float32} +WmmaConfig{16,16,16,Float32} ``` """ -struct wmma_config{M, N, K, d_type} end +struct WmmaConfig{M, N, K, d_type} end # --------- # Constants @@ -305,8 +305,8 @@ map_as_ty_to_str = Dict( # Maps layout types to string map_layout_ty_to_str = Dict( - wmma_row_major => "row", - wmma_col_major => "col" + WmmaRowMajor => "row", + WmmaColMajor => "col" ) # Maps matrix & type to number of elements (size after flattening) @@ -321,10 +321,10 @@ map_num_elems = Dict( # Maps matrix to its use map_matrix_to_use = Dict( - "a" => wmma_matrix_a, - "b" => wmma_matrix_b, - "c" => wmma_accumulator, - "d" => wmma_accumulator + "a" => WmmaMatrixA, + "b" => WmmaMatrixB, + "c" => WmmaAccumulator, + "d" => WmmaAccumulator ) # ---------------- @@ -387,15 +387,15 @@ export wmma_load_a, wmma_load_b, wmma_load_c wmma_load_b(addr, stride, layout, config) wmma_load_c(addr, stride, layout, config) -Load the matrix `a`, `b` or `c` from the memory location indicated by `addr`, and return the resulting [`wmma_fragment`](@ref). +Load the matrix `a`, `b` or `c` from the memory location indicated by `addr`, and return the resulting [`WmmaFragment`](@ref). # Arguments - `addr`: The address to load the matrix from. - `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. -- `layout`: The storage layout of the matrix. Possible values are [`wmma_row_major`](@ref) and [`wmma_col_major`](@ref). -- `config`: The WMMA configuration that should be used for loading this matrix. See [`wmma_config`](@ref). +- `layout`: The storage layout of the matrix. Possible values are [`WmmaRowMajor`](@ref) and [`WmmaColMajor`](@ref). +- `config`: The WMMA configuration that should be used for loading this matrix. See [`WmmaConfig`](@ref). -See also: [`wmma_fragment`](@ref), [`wmma_fragment_layout`](@ref), [`wmma_config`](@ref) +See also: [`WmmaFragment`](@ref), [`WmmaFragmentLayout`](@ref), [`WmmaConfig`](@ref) !!! warning @@ -410,21 +410,21 @@ for mat in ["a", "b", "c"] @eval @generated function $func_name(addr::DevicePtr{T, AS}, stride::Number, layout::Type{L}, - config::Type{wmma_config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} + config::Type{WmmaConfig{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) shape = get_hl_shape(M, N, K) num_els, _, _, arr_str = get_hl_frag_info($mat, T) U = get_hl_mat_use($mat) - L_ret = ($mat == "c") ? wmma_unspecified : L + L_ret = ($mat == "c") ? WmmaUnspecified : L # Name of the Julia wrapper wrapper = Symbol(join_nonempty("llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str, "_")) return quote x = flatten($wrapper(addr, stride)) - return wmma_fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) + return WmmaFragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) end end end @@ -443,10 +443,10 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. # Arguments -- `a`: The [`wmma_fragment`](@ref) corresponding to the matrix ``A``. -- `b`: The [`wmma_fragment`](@ref) corresponding to the matrix ``B``. -- `c`: The [`wmma_fragment`](@ref) corresponding to the matrix ``C``. -- `conf`: The [`wmma_config`](@ref) that should be used in this WMMA operation. +- `a`: The [`WmmaFragment`](@ref) corresponding to the matrix ``A``. +- `b`: The [`WmmaFragment`](@ref) corresponding to the matrix ``B``. +- `c`: The [`WmmaFragment`](@ref) corresponding to the matrix ``C``. +- `conf`: The [`WmmaConfig`](@ref) that should be used in this WMMA operation. !!! warning @@ -455,10 +455,10 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. """ wmma_mma -@generated function wmma_mma(a::wmma_fragment{M, N, K, A_SZ, A_T, A_L, wmma_matrix_a}, - b::wmma_fragment{M, N, K, B_SZ, B_T, B_L, wmma_matrix_b}, - c::wmma_fragment{M, N, K, C_SZ, C_T, wmma_unspecified, wmma_accumulator}, - config::Type{wmma_config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} +@generated function wmma_mma(a::WmmaFragment{M, N, K, A_SZ, A_T, A_L, WmmaMatrixA}, + b::WmmaFragment{M, N, K, B_SZ, B_T, B_L, WmmaMatrixB}, + c::WmmaFragment{M, N, K, C_SZ, C_T, WmmaUnspecified, WmmaAccumulator}, + config::Type{WmmaConfig{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T) _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T) @@ -478,7 +478,7 @@ wmma_mma c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) - return wmma_fragment{$M, $N, $K, $d_num_els, $D_T, wmma_unspecified, wmma_accumulator}(x) + return WmmaFragment{$M, $N, $K, $d_num_els, $D_T, WmmaUnspecified, WmmaAccumulator}(x) end end @@ -496,12 +496,12 @@ Store the result matrix `d` to the memory location indicated by `addr`. # Arguments - `addr`: The address to store the matrix to. -- `d`: The [`wmma_fragment`](@ref) corresponding to the `d` matrix. +- `d`: The [`WmmaFragment`](@ref) corresponding to the `d` matrix. - `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. -- `layout`: The storage layout of the matrix. Possible values are [`wmma_row_major`](@ref) and [`wmma_col_major`](@ref). -- `config`: The WMMA configuration that should be used for storing this matrix. See [`wmma_config`](@ref). +- `layout`: The storage layout of the matrix. Possible values are [`WmmaRowMajor`](@ref) and [`WmmaColMajor`](@ref). +- `config`: The WMMA configuration that should be used for storing this matrix. See [`WmmaConfig`](@ref). -See also: [`wmma_fragment`](@ref), [`wmma_fragment_layout`](@ref), [`wmma_config`](@ref) +See also: [`WmmaFragment`](@ref), [`WmmaFragmentLayout`](@ref), [`WmmaConfig`](@ref) !!! warning @@ -511,10 +511,10 @@ See also: [`wmma_fragment`](@ref), [`wmma_fragment_layout`](@ref), [`wmma_config wmma_store_d @generated function wmma_store_d(addr::DevicePtr{T, AS}, - d::wmma_fragment{M, N, K, D_SZ, T, wmma_unspecified, wmma_accumulator}, + d::WmmaFragment{M, N, K, D_SZ, T, WmmaUnspecified, WmmaAccumulator}, stride::Number, layout::Type{L}, - config::Type{wmma_config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} + config::Type{WmmaConfig{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) @@ -541,18 +541,18 @@ export wmma_fill_c """ wmma_fill_c(value, config) -Return a [`wmma_fragment`](@ref) filled with the value `value`. +Return a [`WmmaFragment`](@ref) filled with the value `value`. This operation is useful if you want to implement a matrix multiplication (and thus want to set ``C = O``). # Arguments - `value`: The value used to fill the fragment. Can be a `Float16` or `Float32`. -- `config`: The WMMA configuration that should be used for this WMMA operation. See [`wmma_config`](@ref). +- `config`: The WMMA configuration that should be used for this WMMA operation. See [`WmmaConfig`](@ref). """ wmma_fill_c @generated function wmma_fill_c(value::T, - config::Type{wmma_config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} + config::Type{WmmaConfig{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} # We can't use closures in @generated functions, so we'll have to do it this way instead of # ntuple(i -> val, $num_els) @@ -562,45 +562,45 @@ wmma_fill_c expr = :(tuple($(args...))) return quote - return wmma_fragment{$M, $N, $K, $num_els, $T, wmma_unspecified, wmma_accumulator}($expr) + return WmmaFragment{$M, $N, $K, $num_els, $T, WmmaUnspecified, WmmaAccumulator}($expr) end end ################################################################################ -# BROADCASTING OVER WMMA_FRAGMENTS +# BROADCASTING OVER WMMA FRAGMENTS ################################################################################ # Based on broadcasting implementation of Tuples in # https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl -# Custom broadcast style for wmma_fragments -struct wmma_fragment_broadcast_style <: Broadcast.BroadcastStyle end +# Custom broadcast style for WmmaFragments +struct WmmaFragmentBroadcastStyle <: Broadcast.BroadcastStyle end -# Use this broadcasting style for wmma_fragments -Base.BroadcastStyle(::Type{<:wmma_fragment}) = wmma_fragment_broadcast_style() +# Use this broadcasting style for WmmaFragments +Base.BroadcastStyle(::Type{<:WmmaFragment}) = WmmaFragmentBroadcastStyle() # Broadcast style precedence rules -# If we broadcast a fragment with a scalar, we want the wmma_fragment style to take precedence -Base.BroadcastStyle(s::wmma_fragment_broadcast_style, t::Broadcast.DefaultArrayStyle{0}) = s +# If we broadcast a fragment with a scalar, we want the WmmaFragment style to take precedence +Base.BroadcastStyle(s::WmmaFragmentBroadcastStyle, t::Broadcast.DefaultArrayStyle{0}) = s # We don't want to convert fragments before broadcasting -Base.broadcastable(frag::wmma_fragment) = frag +Base.broadcastable(frag::WmmaFragment) = frag # Needed for broadcast machinery -Base.axes(frag::wmma_fragment) = axes(frag.x) +Base.axes(frag::WmmaFragment) = axes(frag.x) # Helper functions to get element at specified index @inline get_index(x, i) = x # scalar -@inline get_index(frag::wmma_fragment, i) = frag.x[i] # wmma_fragment +@inline get_index(frag::WmmaFragment, i) = frag.x[i] # WmmaFragment # Helper functions to get first fragment in broadcast call @inline find_first_fragment(args::Tuple) = find_first_fragment(args[1], Base.tail(args)) -@inline find_first_fragment(a::wmma_fragment, tail) = a +@inline find_first_fragment(a::WmmaFragment, tail) = a @inline find_first_fragment(::Any, tail) = find_first_fragment(tail) -# Custom broadcast implementation that returns a wmma_fragment -@inline function Base.copy(bc::Broadcast.Broadcasted{wmma_fragment_broadcast_style}) +# Custom broadcast implementation that returns a WmmaFragment +@inline function Base.copy(bc::Broadcast.Broadcasted{WmmaFragmentBroadcastStyle}) dim = Broadcast.combine_axes(bc.args...) if length(dim) != 1 diff --git a/test/device/wmma.jl b/test/device/wmma.jl index f474378d..306eee60 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -175,18 +175,18 @@ if VERSION >= v"1.4.0-DEV.534" @testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5], ty = [Float16, Float32] - @test ty(5) .* wmma_fragment{16, 16, 16, sz, ty, wmma_row_major, wmma_matrix_a}(ntuple(i -> ty(i), sz)) == wmma_fragment{16, 16, 16, sz, ty, wmma_row_major, wmma_matrix_a}(ntuple(i -> ty(5 * i), sz)) - @test ty(5) .+ wmma_fragment{16, 16, 16, sz, ty, wmma_row_major, wmma_matrix_a}(ntuple(i -> ty(i), sz)) == wmma_fragment{16, 16, 16, sz, ty, wmma_row_major, wmma_matrix_a}(ntuple(i -> ty(5 + i), sz)) + @test ty(5) .* WmmaFragment{16, 16, 16, sz, ty, WmmaRowMajor, WmmaMatrixA}(ntuple(i -> ty(i), sz)) == WmmaFragment{16, 16, 16, sz, ty, WmmaRowMajor, WmmaMatrixA}(ntuple(i -> ty(5 * i), sz)) + @test ty(5) .+ WmmaFragment{16, 16, 16, sz, ty, WmmaRowMajor, WmmaMatrixA}(ntuple(i -> ty(i), sz)) == WmmaFragment{16, 16, 16, sz, ty, WmmaRowMajor, WmmaMatrixA}(ntuple(i -> ty(5 + i), sz)) end ################################################################################ @testset "CUDA C-style API" begin - @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [wmma_col_major, wmma_row_major], - b_layout in [wmma_col_major, wmma_row_major], - c_layout in [wmma_col_major, wmma_row_major], - d_layout in [wmma_col_major, wmma_row_major], + @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [WmmaColMajor, WmmaRowMajor], + b_layout in [WmmaColMajor, WmmaRowMajor], + c_layout in [WmmaColMajor, WmmaRowMajor], + d_layout in [WmmaColMajor, WmmaRowMajor], c_type in [Float16, Float32], d_type in [Float16, Float32], do_mac in [true, false] @@ -205,7 +205,7 @@ if VERSION >= v"1.4.0-DEV.534" beta = rand(c_type) @eval function kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) - conf = wmma_config{16, 16, 16, $d_type} + conf = WmmaConfig{16, 16, 16, $d_type} a_frag = wmma_load_a(pointer(a_dev), 16, $a_layout, conf) b_frag = wmma_load_b(pointer(b_dev), 16, $b_layout, conf) @@ -229,10 +229,10 @@ if VERSION >= v"1.4.0-DEV.534" @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) d = Array(d_dev) - new_a = (a_layout == wmma_col_major) ? a : transpose(a) - new_b = (b_layout == wmma_col_major) ? b : transpose(b) - new_c = (c_layout == wmma_col_major) ? c : transpose(c) - new_d = (d_layout == wmma_col_major) ? d : transpose(d) + new_a = (a_layout == WmmaColMajor) ? a : transpose(a) + new_b = (b_layout == WmmaColMajor) ? b : transpose(b) + new_c = (c_layout == WmmaColMajor) ? c : transpose(c) + new_d = (d_layout == WmmaColMajor) ? d : transpose(d) if do_mac @test all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16)))) From 765a7c2da367771fa5c89f2d8a2e2c617d66c878 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 6 Dec 2019 15:01:02 +0100 Subject: [PATCH 58/81] Capitalise WMMA --- docs/src/lib/device/wmma.md | 26 +++---- src/device/cuda/wmma.jl | 136 ++++++++++++++++++------------------ test/device/wmma.jl | 22 +++--- 3 files changed, 92 insertions(+), 92 deletions(-) diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index 31c56e43..56edee45 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -161,24 +161,24 @@ Note that, in CUDA C++, the fragment is responsible for both the storage of inte All CUDA C++ WMMA calls are function templates that take the resultant fragment as a by-reference argument. As a result, the type of this argument can be used during overload resolution to select the correct WMMA instruction to call. -In contrast, the API in Julia separates the WMMA storage ([`WmmaFragment`](@ref)) and configuration ([`WmmaConfig`](@ref)). +In contrast, the API in Julia separates the WMMA storage ([`WMMAFragment`](@ref)) and configuration ([`WMMAConfig`](@ref)). Instead of taking the resultant fragment by reference, the Julia functions just return it. This makes the dataflow clearer, but it also means that the type of that fragment cannot be used for selection of the correct WMMA instruction. Thus, there is still a limited amount of information that cannot be inferred from the argument types, but must nonetheless match for all WMMA operations, such as the overall shape of the MMA. -This is accomplished by a separate "WMMA configuration" (see [`WmmaConfig`](@ref)) that you create once, and then give as an argument to all intrinsics. +This is accomplished by a separate "WMMA configuration" (see [`WMMAConfig`](@ref)) that you create once, and then give as an argument to all intrinsics. ### Fragment ```@docs -CUDAnative.WmmaFragmentLayout -CUDAnative.WmmaRowMajor -CUDAnative.WmmaColMajor -CUDAnative.WmmaUnspecified -CUDAnative.WmmaFragment +CUDAnative.WMMAFragmentLayout +CUDAnative.WMMARowMajor +CUDAnative.WMMAColMajor +CUDAnative.WMMAUnspecified +CUDAnative.WMMAFragment ``` ### WMMA configuration ```@docs -CUDAnative.WmmaConfig +CUDAnative.WMMAConfig ``` ### Load matrix @@ -220,15 +220,15 @@ c_dev = CuArray(c) d_dev = similar(c_dev) function kernel(a_dev, b_dev, c_dev, d_dev) - conf = WmmaConfig{16, 16, 16, Float32} + conf = WMMAConfig{16, 16, 16, Float32} - a_frag = wmma_load_a(pointer(a_dev), 16, WmmaColMajor, conf) - b_frag = wmma_load_b(pointer(b_dev), 16, WmmaColMajor, conf) - c_frag = wmma_load_c(pointer(c_dev), 16, WmmaColMajor, conf) + a_frag = wmma_load_a(pointer(a_dev), 16, WMMAColMajor, conf) + b_frag = wmma_load_b(pointer(b_dev), 16, WMMAColMajor, conf) + c_frag = wmma_load_c(pointer(c_dev), 16, WMMAColMajor, conf) d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) - wmma_store_d(pointer(d_dev), d_frag, 16, WmmaColMajor, conf) + wmma_store_d(pointer(d_dev), d_frag, 16, WMMAColMajor, conf) return end diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 8e14d6c1..a387327b 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -205,33 +205,33 @@ end # WMMA fragment # ------------- -export WmmaFragmentLayout, WmmaRowMajor, WmmaColMajor, WmmaUnspecified +export WMMAFragmentLayout, WMMARowMajor, WMMAColMajor, WMMAUnspecified """ - WmmaFragmentLayout + WMMAFragmentLayout Abstract type that specifies the storage layout of a matrix. -Possible values are [`WmmaRowMajor`](@ref), [`WmmaColMajor`](@ref) and [`WmmaUnspecified`](@ref). +Possible values are [`WMMARowMajor`](@ref), [`WMMAColMajor`](@ref) and [`WMMAUnspecified`](@ref). """ -abstract type WmmaFragmentLayout end +abstract type WMMAFragmentLayout end """ - WmmaRowMajor + WMMARowMajor Type that represents a matrix stored in row major (C style) order. """ -struct WmmaRowMajor <: WmmaFragmentLayout end +struct WMMARowMajor <: WMMAFragmentLayout end """ - WmmaColMajor + WMMAColMajor Type that represents a matrix stored in column major (Julia style) order. """ -struct WmmaColMajor <: WmmaFragmentLayout end +struct WMMAColMajor <: WMMAFragmentLayout end """ - WmmaUnspecified + WMMAUnspecified Type that represents a matrix stored in an unspecified order. @@ -239,27 +239,27 @@ Type that represents a matrix stored in an unspecified order. This storage format is not valid for all WMMA operations! """ -struct WmmaUnspecified <: WmmaFragmentLayout end +struct WMMAUnspecified <: WMMAFragmentLayout end -export WmmaMatrixA, WmmaMatrixB, WmmaAccumulator +export WMMAMatrixA, WMMAMatrixB, WMMAAccumulator -abstract type WmmaFragmentUse end -struct WmmaMatrixA <: WmmaFragmentUse end -struct WmmaMatrixB <: WmmaFragmentUse end -struct WmmaAccumulator <: WmmaFragmentUse end +abstract type WMMAFragmentUse end +struct WMMAMatrixA <: WMMAFragmentUse end +struct WMMAMatrixB <: WMMAFragmentUse end +struct WMMAAccumulator <: WMMAFragmentUse end -export WmmaFragment +export WMMAFragment """ - WmmaFragment + WMMAFragment Type that represents per-thread intermediate results of WMMA operations. You can access individual elements using the `x` member, but beware that the exact ordering of elements is unspecified. """ -struct WmmaFragment{M, N, K, FS, T, L <: WmmaFragmentLayout, U <: WmmaFragmentUse} +struct WMMAFragment{M, N, K, FS, T, L <: WMMAFragmentLayout, U <: WMMAFragmentUse} x::NTuple{FS, T} end @@ -267,10 +267,10 @@ end # WMMA configuration # ------------------ -export WmmaConfig +export WMMAConfig """ - WmmaConfig{M, N, K, d_type} + WMMAConfig{M, N, K, d_type} Type that contains all information for WMMA operations that cannot be inferred from the argument's types. @@ -279,15 +279,15 @@ WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\c `d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. -All WMMA operations take a `WmmaConfig` as their final argument. +All WMMA operations take a `WMMAConfig` as their final argument. # Examples ```jldoctest -julia> config = WmmaConfig{16, 16, 16, Float32} -WmmaConfig{16,16,16,Float32} +julia> config = WMMAConfig{16, 16, 16, Float32} +WMMAConfig{16,16,16,Float32} ``` """ -struct WmmaConfig{M, N, K, d_type} end +struct WMMAConfig{M, N, K, d_type} end # --------- # Constants @@ -305,8 +305,8 @@ map_as_ty_to_str = Dict( # Maps layout types to string map_layout_ty_to_str = Dict( - WmmaRowMajor => "row", - WmmaColMajor => "col" + WMMARowMajor => "row", + WMMAColMajor => "col" ) # Maps matrix & type to number of elements (size after flattening) @@ -321,10 +321,10 @@ map_num_elems = Dict( # Maps matrix to its use map_matrix_to_use = Dict( - "a" => WmmaMatrixA, - "b" => WmmaMatrixB, - "c" => WmmaAccumulator, - "d" => WmmaAccumulator + "a" => WMMAMatrixA, + "b" => WMMAMatrixB, + "c" => WMMAAccumulator, + "d" => WMMAAccumulator ) # ---------------- @@ -387,15 +387,15 @@ export wmma_load_a, wmma_load_b, wmma_load_c wmma_load_b(addr, stride, layout, config) wmma_load_c(addr, stride, layout, config) -Load the matrix `a`, `b` or `c` from the memory location indicated by `addr`, and return the resulting [`WmmaFragment`](@ref). +Load the matrix `a`, `b` or `c` from the memory location indicated by `addr`, and return the resulting [`WMMAFragment`](@ref). # Arguments - `addr`: The address to load the matrix from. - `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. -- `layout`: The storage layout of the matrix. Possible values are [`WmmaRowMajor`](@ref) and [`WmmaColMajor`](@ref). -- `config`: The WMMA configuration that should be used for loading this matrix. See [`WmmaConfig`](@ref). +- `layout`: The storage layout of the matrix. Possible values are [`WMMARowMajor`](@ref) and [`WMMAColMajor`](@ref). +- `config`: The WMMA configuration that should be used for loading this matrix. See [`WMMAConfig`](@ref). -See also: [`WmmaFragment`](@ref), [`WmmaFragmentLayout`](@ref), [`WmmaConfig`](@ref) +See also: [`WMMAFragment`](@ref), [`WMMAFragmentLayout`](@ref), [`WMMAConfig`](@ref) !!! warning @@ -410,21 +410,21 @@ for mat in ["a", "b", "c"] @eval @generated function $func_name(addr::DevicePtr{T, AS}, stride::Number, layout::Type{L}, - config::Type{WmmaConfig{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} + config::Type{WMMAConfig{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) shape = get_hl_shape(M, N, K) num_els, _, _, arr_str = get_hl_frag_info($mat, T) U = get_hl_mat_use($mat) - L_ret = ($mat == "c") ? WmmaUnspecified : L + L_ret = ($mat == "c") ? WMMAUnspecified : L # Name of the Julia wrapper wrapper = Symbol(join_nonempty("llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str, "_")) return quote x = flatten($wrapper(addr, stride)) - return WmmaFragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) + return WMMAFragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) end end end @@ -443,10 +443,10 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. # Arguments -- `a`: The [`WmmaFragment`](@ref) corresponding to the matrix ``A``. -- `b`: The [`WmmaFragment`](@ref) corresponding to the matrix ``B``. -- `c`: The [`WmmaFragment`](@ref) corresponding to the matrix ``C``. -- `conf`: The [`WmmaConfig`](@ref) that should be used in this WMMA operation. +- `a`: The [`WMMAFragment`](@ref) corresponding to the matrix ``A``. +- `b`: The [`WMMAFragment`](@ref) corresponding to the matrix ``B``. +- `c`: The [`WMMAFragment`](@ref) corresponding to the matrix ``C``. +- `conf`: The [`WMMAConfig`](@ref) that should be used in this WMMA operation. !!! warning @@ -455,10 +455,10 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. """ wmma_mma -@generated function wmma_mma(a::WmmaFragment{M, N, K, A_SZ, A_T, A_L, WmmaMatrixA}, - b::WmmaFragment{M, N, K, B_SZ, B_T, B_L, WmmaMatrixB}, - c::WmmaFragment{M, N, K, C_SZ, C_T, WmmaUnspecified, WmmaAccumulator}, - config::Type{WmmaConfig{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} +@generated function wmma_mma(a::WMMAFragment{M, N, K, A_SZ, A_T, A_L, WMMAMatrixA}, + b::WMMAFragment{M, N, K, B_SZ, B_T, B_L, WMMAMatrixB}, + c::WMMAFragment{M, N, K, C_SZ, C_T, WMMAUnspecified, WMMAAccumulator}, + config::Type{WMMAConfig{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T) _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T) @@ -478,7 +478,7 @@ wmma_mma c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) - return WmmaFragment{$M, $N, $K, $d_num_els, $D_T, WmmaUnspecified, WmmaAccumulator}(x) + return WMMAFragment{$M, $N, $K, $d_num_els, $D_T, WMMAUnspecified, WMMAAccumulator}(x) end end @@ -496,12 +496,12 @@ Store the result matrix `d` to the memory location indicated by `addr`. # Arguments - `addr`: The address to store the matrix to. -- `d`: The [`WmmaFragment`](@ref) corresponding to the `d` matrix. +- `d`: The [`WMMAFragment`](@ref) corresponding to the `d` matrix. - `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. -- `layout`: The storage layout of the matrix. Possible values are [`WmmaRowMajor`](@ref) and [`WmmaColMajor`](@ref). -- `config`: The WMMA configuration that should be used for storing this matrix. See [`WmmaConfig`](@ref). +- `layout`: The storage layout of the matrix. Possible values are [`WMMARowMajor`](@ref) and [`WMMAColMajor`](@ref). +- `config`: The WMMA configuration that should be used for storing this matrix. See [`WMMAConfig`](@ref). -See also: [`WmmaFragment`](@ref), [`WmmaFragmentLayout`](@ref), [`WmmaConfig`](@ref) +See also: [`WMMAFragment`](@ref), [`WMMAFragmentLayout`](@ref), [`WMMAConfig`](@ref) !!! warning @@ -511,10 +511,10 @@ See also: [`WmmaFragment`](@ref), [`WmmaFragmentLayout`](@ref), [`WmmaConfig`](@ wmma_store_d @generated function wmma_store_d(addr::DevicePtr{T, AS}, - d::WmmaFragment{M, N, K, D_SZ, T, WmmaUnspecified, WmmaAccumulator}, + d::WMMAFragment{M, N, K, D_SZ, T, WMMAUnspecified, WMMAAccumulator}, stride::Number, layout::Type{L}, - config::Type{WmmaConfig{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} + config::Type{WMMAConfig{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) @@ -541,18 +541,18 @@ export wmma_fill_c """ wmma_fill_c(value, config) -Return a [`WmmaFragment`](@ref) filled with the value `value`. +Return a [`WMMAFragment`](@ref) filled with the value `value`. This operation is useful if you want to implement a matrix multiplication (and thus want to set ``C = O``). # Arguments - `value`: The value used to fill the fragment. Can be a `Float16` or `Float32`. -- `config`: The WMMA configuration that should be used for this WMMA operation. See [`WmmaConfig`](@ref). +- `config`: The WMMA configuration that should be used for this WMMA operation. See [`WMMAConfig`](@ref). """ wmma_fill_c @generated function wmma_fill_c(value::T, - config::Type{WmmaConfig{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} + config::Type{WMMAConfig{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} # We can't use closures in @generated functions, so we'll have to do it this way instead of # ntuple(i -> val, $num_els) @@ -562,7 +562,7 @@ wmma_fill_c expr = :(tuple($(args...))) return quote - return WmmaFragment{$M, $N, $K, $num_els, $T, WmmaUnspecified, WmmaAccumulator}($expr) + return WMMAFragment{$M, $N, $K, $num_els, $T, WMMAUnspecified, WMMAAccumulator}($expr) end end @@ -574,33 +574,33 @@ end # https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl -# Custom broadcast style for WmmaFragments -struct WmmaFragmentBroadcastStyle <: Broadcast.BroadcastStyle end +# Custom broadcast style for WMMAFragments +struct WMMAFragmentBroadcastStyle <: Broadcast.BroadcastStyle end -# Use this broadcasting style for WmmaFragments -Base.BroadcastStyle(::Type{<:WmmaFragment}) = WmmaFragmentBroadcastStyle() +# Use this broadcasting style for WMMAFragments +Base.BroadcastStyle(::Type{<:WMMAFragment}) = WMMAFragmentBroadcastStyle() # Broadcast style precedence rules -# If we broadcast a fragment with a scalar, we want the WmmaFragment style to take precedence -Base.BroadcastStyle(s::WmmaFragmentBroadcastStyle, t::Broadcast.DefaultArrayStyle{0}) = s +# If we broadcast a fragment with a scalar, we want the WMMAFragment style to take precedence +Base.BroadcastStyle(s::WMMAFragmentBroadcastStyle, t::Broadcast.DefaultArrayStyle{0}) = s # We don't want to convert fragments before broadcasting -Base.broadcastable(frag::WmmaFragment) = frag +Base.broadcastable(frag::WMMAFragment) = frag # Needed for broadcast machinery -Base.axes(frag::WmmaFragment) = axes(frag.x) +Base.axes(frag::WMMAFragment) = axes(frag.x) # Helper functions to get element at specified index @inline get_index(x, i) = x # scalar -@inline get_index(frag::WmmaFragment, i) = frag.x[i] # WmmaFragment +@inline get_index(frag::WMMAFragment, i) = frag.x[i] # WMMAFragment # Helper functions to get first fragment in broadcast call @inline find_first_fragment(args::Tuple) = find_first_fragment(args[1], Base.tail(args)) -@inline find_first_fragment(a::WmmaFragment, tail) = a +@inline find_first_fragment(a::WMMAFragment, tail) = a @inline find_first_fragment(::Any, tail) = find_first_fragment(tail) -# Custom broadcast implementation that returns a WmmaFragment -@inline function Base.copy(bc::Broadcast.Broadcasted{WmmaFragmentBroadcastStyle}) +# Custom broadcast implementation that returns a WMMAFragment +@inline function Base.copy(bc::Broadcast.Broadcasted{WMMAFragmentBroadcastStyle}) dim = Broadcast.combine_axes(bc.args...) if length(dim) != 1 diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 306eee60..e1927df9 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -175,18 +175,18 @@ if VERSION >= v"1.4.0-DEV.534" @testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5], ty = [Float16, Float32] - @test ty(5) .* WmmaFragment{16, 16, 16, sz, ty, WmmaRowMajor, WmmaMatrixA}(ntuple(i -> ty(i), sz)) == WmmaFragment{16, 16, 16, sz, ty, WmmaRowMajor, WmmaMatrixA}(ntuple(i -> ty(5 * i), sz)) - @test ty(5) .+ WmmaFragment{16, 16, 16, sz, ty, WmmaRowMajor, WmmaMatrixA}(ntuple(i -> ty(i), sz)) == WmmaFragment{16, 16, 16, sz, ty, WmmaRowMajor, WmmaMatrixA}(ntuple(i -> ty(5 + i), sz)) + @test ty(5) .* WMMAFragment{16, 16, 16, sz, ty, WMMARowMajor, WMMAMatrixA}(ntuple(i -> ty(i), sz)) == WMMAFragment{16, 16, 16, sz, ty, WMMARowMajor, WMMAMatrixA}(ntuple(i -> ty(5 * i), sz)) + @test ty(5) .+ WMMAFragment{16, 16, 16, sz, ty, WMMARowMajor, WMMAMatrixA}(ntuple(i -> ty(i), sz)) == WMMAFragment{16, 16, 16, sz, ty, WMMARowMajor, WMMAMatrixA}(ntuple(i -> ty(5 + i), sz)) end ################################################################################ @testset "CUDA C-style API" begin - @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [WmmaColMajor, WmmaRowMajor], - b_layout in [WmmaColMajor, WmmaRowMajor], - c_layout in [WmmaColMajor, WmmaRowMajor], - d_layout in [WmmaColMajor, WmmaRowMajor], + @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [WMMAColMajor, WMMARowMajor], + b_layout in [WMMAColMajor, WMMARowMajor], + c_layout in [WMMAColMajor, WMMARowMajor], + d_layout in [WMMAColMajor, WMMARowMajor], c_type in [Float16, Float32], d_type in [Float16, Float32], do_mac in [true, false] @@ -205,7 +205,7 @@ if VERSION >= v"1.4.0-DEV.534" beta = rand(c_type) @eval function kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) - conf = WmmaConfig{16, 16, 16, $d_type} + conf = WMMAConfig{16, 16, 16, $d_type} a_frag = wmma_load_a(pointer(a_dev), 16, $a_layout, conf) b_frag = wmma_load_b(pointer(b_dev), 16, $b_layout, conf) @@ -229,10 +229,10 @@ if VERSION >= v"1.4.0-DEV.534" @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) d = Array(d_dev) - new_a = (a_layout == WmmaColMajor) ? a : transpose(a) - new_b = (b_layout == WmmaColMajor) ? b : transpose(b) - new_c = (c_layout == WmmaColMajor) ? c : transpose(c) - new_d = (d_layout == WmmaColMajor) ? d : transpose(d) + new_a = (a_layout == WMMAColMajor) ? a : transpose(a) + new_b = (b_layout == WMMAColMajor) ? b : transpose(b) + new_c = (c_layout == WMMAColMajor) ? c : transpose(c) + new_d = (d_layout == WMMAColMajor) ? d : transpose(d) if do_mac @test all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16)))) From e5e6965e03e4446934c002fad4ba95504839924d Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 7 Dec 2019 17:25:21 +0100 Subject: [PATCH 59/81] Move examples to separate folder --- Manifest.toml | 50 ++++++++++++++++++++++++++ Project.toml | 3 +- docs/src/lib/device/wmma.md | 72 +++++++------------------------------ examples/wmma/high-level.jl | 31 ++++++++++++++++ examples/wmma/low-level.jl | 29 +++++++++++++++ 5 files changed, 124 insertions(+), 61 deletions(-) create mode 100644 examples/wmma/high-level.jl create mode 100644 examples/wmma/low-level.jl diff --git a/Manifest.toml b/Manifest.toml index f3504f26..8b0fa308 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,5 +1,11 @@ # This file is machine-generated - editing it directly is not advised +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "0.5.0" + [[Adapt]] deps = ["LinearAlgebra"] git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf" @@ -26,6 +32,18 @@ git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921" uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" version = "4.0.4" +[[CUDAnative]] +deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"] +git-tree-sha1 = "a67b38619d1fa131027bac1c4a81f0012254d1fd" +uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" +version = "2.6.0" + +[[CuArrays]] +deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] +git-tree-sha1 = "e99db1397ce85975203a9d736ab37534730996ca" +uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" +version = "1.5.0" + [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections"] git-tree-sha1 = "1fe8fad5fc84686dcbc674aa255bc867a64f8132" @@ -36,6 +54,12 @@ version = "0.17.5" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[GPUArrays]] +deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] +git-tree-sha1 = "e756da6cee76a5f1436a05827fa8fdf3badc577f" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "2.0.1" + [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -56,10 +80,22 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[MacroTools]] +deps = ["DataStructures", "Markdown", "Random"] +git-tree-sha1 = "e2fc7a55bb2224e203bbd8b59f72b91323233458" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.3" + [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[NNlib]] +deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"] +git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.6.0" + [[OrderedCollections]] deps = ["Random", "Serialization", "Test"] git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" @@ -74,12 +110,26 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[Requires]] +deps = ["Test"] +git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "0.5.2" + [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/Project.toml b/Project.toml index ac77fb87..ad238324 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ julia = "1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" [targets] -test = ["Test"] +test = ["Test", "CuArrays"] diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index 56edee45..a8a40b8f 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -115,37 +115,14 @@ In what follows, each of these will be discussed. ### Example +````@eval +using Markdown +Markdown.parse(""" ```julia -using CUDAnative -using CuArrays -using Test - -# Generate input matrices -a = rand(Float16, (16, 16)) -a_dev = CuArray(a) -b = rand(Float16, (16, 16)) -b_dev = CuArray(b) -c = rand(Float32, (16, 16)) -c_dev = CuArray(c) - -# Allocate space for result -d_dev = similar(c_dev) - -# Matrix multiply-accumulate kernel (D = A * B + C) -function kernel(a_dev, b_dev, c_dev, d_dev) - a_frag = llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16) - b_frag = llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16) - c_frag = llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16) - - d_frag = llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag) - - llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16) - return -end - -@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) -@test a * b + c ≈ Array(d_dev) rtol=0.01 +$(read("../../../../examples/wmma/low-level.jl", String)) ``` +""") +```` ## CUDA C-like API @@ -205,36 +182,11 @@ CUDAnative.wmma_fill_c ### Example +````@eval +using Markdown +Markdown.parse(""" ```julia -using CUDAnative -using CuArrays -using Test - -a = rand(Float16, (16, 16)) -b = rand(Float16, (16, 16)) -c = rand(Float32, (16, 16)) - -a_dev = CuArray(a) -b_dev = CuArray(b) -c_dev = CuArray(c) -d_dev = similar(c_dev) - -function kernel(a_dev, b_dev, c_dev, d_dev) - conf = WMMAConfig{16, 16, 16, Float32} - - a_frag = wmma_load_a(pointer(a_dev), 16, WMMAColMajor, conf) - b_frag = wmma_load_b(pointer(b_dev), 16, WMMAColMajor, conf) - c_frag = wmma_load_c(pointer(c_dev), 16, WMMAColMajor, conf) - - d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) - - wmma_store_d(pointer(d_dev), d_frag, 16, WMMAColMajor, conf) - - return -end - -@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) -d = Array(d_dev) - -@test a * b + c ≈ d rtol=0.01 +$(read("../../../../examples/wmma/high-level.jl", String)) ``` +""") +```` diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl new file mode 100644 index 00000000..5aabb830 --- /dev/null +++ b/examples/wmma/high-level.jl @@ -0,0 +1,31 @@ +using CUDAnative +using CuArrays +using Test + +a = rand(Float16, (16, 16)) +b = rand(Float16, (16, 16)) +c = rand(Float32, (16, 16)) + +a_dev = CuArray(a) +b_dev = CuArray(b) +c_dev = CuArray(c) +d_dev = similar(c_dev) + +function kernel(a_dev, b_dev, c_dev, d_dev) + conf = WMMAConfig{16, 16, 16, Float32} + + a_frag = wmma_load_a(pointer(a_dev), 16, WMMAColMajor, conf) + b_frag = wmma_load_b(pointer(b_dev), 16, WMMAColMajor, conf) + c_frag = wmma_load_c(pointer(c_dev), 16, WMMAColMajor, conf) + + d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) + + wmma_store_d(pointer(d_dev), d_frag, 16, WMMAColMajor, conf) + + return +end + +@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) +d = Array(d_dev) + +@test all(isapprox.(a * b + c, d; rtol=0.01)) diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl new file mode 100644 index 00000000..89215ea5 --- /dev/null +++ b/examples/wmma/low-level.jl @@ -0,0 +1,29 @@ +using CUDAnative +using CuArrays +using Test + +# Generate input matrices +a = rand(Float16, (16, 16)) +a_dev = CuArray(a) +b = rand(Float16, (16, 16)) +b_dev = CuArray(b) +c = rand(Float32, (16, 16)) +c_dev = CuArray(c) + +# Allocate space for result +d_dev = similar(c_dev) + +# Matrix multiply-accumulate kernel (D = A * B + C) +function kernel(a_dev, b_dev, c_dev, d_dev) + a_frag = llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16) + b_frag = llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16) + c_frag = llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16) + + d_frag = llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag) + + llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16) + return +end + +@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) +@test all(isapprox.(a * b + c, Array(d_dev); rtol=0.01)) From 76ed2bbd014587e814ef83d762b29d48b33fea05 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 7 Dec 2019 18:29:27 +0100 Subject: [PATCH 60/81] Only run WMMA examples for recent Julia --- docs/src/lib/device/wmma.md | 14 ++++++++++++-- examples/wmma/high-level.jl | 7 +++++++ examples/wmma/low-level.jl | 7 +++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index a8a40b8f..e8a2aa67 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -116,10 +116,15 @@ In what follows, each of these will be discussed. ### Example ````@eval +lines = readlines("../../../../examples/wmma/low-level.jl") +start = findfirst(x -> x == "### START", lines) + 1 +stop = findfirst(x -> x == "### END", lines) - 1 +example = join(lines[start:stop], '\n') + using Markdown Markdown.parse(""" ```julia -$(read("../../../../examples/wmma/low-level.jl", String)) +$(example) ``` """) ```` @@ -183,10 +188,15 @@ CUDAnative.wmma_fill_c ### Example ````@eval +lines = readlines("../../../../examples/wmma/high-level.jl") +start = findfirst(x -> x == "### START", lines) + 1 +stop = findfirst(x -> x == "### END", lines) - 1 +example = join(lines[start:stop], '\n') + using Markdown Markdown.parse(""" ```julia -$(read("../../../../examples/wmma/high-level.jl", String)) +$(example) ``` """) ```` diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl index 5aabb830..1efddf6e 100644 --- a/examples/wmma/high-level.jl +++ b/examples/wmma/high-level.jl @@ -1,3 +1,7 @@ +# Need https://github.com/JuliaLang/julia/pull/33970 +if VERSION >= v"1.4.0-DEV.534" + +### START using CUDAnative using CuArrays using Test @@ -29,3 +33,6 @@ end d = Array(d_dev) @test all(isapprox.(a * b + c, d; rtol=0.01)) +### END + +end diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl index 89215ea5..9a58e92d 100644 --- a/examples/wmma/low-level.jl +++ b/examples/wmma/low-level.jl @@ -1,3 +1,7 @@ +# Need https://github.com/JuliaLang/julia/pull/33970 +if VERSION >= v"1.4.0-DEV.534" + +### START using CUDAnative using CuArrays using Test @@ -27,3 +31,6 @@ end @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) @test all(isapprox.(a * b + c, Array(d_dev); rtol=0.01)) +### END + +end From e181320202309ce1077c7b1ce6ae3f5e73390069 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 10 Dec 2019 15:37:32 +0100 Subject: [PATCH 61/81] Undo changes to Project and Manifest file --- Manifest.toml | 50 -------------------------------------------------- Project.toml | 1 - 2 files changed, 51 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 8b0fa308..f3504f26 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,11 +1,5 @@ # This file is machine-generated - editing it directly is not advised -[[AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "0.5.0" - [[Adapt]] deps = ["LinearAlgebra"] git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf" @@ -32,18 +26,6 @@ git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921" uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" version = "4.0.4" -[[CUDAnative]] -deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"] -git-tree-sha1 = "a67b38619d1fa131027bac1c4a81f0012254d1fd" -uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" -version = "2.6.0" - -[[CuArrays]] -deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] -git-tree-sha1 = "e99db1397ce85975203a9d736ab37534730996ca" -uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" -version = "1.5.0" - [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections"] git-tree-sha1 = "1fe8fad5fc84686dcbc674aa255bc867a64f8132" @@ -54,12 +36,6 @@ version = "0.17.5" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -[[GPUArrays]] -deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] -git-tree-sha1 = "e756da6cee76a5f1436a05827fa8fdf3badc577f" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "2.0.1" - [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -80,22 +56,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[MacroTools]] -deps = ["DataStructures", "Markdown", "Random"] -git-tree-sha1 = "e2fc7a55bb2224e203bbd8b59f72b91323233458" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.3" - [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[NNlib]] -deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"] -git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.0" - [[OrderedCollections]] deps = ["Random", "Serialization", "Test"] git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" @@ -110,26 +74,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[[Requires]] -deps = ["Test"] -git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "0.5.2" - [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/Project.toml b/Project.toml index 4902a273..230f808a 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,6 @@ julia = "1" [extras] CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" [targets] test = ["Test", "CuArrays"] From 841199ed4dccb4998377d7568f7a6789c7d0c3de Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 10 Dec 2019 16:58:41 +0100 Subject: [PATCH 62/81] Document flattening and broadcast --- docs/src/lib/device/wmma.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index e8a2aa67..4c902ab5 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -185,6 +185,19 @@ CUDAnative.wmma_store_d CUDAnative.wmma_fill_c ``` +### Element access and broadcasting + +Similar to the CUDA C++ WMMA API, [`WMMAFragment`](@ref)s have an `x` member that can be used to access individual elements. +Note that, in contrast to the values returned by the LLVM intrinsics, the `x` member is flattened. +For example, while the `Float16` variants of the `load_a` instrinsics return `NTuple{8, NTuple{2, VecElement{Float16}}}`, the `x` member has type `NTuple{16, Float16}`. + +Typically, you will only need to access the `x` member to perform elementwise operations. +This can be more succinctly expressed using Julia's broadcast mechanism. +For example, to double each element in a fragment, you can simply use: +```julia +frag = 2.0f0 .* frag +``` + ### Example ````@eval From 86af32819affdba07089443912a13063ab361d32 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 10 Dec 2019 17:14:04 +0100 Subject: [PATCH 63/81] Bump minimum Julia version --- docs/src/lib/device/wmma.md | 13 +++++++++---- examples/wmma/high-level.jl | 3 ++- examples/wmma/low-level.jl | 3 ++- test/device/wmma.jl | 3 ++- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/docs/src/lib/device/wmma.md b/docs/src/lib/device/wmma.md index 4c902ab5..0b88998c 100644 --- a/docs/src/lib/device/wmma.md +++ b/docs/src/lib/device/wmma.md @@ -5,19 +5,24 @@ This interface enables programmatic access to Tensor Cores, a new hardware featu Access to WMMA using CUDAnative is available in two levels: low level wrappers around the LLVM intrinsics, and a higher-level API, similar to that of CUDA C. -Note that to use the WMMA intrinsics, you need a sufficiently recent version of Julia: `v1.4.0-DEV.534` or later. +Note that to use the WMMA intrinsics, you need a sufficiently recent version of Julia: `v1.4.0-DEV.564` or later. You can check this by running the following in the REPL: ```julia -VERSION >= v"1.4.0-DEV.534" +VERSION >= v"1.4.0-DEV.564" ``` !!! note - If you're running into the following error while using the WMMA interfaces: + If you're running into any of following errors while using the WMMA interfaces: ``` LLVM error: Do not know how to split the result of this operator! ``` - then make sure you are running Julia v1.4.0-DEV.534 or later! + or + ``` + CUDA error: a PTX JIT compilation failed (code 218, ERROR_INVALID_PTX) + ptxas application ptx input, line ; error : .aligned modifier required for instruction '' + ``` + then make sure you are running Julia v1.4.0-DEV.564 or later! ## Introduction of Terminology diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl index 1efddf6e..a1e4e5ff 100644 --- a/examples/wmma/high-level.jl +++ b/examples/wmma/high-level.jl @@ -1,5 +1,6 @@ # Need https://github.com/JuliaLang/julia/pull/33970 -if VERSION >= v"1.4.0-DEV.534" +# and https://github.com/JuliaLang/julia/pull/34043 +if VERSION >= v"1.4.0-DEV.564" ### START using CUDAnative diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl index 9a58e92d..5ca8d64d 100644 --- a/examples/wmma/low-level.jl +++ b/examples/wmma/low-level.jl @@ -1,5 +1,6 @@ # Need https://github.com/JuliaLang/julia/pull/33970 -if VERSION >= v"1.4.0-DEV.534" +# and https://github.com/JuliaLang/julia/pull/34043 +if VERSION >= v"1.4.0-DEV.564" ### START using CUDAnative diff --git a/test/device/wmma.jl b/test/device/wmma.jl index e1927df9..1130a67b 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -1,5 +1,6 @@ # Need https://github.com/JuliaLang/julia/pull/33970 -if VERSION >= v"1.4.0-DEV.534" +# and https://github.com/JuliaLang/julia/pull/34043 +if VERSION >= v"1.4.0-DEV.564" @testset "WMMA" begin ################################################################################ From 6d376f3cd6ddf6f62803026bda0a0179b4e32040 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 10 Dec 2019 18:19:55 +0100 Subject: [PATCH 64/81] Use exit() in examples --- examples/wmma/high-level.jl | 6 +++--- examples/wmma/low-level.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl index a1e4e5ff..cda00ba2 100644 --- a/examples/wmma/high-level.jl +++ b/examples/wmma/high-level.jl @@ -1,6 +1,8 @@ # Need https://github.com/JuliaLang/julia/pull/33970 # and https://github.com/JuliaLang/julia/pull/34043 -if VERSION >= v"1.4.0-DEV.564" +if VERSION < v"1.4.0-DEV.564" + exit() +end ### START using CUDAnative @@ -35,5 +37,3 @@ d = Array(d_dev) @test all(isapprox.(a * b + c, d; rtol=0.01)) ### END - -end diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl index 5ca8d64d..a159a4eb 100644 --- a/examples/wmma/low-level.jl +++ b/examples/wmma/low-level.jl @@ -1,6 +1,8 @@ # Need https://github.com/JuliaLang/julia/pull/33970 # and https://github.com/JuliaLang/julia/pull/34043 -if VERSION >= v"1.4.0-DEV.564" +if VERSION < v"1.4.0-DEV.564" + exit() +end ### START using CUDAnative @@ -33,5 +35,3 @@ end @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) @test all(isapprox.(a * b + c, Array(d_dev); rtol=0.01)) ### END - -end From c50beed60545b8f61f395fbdad40e50dec4d899b Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 10 Dec 2019 18:36:50 +0100 Subject: [PATCH 65/81] Temporarily disable test --- test/device/wmma.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 1130a67b..0cbeda0f 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -166,7 +166,8 @@ if VERSION >= v"1.4.0-DEV.564" @test CUDAnative.unflatten(Float64, (5.0,)) == 5.0 @test CUDAnative.unflatten(VecElement{Float16}, (Float16(5),)) == VecElement{Float16}(5) @test CUDAnative.unflatten(NTuple{8, Int64}, ntuple(i -> i, 8)) == ntuple(i -> i, 8) - @test CUDAnative.unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) + # TODO: Reenable this + #= @test CUDAnative.unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) =# @test CUDAnative.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) @test CUDAnative.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) end From 1b70c909aa61be7a4c2400025f62431b08930964 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 12 Dec 2019 11:35:16 +0100 Subject: [PATCH 66/81] Bump min SM for Julia nightly tests --- .gitlab-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 78bc8f85..616256c6 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -47,6 +47,7 @@ julia:nightly: - .test tags: - nvidia + - sm_75 allow_failure: true From 05636790fb293ab398dbf93b286ffcbce3b08186 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 12 Dec 2019 11:35:38 +0100 Subject: [PATCH 67/81] Check capability in WMMA tests --- examples/wmma/high-level.jl | 5 +++++ examples/wmma/low-level.jl | 5 +++++ test/device/wmma.jl | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl index cda00ba2..1edfb1db 100644 --- a/examples/wmma/high-level.jl +++ b/examples/wmma/high-level.jl @@ -4,6 +4,11 @@ if VERSION < v"1.4.0-DEV.564" exit() end +using CUDAnative +if CUDAnative.current_capability() < v"7.0" + exit() +end + ### START using CUDAnative using CuArrays diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl index a159a4eb..a9f178d9 100644 --- a/examples/wmma/low-level.jl +++ b/examples/wmma/low-level.jl @@ -4,6 +4,11 @@ if VERSION < v"1.4.0-DEV.564" exit() end +using CUDAnative +if CUDAnative.current_capability() < v"7.0" + exit() +end + ### START using CUDAnative using CuArrays diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 0cbeda0f..0161014e 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -1,6 +1,6 @@ # Need https://github.com/JuliaLang/julia/pull/33970 # and https://github.com/JuliaLang/julia/pull/34043 -if VERSION >= v"1.4.0-DEV.564" +if VERSION >= v"1.4.0-DEV.564" && CUDAnative.current_capability() >= v"7.0" @testset "WMMA" begin ################################################################################ From 3b8391d123dfe9ee2b99032b9bc03ad6eb981fde Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sun, 15 Dec 2019 15:54:58 +0100 Subject: [PATCH 68/81] Reenable unflatten test --- test/device/wmma.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 0161014e..5dd82183 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -166,8 +166,7 @@ if VERSION >= v"1.4.0-DEV.564" && CUDAnative.current_capability() >= v"7.0" @test CUDAnative.unflatten(Float64, (5.0,)) == 5.0 @test CUDAnative.unflatten(VecElement{Float16}, (Float16(5),)) == VecElement{Float16}(5) @test CUDAnative.unflatten(NTuple{8, Int64}, ntuple(i -> i, 8)) == ntuple(i -> i, 8) - # TODO: Reenable this - #= @test CUDAnative.unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) =# + @test CUDAnative.unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) @test CUDAnative.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) @test CUDAnative.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) end From fdc8ddef8523ea7db98c6ca0f7595e784ee59f3a Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 26 Dec 2019 23:06:56 +0100 Subject: [PATCH 69/81] Update version check --- docs/src/device/wmma.md | 6 +++--- examples/wmma/high-level.jl | 2 +- examples/wmma/low-level.jl | 2 +- test/device/wmma.jl | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/src/device/wmma.md b/docs/src/device/wmma.md index 0b88998c..d36661be 100644 --- a/docs/src/device/wmma.md +++ b/docs/src/device/wmma.md @@ -5,10 +5,10 @@ This interface enables programmatic access to Tensor Cores, a new hardware featu Access to WMMA using CUDAnative is available in two levels: low level wrappers around the LLVM intrinsics, and a higher-level API, similar to that of CUDA C. -Note that to use the WMMA intrinsics, you need a sufficiently recent version of Julia: `v1.4.0-DEV.564` or later. +Note that to use the WMMA intrinsics, you need a sufficiently recent version of Julia: `v1.4.0-DEV.666` or later. You can check this by running the following in the REPL: ```julia -VERSION >= v"1.4.0-DEV.564" +VERSION >= v"1.4.0-DEV.666" ``` !!! note @@ -22,7 +22,7 @@ VERSION >= v"1.4.0-DEV.564" CUDA error: a PTX JIT compilation failed (code 218, ERROR_INVALID_PTX) ptxas application ptx input, line ; error : .aligned modifier required for instruction '' ``` - then make sure you are running Julia v1.4.0-DEV.564 or later! + then make sure you are running Julia v1.4.0-DEV.666 or later! ## Introduction of Terminology diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl index 1edfb1db..f97bed4f 100644 --- a/examples/wmma/high-level.jl +++ b/examples/wmma/high-level.jl @@ -1,6 +1,6 @@ # Need https://github.com/JuliaLang/julia/pull/33970 # and https://github.com/JuliaLang/julia/pull/34043 -if VERSION < v"1.4.0-DEV.564" +if VERSION < v"1.4.0-DEV.666" exit() end diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl index a9f178d9..0607e074 100644 --- a/examples/wmma/low-level.jl +++ b/examples/wmma/low-level.jl @@ -1,6 +1,6 @@ # Need https://github.com/JuliaLang/julia/pull/33970 # and https://github.com/JuliaLang/julia/pull/34043 -if VERSION < v"1.4.0-DEV.564" +if VERSION < v"1.4.0-DEV.666" exit() end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 5dd82183..9193419c 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -1,6 +1,6 @@ # Need https://github.com/JuliaLang/julia/pull/33970 # and https://github.com/JuliaLang/julia/pull/34043 -if VERSION >= v"1.4.0-DEV.564" && CUDAnative.current_capability() >= v"7.0" +if VERSION >= v"1.4.0-DEV.666" && CUDAnative.current_capability() >= v"7.0" @testset "WMMA" begin ################################################################################ From 3f9faace8c2db865e0a722b9fb977fed298db833 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 31 Jan 2020 21:26:03 +0100 Subject: [PATCH 70/81] Set CI_THOROUGH --- .gitlab-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d6d07fc4..d2c66151 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -27,6 +27,8 @@ julia:nightly: tags: - nvidia - sm_75 + variables: + CI_THOROUGH: 'true' allow_failure: true From a1b64f2cf757f4572a2725eed5bcfc89e8694dad Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 31 Jan 2020 21:30:23 +0100 Subject: [PATCH 71/81] Mark constants as 'const' --- src/device/cuda/wmma.jl | 90 ++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index a387327b..6c3d5e7e 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -3,33 +3,33 @@ ################################################################################ # Maps PTX types to Julia array types -map_ptx_to_jl_array = Dict( - "f16" => Float16, - "f32" => Float32 - ) +const map_ptx_to_jl_array = Dict( + "f16" => Float16, + "f32" => Float32 + ) # Maps PTX types to Julia fragment types -map_ptx_to_jl_frag = Dict( - "f16" => NTuple{2, VecElement{Float16}}, - "f32" => Float32 - ) +const map_ptx_to_jl_frag = Dict( + "f16" => NTuple{2, VecElement{Float16}}, + "f32" => Float32 + ) # Maps matrix & PTX types to fragment sizes -map_frag_sizes = Dict( - "a.f16" => 8, - "b.f16" => 8, - "c.f16" => 4, - "c.f32" => 8, - "d.f16" => 4, - "d.f32" => 8 - ) +const map_frag_sizes = Dict( + "a.f16" => 8, + "b.f16" => 8, + "c.f16" => 4, + "c.f32" => 8, + "d.f16" => 4, + "d.f32" => 8 + ) # Maps PTX AS to Int -map_ptx_as_to_int = Dict( - "" => 0, - "shared" => 3, - "global" => 1 - ) +const map_ptx_as_to_int = Dict( + "" => 0, + "shared" => 3, + "global" => 1 + ) ################################################################################ # HELPER FUNCTIONS @@ -294,38 +294,38 @@ struct WMMAConfig{M, N, K, d_type} end # --------- # Maps Julia array types to string -map_jl_array_to_str = Dict(val => key for (key, val) in map_ptx_to_jl_array) +const map_jl_array_to_str = Dict(val => key for (key, val) in map_ptx_to_jl_array) # Maps CUDAnative.AS types to string -map_as_ty_to_str = Dict( - AS.Generic => "", - AS.Shared => "shared", - AS.Global => "global" - ) +const map_as_ty_to_str = Dict( + AS.Generic => "", + AS.Shared => "shared", + AS.Global => "global" + ) # Maps layout types to string -map_layout_ty_to_str = Dict( - WMMARowMajor => "row", - WMMAColMajor => "col" - ) +const map_layout_ty_to_str = Dict( + WMMARowMajor => "row", + WMMAColMajor => "col" + ) # Maps matrix & type to number of elements (size after flattening) -map_num_elems = Dict( - ("a", Float16) => 16, - ("b", Float16) => 16, - ("c", Float16) => 8, - ("c", Float32) => 8, - ("d", Float16) => 8, - ("d", Float32) => 8 - ) +const map_num_elems = Dict( + ("a", Float16) => 16, + ("b", Float16) => 16, + ("c", Float16) => 8, + ("c", Float32) => 8, + ("d", Float16) => 8, + ("d", Float32) => 8 + ) # Maps matrix to its use -map_matrix_to_use = Dict( - "a" => WMMAMatrixA, - "b" => WMMAMatrixB, - "c" => WMMAAccumulator, - "d" => WMMAAccumulator - ) +const map_matrix_to_use = Dict( + "a" => WMMAMatrixA, + "b" => WMMAMatrixB, + "c" => WMMAAccumulator, + "d" => WMMAAccumulator + ) # ---------------- # Helper functions From d215a562d30b48aa569e41670f3553781ff0a37c Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 31 Jan 2020 22:14:19 +0100 Subject: [PATCH 72/81] Remove join_nonempty --- src/device/cuda/wmma.jl | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 6c3d5e7e..50af9157 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -35,13 +35,6 @@ const map_ptx_as_to_int = Dict( # HELPER FUNCTIONS ################################################################################ -function join_nonempty(args...) - delim = args[end] - arr = [args[1:end-1]...] - - return join(arr[arr .!= ""], delim) -end - # Returns (Julia array type, Julia fragment type, fragment size) get_frag_info(matrix, ptx_el_type) = ( map_ptx_to_jl_array[ptx_el_type], @@ -76,7 +69,7 @@ for mat in ["a", "b", "c"], addr_space_int = get_addrspace_info(addr_space) # Name of the Julia wrapper function - func_name = Symbol(join_nonempty("llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type, "_")) + func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type]), "_")) # Name of the LLVM intrinsic llvm_intr = "llvm.nvvm.wmma.$shape.load.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8" @@ -106,7 +99,7 @@ for mat in ["d"], addr_space_int = get_addrspace_info(addr_space) # Name of the Julia wrapper function - func_name = Symbol(join_nonempty("llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type, "_")) + func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_")) # Name of the LLVM intrinsic llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8" @@ -135,7 +128,7 @@ for a_layout in ["col", "row"], a_elem_type in ["f16"] # Name of the Julia wrapper function - func_name = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type, "_")) + func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_")) # Name of the LLVM intrinsic llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type" @@ -420,7 +413,7 @@ for mat in ["a", "b", "c"] L_ret = ($mat == "c") ? WMMAUnspecified : L # Name of the Julia wrapper - wrapper = Symbol(join_nonempty("llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str, "_")) + wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_")) return quote x = flatten($wrapper(addr, stride)) @@ -470,7 +463,7 @@ wmma_mma shape = get_hl_shape(M, N, K) # Name of the Julia wrapper - wrapper = Symbol(join_nonempty("llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str, "_")) + wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_")) return quote a_unfl = unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x) @@ -522,7 +515,7 @@ wmma_store_d num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T) # Name of the Julia wrapper - wrapper = Symbol(join_nonempty("llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str, "_")) + wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str]), "_")) return quote d_unfl = unflatten(NTuple{$frag_sz, $frag_ty}, d.x) From d66f738f34aff182ca9a127c3dd1dcbb2511e12a Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 31 Jan 2020 22:52:36 +0100 Subject: [PATCH 73/81] Implement indexing for WMMAFragment --- src/device/cuda/wmma.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 50af9157..bead3e86 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -250,12 +250,20 @@ export WMMAFragment Type that represents per-thread intermediate results of WMMA operations. -You can access individual elements using the `x` member, but beware that the exact ordering of elements is unspecified. +You can access individual elements using the `x` member or [] operator, but beware that the exact ordering of elements is unspecified. """ struct WMMAFragment{M, N, K, FS, T, L <: WMMAFragmentLayout, U <: WMMAFragmentUse} x::NTuple{FS, T} end +# ---------------------- +# WMMA fragment indexing +# ---------------------- + +for f in (:getindex, :setindex!, :firstindex, :lastindex) + @eval Base.$f(frag::WMMAFragment, args...) = $f(frag.x, args...) +end + # ------------------ # WMMA configuration # ------------------ @@ -584,8 +592,8 @@ Base.broadcastable(frag::WMMAFragment) = frag Base.axes(frag::WMMAFragment) = axes(frag.x) # Helper functions to get element at specified index -@inline get_index(x, i) = x # scalar -@inline get_index(frag::WMMAFragment, i) = frag.x[i] # WMMAFragment +@inline get_index(x, i) = x # scalar +@inline get_index(frag::WMMAFragment, i) = frag[i] # WMMAFragment # Helper functions to get first fragment in broadcast call @inline find_first_fragment(args::Tuple) = find_first_fragment(args[1], Base.tail(args)) From dad2a66067d259f585e9940b9f7a0069088ec27d Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 31 Jan 2020 23:07:54 +0100 Subject: [PATCH 74/81] Refactor conversion of AS to Int --- src/device/cuda/wmma.jl | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index bead3e86..145ec482 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -24,12 +24,12 @@ const map_frag_sizes = Dict( "d.f32" => 8 ) -# Maps PTX AS to Int -const map_ptx_as_to_int = Dict( - "" => 0, - "shared" => 3, - "global" => 1 - ) +# Maps PTX AS to CUDAnative.AS +const map_ptx_as_to_as_ty = Dict( + "" => AS.Generic, + "shared" => AS.Shared, + "global" => AS.Global + ) ################################################################################ # HELPER FUNCTIONS @@ -42,7 +42,7 @@ get_frag_info(matrix, ptx_el_type) = ( map_frag_sizes["$matrix.$ptx_el_type"] ) -get_addrspace_info(addr_space) = map_ptx_as_to_int[addr_space] +get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space]) ################################################################################ # LOW LEVEL API @@ -298,11 +298,7 @@ struct WMMAConfig{M, N, K, d_type} end const map_jl_array_to_str = Dict(val => key for (key, val) in map_ptx_to_jl_array) # Maps CUDAnative.AS types to string -const map_as_ty_to_str = Dict( - AS.Generic => "", - AS.Shared => "shared", - AS.Global => "global" - ) +const map_as_ty_to_str = Dict(val => key for (key, val) in map_ptx_as_to_as_ty) # Maps layout types to string const map_layout_ty_to_str = Dict( From 28f44b54a164c77eb7d4284236b557f8e86e3093 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Fri, 31 Jan 2020 23:59:34 +0100 Subject: [PATCH 75/81] Move LLVM instrincs doc to docstrings --- docs/src/device/wmma.md | 60 ++++++------------------------------- src/device/cuda/wmma.jl | 66 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 51 deletions(-) diff --git a/docs/src/device/wmma.md b/docs/src/device/wmma.md index d36661be..5801ecde 100644 --- a/docs/src/device/wmma.md +++ b/docs/src/device/wmma.md @@ -62,61 +62,19 @@ The LLVM intrinsics are subdivided in three categories: load, store and multiply In what follows, each of these will be discussed. ### Load matrix - -**Julia function:** `llvm_wmma_load_{matrix}_{layout}_{shape}_{addr_space}_stride_{elem_type}(src_addr, stride)` - -**Corresponding LLVM instrinsic:** `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}` - -**Arguments:** -- `src_addr`: The memory address to load from. -- `stride`: The leading dimension of the matrix, in numbers of elements. - -**Placeholders:** -- `{matrix}`: The matrix to load. Can be `a`, `b` or `c`. -- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. -- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. -- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. -- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). Note that `f32` is only valid for the matrix ``C``. +```@docs +CUDAnative.llvm_wmma_load +``` ### Perform multiply-accumulate - -**Julia function:** `llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c)` - -**Corresponding LLVM instrinsic:** `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}` - -**Arguments:** -- `a`: The WMMA fragment corresponding to the matrix ``A``. -- `b`: The WMMA fragment corresponding to the matrix ``B``. -- `c`: The WMMA fragment corresponding to the matrix ``C``. - -**Placeholders:** -- `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. -- `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. -- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. -- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). -- `{c_elem_type}`: The type of each element in the ``C`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). - -!!! warning - - Remember that the shape, type and layout of all operations (be it MMA, load or store) **MUST** match. - Otherwise, the behaviour is undefined! +```@docs +CUDAnative.llvm_wmma_mma +``` ### Store matrix - -**Julia function:** `llvm_wmma_store_d_{layout}_{shape}_{addr_space}_stride_{elem_type}(dst_addr, data, stride)` - -**Corresponding LLVM intrinsic:** `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}` - -**Arguments:** -- `dst_addr`: The memory address to store to. -- `data`: The ``D`` fragment to store. -- `stride`: The leading dimension of the matrix, in numbers of elements. - -**Placeholders:** -- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. -- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. -- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. -- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). +```@docs +CUDAnative.llvm_wmma_store +``` ### Example diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 145ec482..72152468 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -52,6 +52,25 @@ get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space]) # Matrix load # ----------- +@doc """ + llvm_wmma_load_{matrix}_{layout}_{shape}_{addr_space}_stride_{elem_type}(src_addr, stride) + +Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}`. + +# Arguments +- `src_addr`: The memory address to load from. +- `stride`: The leading dimension of the matrix, in numbers of elements. + +# Placeholders +- `{matrix}`: The matrix to load. Can be `a`, `b` or `c`. +- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. +- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). Note that `f32` is only valid for the matrix ``C``. +""" +llvm_wmma_load() = error("Cannot call llvm_wmma_load without values for placeholders!") +export llvm_wmma_load + for mat in ["a", "b", "c"], layout in ["col", "row"], shape in ["m16n16k16"], @@ -81,12 +100,32 @@ for mat in ["a", "b", "c"], @eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $frag_ty}, (Ref{$arr_ty}, Int32), src_addr, stride) @eval export $func_name + @eval @doc (@doc llvm_wmma_load) $func_name end # ------------ # Matrix store # ------------ +@doc """ + llvm_wmma_store_d_{layout}_{shape}_{addr_space}_stride_{elem_type}(dst_addr, data, stride) + +Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}`. + +# Arguments +- `dst_addr`: The memory address to store to. +- `data`: The ``D`` fragment to store. +- `stride`: The leading dimension of the matrix, in numbers of elements. + +# Placeholders +- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. +- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). +""" +llvm_wmma_store() = error("Cannot call llvm_wmma_store without values for placeholders!") +export llvm_wmma_store + for mat in ["d"], layout in ["col", "row"], shape in ["m16n16k16"], @@ -113,12 +152,38 @@ for mat in ["d"], @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$arr_ty}, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) @eval export $func_name + @eval @doc (@doc llvm_wmma_store) $func_name end # -------------------------- # Matrix multiply accumulate # -------------------------- +@doc """ + llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) + +Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`. + +# Arguments +- `a`: The WMMA fragment corresponding to the matrix ``A``. +- `b`: The WMMA fragment corresponding to the matrix ``B``. +- `c`: The WMMA fragment corresponding to the matrix ``C``. + +# Placeholders +- `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. +- `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). +- `{c_elem_type}`: The type of each element in the ``C`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). + +!!! warning + + Remember that the shape, type and layout of all operations (be it MMA, load or store) **MUST** match. + Otherwise, the behaviour is undefined! +""" +llvm_wmma_mma() = error("Cannot call llvm_wmma_mma without values for placeholders!") +export llvm_wmma_mma + for a_layout in ["col", "row"], b_layout in ["col", "row"], shape in ["m16n16k16"], @@ -151,6 +216,7 @@ for a_layout in ["col", "row"], @eval $func_name(a, b, c) = ccall($ccall_name, llvmcall, NTuple{$d_sz, $d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)) @eval export $func_name + @eval @doc (@doc llvm_wmma_mma) $func_name end ################################################################################ From 9d9d5e2d2d3fa5f65cad0f936599ddc85f547157 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 1 Feb 2020 00:21:54 +0100 Subject: [PATCH 76/81] Fix path to examples in docs --- docs/src/device/wmma.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/device/wmma.md b/docs/src/device/wmma.md index 5801ecde..ab426b60 100644 --- a/docs/src/device/wmma.md +++ b/docs/src/device/wmma.md @@ -79,7 +79,7 @@ CUDAnative.llvm_wmma_store ### Example ````@eval -lines = readlines("../../../../examples/wmma/low-level.jl") +lines = readlines("../../../examples/wmma/low-level.jl") start = findfirst(x -> x == "### START", lines) + 1 stop = findfirst(x -> x == "### END", lines) - 1 example = join(lines[start:stop], '\n') @@ -164,7 +164,7 @@ frag = 2.0f0 .* frag ### Example ````@eval -lines = readlines("../../../../examples/wmma/high-level.jl") +lines = readlines("../../../examples/wmma/high-level.jl") start = findfirst(x -> x == "### START", lines) + 1 stop = findfirst(x -> x == "### END", lines) - 1 example = join(lines[start:stop], '\n') From ae19e0374cbd4432857001b234ef40d93870bbcc Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 1 Feb 2020 01:43:25 +0100 Subject: [PATCH 77/81] Move everything in WMMA submodule --- docs/src/device/wmma.md | 36 +++---- examples/wmma/high-level.jl | 12 +-- examples/wmma/low-level.jl | 10 +- src/device/cuda/wmma.jl | 181 +++++++++++++++++++----------------- test/device/wmma.jl | 65 +++++++------ 5 files changed, 157 insertions(+), 147 deletions(-) diff --git a/docs/src/device/wmma.md b/docs/src/device/wmma.md index ab426b60..b4dcfc1a 100644 --- a/docs/src/device/wmma.md +++ b/docs/src/device/wmma.md @@ -63,17 +63,17 @@ In what follows, each of these will be discussed. ### Load matrix ```@docs -CUDAnative.llvm_wmma_load +CUDAnative.WMMA.llvm_wmma_load ``` ### Perform multiply-accumulate ```@docs -CUDAnative.llvm_wmma_mma +CUDAnative.WMMA.llvm_wmma_mma ``` ### Store matrix ```@docs -CUDAnative.llvm_wmma_store +CUDAnative.WMMA.llvm_wmma_store ``` ### Example @@ -106,51 +106,51 @@ Note that, in CUDA C++, the fragment is responsible for both the storage of inte All CUDA C++ WMMA calls are function templates that take the resultant fragment as a by-reference argument. As a result, the type of this argument can be used during overload resolution to select the correct WMMA instruction to call. -In contrast, the API in Julia separates the WMMA storage ([`WMMAFragment`](@ref)) and configuration ([`WMMAConfig`](@ref)). +In contrast, the API in Julia separates the WMMA storage ([`Fragment`](@ref)) and configuration ([`Config`](@ref)). Instead of taking the resultant fragment by reference, the Julia functions just return it. This makes the dataflow clearer, but it also means that the type of that fragment cannot be used for selection of the correct WMMA instruction. Thus, there is still a limited amount of information that cannot be inferred from the argument types, but must nonetheless match for all WMMA operations, such as the overall shape of the MMA. -This is accomplished by a separate "WMMA configuration" (see [`WMMAConfig`](@ref)) that you create once, and then give as an argument to all intrinsics. +This is accomplished by a separate "WMMA configuration" (see [`Config`](@ref)) that you create once, and then give as an argument to all intrinsics. ### Fragment ```@docs -CUDAnative.WMMAFragmentLayout -CUDAnative.WMMARowMajor -CUDAnative.WMMAColMajor -CUDAnative.WMMAUnspecified -CUDAnative.WMMAFragment +CUDAnative.WMMA.FragmentLayout +CUDAnative.WMMA.RowMajor +CUDAnative.WMMA.ColMajor +CUDAnative.WMMA.Unspecified +CUDAnative.WMMA.Fragment ``` ### WMMA configuration ```@docs -CUDAnative.WMMAConfig +CUDAnative.WMMA.Config ``` ### Load matrix ```@docs -CUDAnative.wmma_load_a -CUDAnative.wmma_load_b -CUDAnative.wmma_load_c +CUDAnative.WMMA.wmma_load_a +CUDAnative.WMMA.wmma_load_b +CUDAnative.WMMA.wmma_load_c ``` ### Perform multiply-accumulate ```@docs -CUDAnative.wmma_mma +CUDAnative.WMMA.wmma_mma ``` ### Store matrix ```@docs -CUDAnative.wmma_store_d +CUDAnative.WMMA.wmma_store_d ``` ### Fill fragment ```@docs -CUDAnative.wmma_fill_c +CUDAnative.WMMA.wmma_fill_c ``` ### Element access and broadcasting -Similar to the CUDA C++ WMMA API, [`WMMAFragment`](@ref)s have an `x` member that can be used to access individual elements. +Similar to the CUDA C++ WMMA API, [`Fragment`](@ref)s have an `x` member that can be used to access individual elements. Note that, in contrast to the values returned by the LLVM intrinsics, the `x` member is flattened. For example, while the `Float16` variants of the `load_a` instrinsics return `NTuple{8, NTuple{2, VecElement{Float16}}}`, the `x` member has type `NTuple{16, Float16}`. diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl index f97bed4f..63f16c34 100644 --- a/examples/wmma/high-level.jl +++ b/examples/wmma/high-level.jl @@ -24,15 +24,15 @@ c_dev = CuArray(c) d_dev = similar(c_dev) function kernel(a_dev, b_dev, c_dev, d_dev) - conf = WMMAConfig{16, 16, 16, Float32} + conf = WMMA.Config{16, 16, 16, Float32} - a_frag = wmma_load_a(pointer(a_dev), 16, WMMAColMajor, conf) - b_frag = wmma_load_b(pointer(b_dev), 16, WMMAColMajor, conf) - c_frag = wmma_load_c(pointer(c_dev), 16, WMMAColMajor, conf) + a_frag = WMMA.load_a(pointer(a_dev), 16, WMMA.ColMajor, conf) + b_frag = WMMA.load_b(pointer(b_dev), 16, WMMA.ColMajor, conf) + c_frag = WMMA.load_c(pointer(c_dev), 16, WMMA.ColMajor, conf) - d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) + d_frag = WMMA.mma(a_frag, b_frag, c_frag, conf) - wmma_store_d(pointer(d_dev), d_frag, 16, WMMAColMajor, conf) + WMMA.store_d(pointer(d_dev), d_frag, 16, WMMA.ColMajor, conf) return end diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl index 0607e074..9eadd28d 100644 --- a/examples/wmma/low-level.jl +++ b/examples/wmma/low-level.jl @@ -27,13 +27,13 @@ d_dev = similar(c_dev) # Matrix multiply-accumulate kernel (D = A * B + C) function kernel(a_dev, b_dev, c_dev, d_dev) - a_frag = llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16) - b_frag = llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16) - c_frag = llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16) + a_frag = WMMA.llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16) + b_frag = WMMA.llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16) + c_frag = WMMA.llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16) - d_frag = llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag) + d_frag = WMMA.llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag) - llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16) + WMMA.llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16) return end diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 72152468..a8dda8cb 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -1,3 +1,8 @@ +export WMMA +module WMMA + +using CUDAnative: AS, DevicePtr + ################################################################################ # CONSTANTS ################################################################################ @@ -264,33 +269,33 @@ end # WMMA fragment # ------------- -export WMMAFragmentLayout, WMMARowMajor, WMMAColMajor, WMMAUnspecified +export FragmentLayout, RowMajor, ColMajor, Unspecified """ - WMMAFragmentLayout + FragmentLayout Abstract type that specifies the storage layout of a matrix. -Possible values are [`WMMARowMajor`](@ref), [`WMMAColMajor`](@ref) and [`WMMAUnspecified`](@ref). +Possible values are [`RowMajor`](@ref), [`ColMajor`](@ref) and [`Unspecified`](@ref). """ -abstract type WMMAFragmentLayout end +abstract type FragmentLayout end """ - WMMARowMajor + RowMajor Type that represents a matrix stored in row major (C style) order. """ -struct WMMARowMajor <: WMMAFragmentLayout end +struct RowMajor <: FragmentLayout end """ - WMMAColMajor + ColMajor Type that represents a matrix stored in column major (Julia style) order. """ -struct WMMAColMajor <: WMMAFragmentLayout end +struct ColMajor <: FragmentLayout end """ - WMMAUnspecified + Unspecified Type that represents a matrix stored in an unspecified order. @@ -298,27 +303,27 @@ Type that represents a matrix stored in an unspecified order. This storage format is not valid for all WMMA operations! """ -struct WMMAUnspecified <: WMMAFragmentLayout end +struct Unspecified <: FragmentLayout end -export WMMAMatrixA, WMMAMatrixB, WMMAAccumulator +export MatrixA, MatrixB, Accumulator -abstract type WMMAFragmentUse end -struct WMMAMatrixA <: WMMAFragmentUse end -struct WMMAMatrixB <: WMMAFragmentUse end -struct WMMAAccumulator <: WMMAFragmentUse end +abstract type FragmentUse end +struct MatrixA <: FragmentUse end +struct MatrixB <: FragmentUse end +struct Accumulator <: FragmentUse end -export WMMAFragment +export Fragment """ - WMMAFragment + Fragment Type that represents per-thread intermediate results of WMMA operations. You can access individual elements using the `x` member or [] operator, but beware that the exact ordering of elements is unspecified. """ -struct WMMAFragment{M, N, K, FS, T, L <: WMMAFragmentLayout, U <: WMMAFragmentUse} +struct Fragment{M, N, K, FS, T, L <: FragmentLayout, U <: FragmentUse} x::NTuple{FS, T} end @@ -327,17 +332,17 @@ end # ---------------------- for f in (:getindex, :setindex!, :firstindex, :lastindex) - @eval Base.$f(frag::WMMAFragment, args...) = $f(frag.x, args...) + @eval Base.$f(frag::Fragment, args...) = $f(frag.x, args...) end # ------------------ # WMMA configuration # ------------------ -export WMMAConfig +export Config """ - WMMAConfig{M, N, K, d_type} + Config{M, N, K, d_type} Type that contains all information for WMMA operations that cannot be inferred from the argument's types. @@ -346,15 +351,15 @@ WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\c `d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. -All WMMA operations take a `WMMAConfig` as their final argument. +All WMMA operations take a `Config` as their final argument. # Examples ```jldoctest -julia> config = WMMAConfig{16, 16, 16, Float32} -WMMAConfig{16,16,16,Float32} +julia> config = Config{16, 16, 16, Float32} +Config{16,16,16,Float32} ``` """ -struct WMMAConfig{M, N, K, d_type} end +struct Config{M, N, K, d_type} end # --------- # Constants @@ -368,8 +373,8 @@ const map_as_ty_to_str = Dict(val => key for (key, val) in map_ptx_as_to_as_ty) # Maps layout types to string const map_layout_ty_to_str = Dict( - WMMARowMajor => "row", - WMMAColMajor => "col" + RowMajor => "row", + ColMajor => "col" ) # Maps matrix & type to number of elements (size after flattening) @@ -384,10 +389,10 @@ const map_num_elems = Dict( # Maps matrix to its use const map_matrix_to_use = Dict( - "a" => WMMAMatrixA, - "b" => WMMAMatrixB, - "c" => WMMAAccumulator, - "d" => WMMAAccumulator + "a" => MatrixA, + "b" => MatrixB, + "c" => Accumulator, + "d" => Accumulator ) # ---------------- @@ -443,51 +448,51 @@ end # WMMA load # --------- -export wmma_load_a, wmma_load_b, wmma_load_c +export load_a, load_b, load_c """ - wmma_load_a(addr, stride, layout, config) - wmma_load_b(addr, stride, layout, config) - wmma_load_c(addr, stride, layout, config) + load_a(addr, stride, layout, config) + load_b(addr, stride, layout, config) + load_c(addr, stride, layout, config) -Load the matrix `a`, `b` or `c` from the memory location indicated by `addr`, and return the resulting [`WMMAFragment`](@ref). +Load the matrix `a`, `b` or `c` from the memory location indicated by `addr`, and return the resulting [`Fragment`](@ref). # Arguments - `addr`: The address to load the matrix from. - `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. -- `layout`: The storage layout of the matrix. Possible values are [`WMMARowMajor`](@ref) and [`WMMAColMajor`](@ref). -- `config`: The WMMA configuration that should be used for loading this matrix. See [`WMMAConfig`](@ref). +- `layout`: The storage layout of the matrix. Possible values are [`RowMajor`](@ref) and [`ColMajor`](@ref). +- `config`: The WMMA configuration that should be used for loading this matrix. See [`Config`](@ref). -See also: [`WMMAFragment`](@ref), [`WMMAFragmentLayout`](@ref), [`WMMAConfig`](@ref) +See also: [`Fragment`](@ref), [`FragmentLayout`](@ref), [`Config`](@ref) !!! warning All threads in a warp **MUST** execute the load operation in lockstep, and have to use exactly the same arguments. Failure to do so will result in undefined behaviour. """ -wmma_load_a, wmma_load_b, wmma_load_c +load_a, load_b, load_c for mat in ["a", "b", "c"] - func_name = Symbol("wmma_load_$mat") + func_name = Symbol("load_$mat") @eval @generated function $func_name(addr::DevicePtr{T, AS}, stride::Number, layout::Type{L}, - config::Type{WMMAConfig{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} + config::Type{Config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) shape = get_hl_shape(M, N, K) num_els, _, _, arr_str = get_hl_frag_info($mat, T) U = get_hl_mat_use($mat) - L_ret = ($mat == "c") ? WMMAUnspecified : L + L_ret = ($mat == "c") ? Unspecified : L # Name of the Julia wrapper wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_")) return quote x = flatten($wrapper(addr, stride)) - return WMMAFragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) + return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) end end end @@ -497,31 +502,31 @@ end # WMMA multiply-accumulate # ------------------------ -export wmma_mma +export mma """ - wmma_mma(a, b, c, conf) + mma(a, b, c, conf) Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. # Arguments -- `a`: The [`WMMAFragment`](@ref) corresponding to the matrix ``A``. -- `b`: The [`WMMAFragment`](@ref) corresponding to the matrix ``B``. -- `c`: The [`WMMAFragment`](@ref) corresponding to the matrix ``C``. -- `conf`: The [`WMMAConfig`](@ref) that should be used in this WMMA operation. +- `a`: The [`Fragment`](@ref) corresponding to the matrix ``A``. +- `b`: The [`Fragment`](@ref) corresponding to the matrix ``B``. +- `c`: The [`Fragment`](@ref) corresponding to the matrix ``C``. +- `conf`: The [`Config`](@ref) that should be used in this WMMA operation. !!! warning All threads in a warp **MUST** execute the `mma` operation in lockstep, and have to use exactly the same arguments. Failure to do so will result in undefined behaviour. """ -wmma_mma +mma -@generated function wmma_mma(a::WMMAFragment{M, N, K, A_SZ, A_T, A_L, WMMAMatrixA}, - b::WMMAFragment{M, N, K, B_SZ, B_T, B_L, WMMAMatrixB}, - c::WMMAFragment{M, N, K, C_SZ, C_T, WMMAUnspecified, WMMAAccumulator}, - config::Type{WMMAConfig{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} +@generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, + b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, + c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, + config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T) _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T) @@ -541,7 +546,7 @@ wmma_mma c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) - return WMMAFragment{$M, $N, $K, $d_num_els, $D_T, WMMAUnspecified, WMMAAccumulator}(x) + return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x) end end @@ -550,34 +555,34 @@ end # WMMA store # ---------- -export wmma_store_d +export store_d """ - wmma_store_d(addr, d, stride, layout, config) + store_d(addr, d, stride, layout, config) Store the result matrix `d` to the memory location indicated by `addr`. # Arguments - `addr`: The address to store the matrix to. -- `d`: The [`WMMAFragment`](@ref) corresponding to the `d` matrix. +- `d`: The [`Fragment`](@ref) corresponding to the `d` matrix. - `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. -- `layout`: The storage layout of the matrix. Possible values are [`WMMARowMajor`](@ref) and [`WMMAColMajor`](@ref). -- `config`: The WMMA configuration that should be used for storing this matrix. See [`WMMAConfig`](@ref). +- `layout`: The storage layout of the matrix. Possible values are [`RowMajor`](@ref) and [`ColMajor`](@ref). +- `config`: The WMMA configuration that should be used for storing this matrix. See [`Config`](@ref). -See also: [`WMMAFragment`](@ref), [`WMMAFragmentLayout`](@ref), [`WMMAConfig`](@ref) +See also: [`Fragment`](@ref), [`FragmentLayout`](@ref), [`Config`](@ref) !!! warning All threads in a warp **MUST** execute the `store` operation in lockstep, and have to use exactly the same arguments. Failure to do so will result in undefined behaviour. """ -wmma_store_d +store_d -@generated function wmma_store_d(addr::DevicePtr{T, AS}, - d::WMMAFragment{M, N, K, D_SZ, T, WMMAUnspecified, WMMAAccumulator}, +@generated function store_d(addr::DevicePtr{T, AS}, + d::Fragment{M, N, K, D_SZ, T, Unspecified, Accumulator}, stride::Number, layout::Type{L}, - config::Type{WMMAConfig{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} + config::Type{Config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) @@ -599,23 +604,23 @@ end # WMMA fill fragment # ------------------ -export wmma_fill_c +export fill_c """ - wmma_fill_c(value, config) + fill_c(value, config) -Return a [`WMMAFragment`](@ref) filled with the value `value`. +Return a [`Fragment`](@ref) filled with the value `value`. This operation is useful if you want to implement a matrix multiplication (and thus want to set ``C = O``). # Arguments - `value`: The value used to fill the fragment. Can be a `Float16` or `Float32`. -- `config`: The WMMA configuration that should be used for this WMMA operation. See [`WMMAConfig`](@ref). +- `config`: The WMMA configuration that should be used for this WMMA operation. See [`Config`](@ref). """ -wmma_fill_c +fill_c -@generated function wmma_fill_c(value::T, - config::Type{WMMAConfig{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} +@generated function fill_c(value::T, + config::Type{Config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} # We can't use closures in @generated functions, so we'll have to do it this way instead of # ntuple(i -> val, $num_els) @@ -625,7 +630,7 @@ wmma_fill_c expr = :(tuple($(args...))) return quote - return WMMAFragment{$M, $N, $K, $num_els, $T, WMMAUnspecified, WMMAAccumulator}($expr) + return Fragment{$M, $N, $K, $num_els, $T, Unspecified, Accumulator}($expr) end end @@ -637,33 +642,33 @@ end # https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl -# Custom broadcast style for WMMAFragments -struct WMMAFragmentBroadcastStyle <: Broadcast.BroadcastStyle end +# Custom broadcast style for Fragments +struct FragmentBroadcastStyle <: Broadcast.BroadcastStyle end -# Use this broadcasting style for WMMAFragments -Base.BroadcastStyle(::Type{<:WMMAFragment}) = WMMAFragmentBroadcastStyle() +# Use this broadcasting style for Fragments +Base.BroadcastStyle(::Type{<:Fragment}) = FragmentBroadcastStyle() # Broadcast style precedence rules -# If we broadcast a fragment with a scalar, we want the WMMAFragment style to take precedence -Base.BroadcastStyle(s::WMMAFragmentBroadcastStyle, t::Broadcast.DefaultArrayStyle{0}) = s +# If we broadcast a fragment with a scalar, we want the Fragment style to take precedence +Base.BroadcastStyle(s::FragmentBroadcastStyle, t::Broadcast.DefaultArrayStyle{0}) = s # We don't want to convert fragments before broadcasting -Base.broadcastable(frag::WMMAFragment) = frag +Base.broadcastable(frag::Fragment) = frag # Needed for broadcast machinery -Base.axes(frag::WMMAFragment) = axes(frag.x) +Base.axes(frag::Fragment) = axes(frag.x) # Helper functions to get element at specified index -@inline get_index(x, i) = x # scalar -@inline get_index(frag::WMMAFragment, i) = frag[i] # WMMAFragment +@inline get_index(x, i) = x # scalar +@inline get_index(frag::Fragment, i) = frag[i] # Fragment # Helper functions to get first fragment in broadcast call @inline find_first_fragment(args::Tuple) = find_first_fragment(args[1], Base.tail(args)) -@inline find_first_fragment(a::WMMAFragment, tail) = a +@inline find_first_fragment(a::Fragment, tail) = a @inline find_first_fragment(::Any, tail) = find_first_fragment(tail) -# Custom broadcast implementation that returns a WMMAFragment -@inline function Base.copy(bc::Broadcast.Broadcasted{WMMAFragmentBroadcastStyle}) +# Custom broadcast implementation that returns a Fragment +@inline function Base.copy(bc::Broadcast.Broadcasted{FragmentBroadcastStyle}) dim = Broadcast.combine_axes(bc.args...) if length(dim) != 1 @@ -677,3 +682,5 @@ Base.axes(frag::WMMAFragment) = axes(frag.x) frag_ty = typeof(find_first_fragment(bc.args)) return frag_ty(tuple) end + +end diff --git a/test/device/wmma.jl b/test/device/wmma.jl index 9193419c..412fcc69 100644 --- a/test/device/wmma.jl +++ b/test/device/wmma.jl @@ -1,6 +1,9 @@ # Need https://github.com/JuliaLang/julia/pull/33970 # and https://github.com/JuliaLang/julia/pull/34043 if VERSION >= v"1.4.0-DEV.666" && CUDAnative.current_capability() >= v"7.0" + +using CUDAnative.WMMA + @testset "WMMA" begin ################################################################################ @@ -152,23 +155,23 @@ if VERSION >= v"1.4.0-DEV.666" && CUDAnative.current_capability() >= v"7.0" @testset "Flattening/unflattening" begin @testset "Flattening" begin - @test CUDAnative.flatten(5) == (5,) - @test CUDAnative.flatten(5.0) == (5.0,) - @test CUDAnative.flatten(VecElement{Float16}(5)) == (Float16(5),) - @test CUDAnative.flatten(ntuple(i -> i, 8)) == ntuple(i -> i, 8) - @test CUDAnative.flatten(ntuple(i -> VecElement{Float16}(i), 8)) == ntuple(i -> Float16(i), 8) - @test CUDAnative.flatten(ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8)) == ntuple(i -> i, 2 * 8) - @test CUDAnative.flatten(ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)) == ntuple(i -> Float16(i), 2 * 8) + @test CUDAnative.WMMA.flatten(5) == (5,) + @test CUDAnative.WMMA.flatten(5.0) == (5.0,) + @test CUDAnative.WMMA.flatten(VecElement{Float16}(5)) == (Float16(5),) + @test CUDAnative.WMMA.flatten(ntuple(i -> i, 8)) == ntuple(i -> i, 8) + @test CUDAnative.WMMA.flatten(ntuple(i -> VecElement{Float16}(i), 8)) == ntuple(i -> Float16(i), 8) + @test CUDAnative.WMMA.flatten(ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8)) == ntuple(i -> i, 2 * 8) + @test CUDAnative.WMMA.flatten(ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)) == ntuple(i -> Float16(i), 2 * 8) end @testset "Unflattening" begin - @test CUDAnative.unflatten(Int64, (5,)) == 5 - @test CUDAnative.unflatten(Float64, (5.0,)) == 5.0 - @test CUDAnative.unflatten(VecElement{Float16}, (Float16(5),)) == VecElement{Float16}(5) - @test CUDAnative.unflatten(NTuple{8, Int64}, ntuple(i -> i, 8)) == ntuple(i -> i, 8) - @test CUDAnative.unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) - @test CUDAnative.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) - @test CUDAnative.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) + @test CUDAnative.WMMA.unflatten(Int64, (5,)) == 5 + @test CUDAnative.WMMA.unflatten(Float64, (5.0,)) == 5.0 + @test CUDAnative.WMMA.unflatten(VecElement{Float16}, (Float16(5),)) == VecElement{Float16}(5) + @test CUDAnative.WMMA.unflatten(NTuple{8, Int64}, ntuple(i -> i, 8)) == ntuple(i -> i, 8) + @test CUDAnative.WMMA.unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) + @test CUDAnative.WMMA.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) + @test CUDAnative.WMMA.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) end end @@ -176,18 +179,18 @@ if VERSION >= v"1.4.0-DEV.666" && CUDAnative.current_capability() >= v"7.0" @testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5], ty = [Float16, Float32] - @test ty(5) .* WMMAFragment{16, 16, 16, sz, ty, WMMARowMajor, WMMAMatrixA}(ntuple(i -> ty(i), sz)) == WMMAFragment{16, 16, 16, sz, ty, WMMARowMajor, WMMAMatrixA}(ntuple(i -> ty(5 * i), sz)) - @test ty(5) .+ WMMAFragment{16, 16, 16, sz, ty, WMMARowMajor, WMMAMatrixA}(ntuple(i -> ty(i), sz)) == WMMAFragment{16, 16, 16, sz, ty, WMMARowMajor, WMMAMatrixA}(ntuple(i -> ty(5 + i), sz)) + @test ty(5) .* Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 * i), sz)) + @test ty(5) .+ Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 + i), sz)) end ################################################################################ @testset "CUDA C-style API" begin - @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [WMMAColMajor, WMMARowMajor], - b_layout in [WMMAColMajor, WMMARowMajor], - c_layout in [WMMAColMajor, WMMARowMajor], - d_layout in [WMMAColMajor, WMMARowMajor], + @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [ColMajor, RowMajor], + b_layout in [ColMajor, RowMajor], + c_layout in [ColMajor, RowMajor], + d_layout in [ColMajor, RowMajor], c_type in [Float16, Float32], d_type in [Float16, Float32], do_mac in [true, false] @@ -206,23 +209,23 @@ if VERSION >= v"1.4.0-DEV.666" && CUDAnative.current_capability() >= v"7.0" beta = rand(c_type) @eval function kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) - conf = WMMAConfig{16, 16, 16, $d_type} + conf = Config{16, 16, 16, $d_type} - a_frag = wmma_load_a(pointer(a_dev), 16, $a_layout, conf) - b_frag = wmma_load_b(pointer(b_dev), 16, $b_layout, conf) + a_frag = load_a(pointer(a_dev), 16, $a_layout, conf) + b_frag = load_b(pointer(b_dev), 16, $b_layout, conf) if $do_mac - c_frag = wmma_load_c(pointer(c_dev), 16, $c_layout, conf) + c_frag = load_c(pointer(c_dev), 16, $c_layout, conf) else - c_frag = wmma_fill_c($c_type(0), conf) + c_frag = fill_c($c_type(0), conf) end a_frag = alpha .* a_frag c_frag = beta .* c_frag - d_frag = wmma_mma(a_frag, b_frag, c_frag, conf) + d_frag = mma(a_frag, b_frag, c_frag, conf) - wmma_store_d(pointer(d_dev), d_frag, 16, $d_layout, conf) + store_d(pointer(d_dev), d_frag, 16, $d_layout, conf) return end @@ -230,10 +233,10 @@ if VERSION >= v"1.4.0-DEV.666" && CUDAnative.current_capability() >= v"7.0" @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) d = Array(d_dev) - new_a = (a_layout == WMMAColMajor) ? a : transpose(a) - new_b = (b_layout == WMMAColMajor) ? b : transpose(b) - new_c = (c_layout == WMMAColMajor) ? c : transpose(c) - new_d = (d_layout == WMMAColMajor) ? d : transpose(d) + new_a = (a_layout == ColMajor) ? a : transpose(a) + new_b = (b_layout == ColMajor) ? b : transpose(b) + new_c = (c_layout == ColMajor) ? c : transpose(c) + new_d = (d_layout == ColMajor) ? d : transpose(d) if do_mac @test all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16)))) From 3c95dcf6cda68db0144029518c21862a413170d9 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 1 Feb 2020 01:59:43 +0100 Subject: [PATCH 78/81] Fix docs --- docs/src/device/wmma.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/src/device/wmma.md b/docs/src/device/wmma.md index b4dcfc1a..bfc0041b 100644 --- a/docs/src/device/wmma.md +++ b/docs/src/device/wmma.md @@ -106,11 +106,11 @@ Note that, in CUDA C++, the fragment is responsible for both the storage of inte All CUDA C++ WMMA calls are function templates that take the resultant fragment as a by-reference argument. As a result, the type of this argument can be used during overload resolution to select the correct WMMA instruction to call. -In contrast, the API in Julia separates the WMMA storage ([`Fragment`](@ref)) and configuration ([`Config`](@ref)). +In contrast, the API in Julia separates the WMMA storage ([`WMMA.Fragment`](@ref)) and configuration ([`WMMA.Config`](@ref)). Instead of taking the resultant fragment by reference, the Julia functions just return it. This makes the dataflow clearer, but it also means that the type of that fragment cannot be used for selection of the correct WMMA instruction. Thus, there is still a limited amount of information that cannot be inferred from the argument types, but must nonetheless match for all WMMA operations, such as the overall shape of the MMA. -This is accomplished by a separate "WMMA configuration" (see [`Config`](@ref)) that you create once, and then give as an argument to all intrinsics. +This is accomplished by a separate "WMMA configuration" (see [`WMMA.Config`](@ref)) that you create once, and then give as an argument to all intrinsics. ### Fragment ```@docs @@ -128,29 +128,29 @@ CUDAnative.WMMA.Config ### Load matrix ```@docs -CUDAnative.WMMA.wmma_load_a -CUDAnative.WMMA.wmma_load_b -CUDAnative.WMMA.wmma_load_c +CUDAnative.WMMA.load_a +CUDAnative.WMMA.load_b +CUDAnative.WMMA.load_c ``` ### Perform multiply-accumulate ```@docs -CUDAnative.WMMA.wmma_mma +CUDAnative.WMMA.mma ``` ### Store matrix ```@docs -CUDAnative.WMMA.wmma_store_d +CUDAnative.WMMA.store_d ``` ### Fill fragment ```@docs -CUDAnative.WMMA.wmma_fill_c +CUDAnative.WMMA.fill_c ``` ### Element access and broadcasting -Similar to the CUDA C++ WMMA API, [`Fragment`](@ref)s have an `x` member that can be used to access individual elements. +Similar to the CUDA C++ WMMA API, [`WMMA.Fragment`](@ref)s have an `x` member that can be used to access individual elements. Note that, in contrast to the values returned by the LLVM intrinsics, the `x` member is flattened. For example, while the `Float16` variants of the `load_a` instrinsics return `NTuple{8, NTuple{2, VecElement{Float16}}}`, the `x` member has type `NTuple{16, Float16}`. From 94ab44225d98fd7d8382267aece11455e7cad51c Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 1 Feb 2020 10:59:43 +0100 Subject: [PATCH 79/81] Fix indenting --- src/device/cuda/wmma.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index a8dda8cb..8154bc34 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -524,9 +524,9 @@ Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. mma @generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, - b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, - c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, - config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} + b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, + c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, + config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T) _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T) @@ -579,10 +579,10 @@ See also: [`Fragment`](@ref), [`FragmentLayout`](@ref), [`Config`](@ref) store_d @generated function store_d(addr::DevicePtr{T, AS}, - d::Fragment{M, N, K, D_SZ, T, Unspecified, Accumulator}, - stride::Number, - layout::Type{L}, - config::Type{Config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} + d::Fragment{M, N, K, D_SZ, T, Unspecified, Accumulator}, + stride::Number, + layout::Type{L}, + config::Type{Config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} as_str = get_hl_as_info(AS) layout = get_hl_layout(L) @@ -620,7 +620,7 @@ This operation is useful if you want to implement a matrix multiplication (and t fill_c @generated function fill_c(value::T, - config::Type{Config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} + config::Type{Config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} # We can't use closures in @generated functions, so we'll have to do it this way instead of # ntuple(i -> val, $num_els) From 3579373ff6d19af5acd2010e0062fa712a4dbc4a Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 1 Feb 2020 11:13:28 +0100 Subject: [PATCH 80/81] Add broadcasting to example --- examples/wmma/high-level.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl index 63f16c34..932c0420 100644 --- a/examples/wmma/high-level.jl +++ b/examples/wmma/high-level.jl @@ -30,6 +30,8 @@ function kernel(a_dev, b_dev, c_dev, d_dev) b_frag = WMMA.load_b(pointer(b_dev), 16, WMMA.ColMajor, conf) c_frag = WMMA.load_c(pointer(c_dev), 16, WMMA.ColMajor, conf) + c_frag = 0.5f0 .* c_frag + d_frag = WMMA.mma(a_frag, b_frag, c_frag, conf) WMMA.store_d(pointer(d_dev), d_frag, 16, WMMA.ColMajor, conf) @@ -40,5 +42,5 @@ end @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) d = Array(d_dev) -@test all(isapprox.(a * b + c, d; rtol=0.01)) +@test all(isapprox.(a * b + 0.5 * c, d; rtol=0.01)) ### END From 93c77bc97731a974841821fa5a74d207688b09ac Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Sat, 1 Feb 2020 11:36:29 +0100 Subject: [PATCH 81/81] Small doc fix --- src/device/cuda/wmma.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl index 8154bc34..8ec95288 100644 --- a/src/device/cuda/wmma.jl +++ b/src/device/cuda/wmma.jl @@ -321,7 +321,7 @@ export Fragment Type that represents per-thread intermediate results of WMMA operations. -You can access individual elements using the `x` member or [] operator, but beware that the exact ordering of elements is unspecified. +You can access individual elements using the `x` member or `[]` operator, but beware that the exact ordering of elements is unspecified. """ struct Fragment{M, N, K, FS, T, L <: FragmentLayout, U <: FragmentUse} x::NTuple{FS, T}