diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index f6f6b0a1..d2c66151 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,6 +26,9 @@ julia:nightly: - .test tags: - nvidia + - sm_75 + variables: + CI_THOROUGH: 'true' allow_failure: true diff --git a/docs/make.jl b/docs/make.jl index 36c0427c..33b428af 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -25,6 +25,7 @@ function main() ], "Device" => [ "device/cuda.md", + "device/wmma.md", "device/array.md" ] ] diff --git a/docs/src/device/wmma.md b/docs/src/device/wmma.md new file mode 100644 index 00000000..bfc0041b --- /dev/null +++ b/docs/src/device/wmma.md @@ -0,0 +1,178 @@ +# WMMA + +This section details CUDAnative's interface to CUDA's warp matrix multiply-accumulate (WMMA) operations. +This interface enables programmatic access to Tensor Cores, a new hardware feature in Volta that performs mixed precision matrix MAC operations. + +Access to WMMA using CUDAnative is available in two levels: low level wrappers around the LLVM intrinsics, and a higher-level API, similar to that of CUDA C. + +Note that to use the WMMA intrinsics, you need a sufficiently recent version of Julia: `v1.4.0-DEV.666` or later. +You can check this by running the following in the REPL: +```julia +VERSION >= v"1.4.0-DEV.666" +``` + +!!! note + + If you're running into any of following errors while using the WMMA interfaces: + ``` + LLVM error: Do not know how to split the result of this operator! + ``` + or + ``` + CUDA error: a PTX JIT compilation failed (code 218, ERROR_INVALID_PTX) + ptxas application ptx input, line ; error : .aligned modifier required for instruction '' + ``` + then make sure you are running Julia v1.4.0-DEV.666 or later! + +## Introduction of Terminology + +The WMMA operations perform a matrix multiply-accumulate. +More concretely, it calculates ``D = A \cdot B + C``, where ``A`` is a ``M \times K`` matrix, ``B`` is a ``K \times N`` matrix, and ``C`` and ``D`` are ``M \times N`` matrices. + +Note that not all values of ``M``, ``N`` and ``K`` are allowed. +The tuple ``(M, N, K)`` is often called the "shape" of the multiply accumulate operation. + +The multiply-accumulate consists of the following steps: +- Load the matrices ``A``, ``B`` and ``C`` from memory to registers using a WMMA load operation. +- Perform the matrix multiply-accumulate of ``A``, ``B`` and ``C`` to obtain ``D`` using a WMMA MMA operation. ``D`` is stored in hardware registers after this step. +- Store the result ``D`` back to memory using a WMMA store operation. + +Note that WMMA is a warp-wide operation, which means that all threads in a warp must cooperate, and execute the WMMA operations in lockstep. +Failure to do so will result in undefined behaviour. + +Each thread in a warp will hold a part of the matrix in its registers. +In WMMA parlance, this part is referred to as a "fragment". +Note that the exact mapping between matrix elements and fragment is unspecified, and subject to change in future versions. + +Finally, it is important to note that the resultant ``D`` matrix can be used as a ``C`` matrix for a subsequent multiply-accumulate. +This is useful if one needs to calculate a sum of the form ``\sum_{i=0}^{n} A_i B_i``, where ``A_i`` and ``B_i`` are matrices of the correct dimension. + +## LLVM Intrinsics + +The LLVM intrinsics are accessible by using the one-to-one Julia wrappers. +The return type of each wrapper is the Julia type that corresponds closest to the return type of the LLVM intrinsic. +For example, LLVM's `[8 x <2 x half>]` becomes `NTuple{8, NTuple{2, VecElement{Float16}}}` in Julia. +In essence, these wrappers return the SSA values returned by the LLVM intrinsic. +Currently, all intrinsics that are available in LLVM 6, PTX 6.0 and SM 70 are implemented. + +These LLVM intrinsics are then lowered to the correct PTX instructions by the LLVM NVPTX backend. +For more information about the PTX instructions, please refer to the [PTX Instruction Set Architecture Manual](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions). + +The LLVM intrinsics are subdivided in three categories: load, store and multiply-accumulate. +In what follows, each of these will be discussed. + +### Load matrix +```@docs +CUDAnative.WMMA.llvm_wmma_load +``` + +### Perform multiply-accumulate +```@docs +CUDAnative.WMMA.llvm_wmma_mma +``` + +### Store matrix +```@docs +CUDAnative.WMMA.llvm_wmma_store +``` + +### Example + +````@eval +lines = readlines("../../../examples/wmma/low-level.jl") +start = findfirst(x -> x == "### START", lines) + 1 +stop = findfirst(x -> x == "### END", lines) - 1 +example = join(lines[start:stop], '\n') + +using Markdown +Markdown.parse(""" +```julia +$(example) +``` +""") +```` + +## CUDA C-like API + +The main difference between the CUDA C-like API and the lower level wrappers, is that the former enforces several constraints when working with WMMA. +For example, it ensures that the ``A`` fragment argument to the MMA instruction was obtained by a `load_a` call, and not by a `load_b` or `load_c`. +Additionally, it makes sure that the data type and storage layout of the load/store operations and the MMA operation match. + +The CUDA C-like API heavily uses Julia's dispatch mechanism. +As such, the method names are much shorter than the LLVM intrinsic wrappers, as most information is baked into the type of the arguments rather than the method name. + + +Note that, in CUDA C++, the fragment is responsible for both the storage of intermediate results and the WMMA configuration. +All CUDA C++ WMMA calls are function templates that take the resultant fragment as a by-reference argument. +As a result, the type of this argument can be used during overload resolution to select the correct WMMA instruction to call. + +In contrast, the API in Julia separates the WMMA storage ([`WMMA.Fragment`](@ref)) and configuration ([`WMMA.Config`](@ref)). +Instead of taking the resultant fragment by reference, the Julia functions just return it. +This makes the dataflow clearer, but it also means that the type of that fragment cannot be used for selection of the correct WMMA instruction. +Thus, there is still a limited amount of information that cannot be inferred from the argument types, but must nonetheless match for all WMMA operations, such as the overall shape of the MMA. +This is accomplished by a separate "WMMA configuration" (see [`WMMA.Config`](@ref)) that you create once, and then give as an argument to all intrinsics. + +### Fragment +```@docs +CUDAnative.WMMA.FragmentLayout +CUDAnative.WMMA.RowMajor +CUDAnative.WMMA.ColMajor +CUDAnative.WMMA.Unspecified +CUDAnative.WMMA.Fragment +``` + +### WMMA configuration +```@docs +CUDAnative.WMMA.Config +``` + +### Load matrix +```@docs +CUDAnative.WMMA.load_a +CUDAnative.WMMA.load_b +CUDAnative.WMMA.load_c +``` + +### Perform multiply-accumulate +```@docs +CUDAnative.WMMA.mma +``` + +### Store matrix +```@docs +CUDAnative.WMMA.store_d +``` + +### Fill fragment +```@docs +CUDAnative.WMMA.fill_c +``` + +### Element access and broadcasting + +Similar to the CUDA C++ WMMA API, [`WMMA.Fragment`](@ref)s have an `x` member that can be used to access individual elements. +Note that, in contrast to the values returned by the LLVM intrinsics, the `x` member is flattened. +For example, while the `Float16` variants of the `load_a` instrinsics return `NTuple{8, NTuple{2, VecElement{Float16}}}`, the `x` member has type `NTuple{16, Float16}`. + +Typically, you will only need to access the `x` member to perform elementwise operations. +This can be more succinctly expressed using Julia's broadcast mechanism. +For example, to double each element in a fragment, you can simply use: +```julia +frag = 2.0f0 .* frag +``` + +### Example + +````@eval +lines = readlines("../../../examples/wmma/high-level.jl") +start = findfirst(x -> x == "### START", lines) + 1 +stop = findfirst(x -> x == "### END", lines) - 1 +example = join(lines[start:stop], '\n') + +using Markdown +Markdown.parse(""" +```julia +$(example) +``` +""") +```` diff --git a/examples/wmma/high-level.jl b/examples/wmma/high-level.jl new file mode 100644 index 00000000..932c0420 --- /dev/null +++ b/examples/wmma/high-level.jl @@ -0,0 +1,46 @@ +# Need https://github.com/JuliaLang/julia/pull/33970 +# and https://github.com/JuliaLang/julia/pull/34043 +if VERSION < v"1.4.0-DEV.666" + exit() +end + +using CUDAnative +if CUDAnative.current_capability() < v"7.0" + exit() +end + +### START +using CUDAnative +using CuArrays +using Test + +a = rand(Float16, (16, 16)) +b = rand(Float16, (16, 16)) +c = rand(Float32, (16, 16)) + +a_dev = CuArray(a) +b_dev = CuArray(b) +c_dev = CuArray(c) +d_dev = similar(c_dev) + +function kernel(a_dev, b_dev, c_dev, d_dev) + conf = WMMA.Config{16, 16, 16, Float32} + + a_frag = WMMA.load_a(pointer(a_dev), 16, WMMA.ColMajor, conf) + b_frag = WMMA.load_b(pointer(b_dev), 16, WMMA.ColMajor, conf) + c_frag = WMMA.load_c(pointer(c_dev), 16, WMMA.ColMajor, conf) + + c_frag = 0.5f0 .* c_frag + + d_frag = WMMA.mma(a_frag, b_frag, c_frag, conf) + + WMMA.store_d(pointer(d_dev), d_frag, 16, WMMA.ColMajor, conf) + + return +end + +@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) +d = Array(d_dev) + +@test all(isapprox.(a * b + 0.5 * c, d; rtol=0.01)) +### END diff --git a/examples/wmma/low-level.jl b/examples/wmma/low-level.jl new file mode 100644 index 00000000..9eadd28d --- /dev/null +++ b/examples/wmma/low-level.jl @@ -0,0 +1,42 @@ +# Need https://github.com/JuliaLang/julia/pull/33970 +# and https://github.com/JuliaLang/julia/pull/34043 +if VERSION < v"1.4.0-DEV.666" + exit() +end + +using CUDAnative +if CUDAnative.current_capability() < v"7.0" + exit() +end + +### START +using CUDAnative +using CuArrays +using Test + +# Generate input matrices +a = rand(Float16, (16, 16)) +a_dev = CuArray(a) +b = rand(Float16, (16, 16)) +b_dev = CuArray(b) +c = rand(Float32, (16, 16)) +c_dev = CuArray(c) + +# Allocate space for result +d_dev = similar(c_dev) + +# Matrix multiply-accumulate kernel (D = A * B + C) +function kernel(a_dev, b_dev, c_dev, d_dev) + a_frag = WMMA.llvm_wmma_load_a_col_m16n16k16_stride_f16(pointer(a_dev), 16) + b_frag = WMMA.llvm_wmma_load_b_col_m16n16k16_stride_f16(pointer(b_dev), 16) + c_frag = WMMA.llvm_wmma_load_c_col_m16n16k16_stride_f32(pointer(c_dev), 16) + + d_frag = WMMA.llvm_wmma_mma_col_col_m16n16k16_f32_f32(a_frag, b_frag, c_frag) + + WMMA.llvm_wmma_store_d_col_m16n16k16_stride_f32(pointer(d_dev), d_frag, 16) + return +end + +@cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) +@test all(isapprox.(a * b + c, Array(d_dev); rtol=0.01)) +### END diff --git a/src/device/cuda.jl b/src/device/cuda.jl index 13833fd6..4e2772ad 100644 --- a/src/device/cuda.jl +++ b/src/device/cuda.jl @@ -11,6 +11,7 @@ include("cuda/assertion.jl") include("cuda/memory_dynamic.jl") include("cuda/atomics.jl") include("cuda/misc.jl") +include("cuda/wmma.jl") # functionality from libdevice # diff --git a/src/device/cuda/memory_shared.jl b/src/device/cuda/memory_shared.jl index 37113db1..23f760f0 100644 --- a/src/device/cuda/memory_shared.jl +++ b/src/device/cuda/memory_shared.jl @@ -83,8 +83,9 @@ end initializer!(gv, null(gv_typ)) end # by requesting a larger-than-datatype alignment, we might be able to vectorize. - # we pick 16 bytes since this is the largest transaction size as supported by PTX. - alignment!(gv, Base.max(16, datatype_align(T))) + # we pick 32 bytes here, since WMMA instructions require 32-byte alignment. + # TODO: Make the alignment configurable + alignment!(gv, Base.max(32, datatype_align(T))) # generate IR Builder(JuliaContext()) do builder diff --git a/src/device/cuda/wmma.jl b/src/device/cuda/wmma.jl new file mode 100644 index 00000000..8ec95288 --- /dev/null +++ b/src/device/cuda/wmma.jl @@ -0,0 +1,686 @@ +export WMMA +module WMMA + +using CUDAnative: AS, DevicePtr + +################################################################################ +# CONSTANTS +################################################################################ + +# Maps PTX types to Julia array types +const map_ptx_to_jl_array = Dict( + "f16" => Float16, + "f32" => Float32 + ) + +# Maps PTX types to Julia fragment types +const map_ptx_to_jl_frag = Dict( + "f16" => NTuple{2, VecElement{Float16}}, + "f32" => Float32 + ) + +# Maps matrix & PTX types to fragment sizes +const map_frag_sizes = Dict( + "a.f16" => 8, + "b.f16" => 8, + "c.f16" => 4, + "c.f32" => 8, + "d.f16" => 4, + "d.f32" => 8 + ) + +# Maps PTX AS to CUDAnative.AS +const map_ptx_as_to_as_ty = Dict( + "" => AS.Generic, + "shared" => AS.Shared, + "global" => AS.Global + ) + +################################################################################ +# HELPER FUNCTIONS +################################################################################ + +# Returns (Julia array type, Julia fragment type, fragment size) +get_frag_info(matrix, ptx_el_type) = ( + map_ptx_to_jl_array[ptx_el_type], + map_ptx_to_jl_frag[ptx_el_type], + map_frag_sizes["$matrix.$ptx_el_type"] + ) + +get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space]) + +################################################################################ +# LOW LEVEL API +################################################################################ + +# ----------- +# Matrix load +# ----------- + +@doc """ + llvm_wmma_load_{matrix}_{layout}_{shape}_{addr_space}_stride_{elem_type}(src_addr, stride) + +Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}`. + +# Arguments +- `src_addr`: The memory address to load from. +- `stride`: The leading dimension of the matrix, in numbers of elements. + +# Placeholders +- `{matrix}`: The matrix to load. Can be `a`, `b` or `c`. +- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. +- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). Note that `f32` is only valid for the matrix ``C``. +""" +llvm_wmma_load() = error("Cannot call llvm_wmma_load without values for placeholders!") +export llvm_wmma_load + +for mat in ["a", "b", "c"], + layout in ["col", "row"], + shape in ["m16n16k16"], + addr_space in ["", "shared", "global"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Non-stride versions? + + # Float32 is only supported for C + if (elem_type == "f32") && (mat != "c") + continue + end + + addr_space_int = get_addrspace_info(addr_space) + + # Name of the Julia wrapper function + func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", mat, layout, shape, addr_space, stride, elem_type]), "_")) + + # Name of the LLVM intrinsic + llvm_intr = "llvm.nvvm.wmma.$shape.load.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8" + + # Determine types + size for this (matrix, elem_type) combination + arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) + + ccall_name = "extern $llvm_intr" + + @eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $frag_ty}, (Ref{$arr_ty}, Int32), src_addr, stride) + @eval export $func_name + @eval @doc (@doc llvm_wmma_load) $func_name +end + +# ------------ +# Matrix store +# ------------ + +@doc """ + llvm_wmma_store_d_{layout}_{shape}_{addr_space}_stride_{elem_type}(dst_addr, data, stride) + +Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape}.{addr_space}.stride.{elem_type}`. + +# Arguments +- `dst_addr`: The memory address to store to. +- `data`: The ``D`` fragment to store. +- `stride`: The leading dimension of the matrix, in numbers of elements. + +# Placeholders +- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. +- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). +""" +llvm_wmma_store() = error("Cannot call llvm_wmma_store without values for placeholders!") +export llvm_wmma_store + +for mat in ["d"], + layout in ["col", "row"], + shape in ["m16n16k16"], + addr_space in ["", "shared", "global"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # TODO: Non-stride versions? + + addr_space_int = get_addrspace_info(addr_space) + + # Name of the Julia wrapper function + func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", mat, layout, shape, addr_space, stride, elem_type]), "_")) + + # Name of the LLVM intrinsic + llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8" + + # Determine types + size for this (matrix, elem_type) combination + arr_ty, frag_ty, sz = get_frag_info(mat, elem_type) + + ccall_name = "extern $llvm_intr" + frag_types = ntuple(i -> frag_ty, sz) + frag_vars = ntuple(i -> :(data[$i]), sz) + + @eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, (Ref{$arr_ty}, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride) + @eval export $func_name + @eval @doc (@doc llvm_wmma_store) $func_name +end + +# -------------------------- +# Matrix multiply accumulate +# -------------------------- + +@doc """ + llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) + +Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`. + +# Arguments +- `a`: The WMMA fragment corresponding to the matrix ``A``. +- `b`: The WMMA fragment corresponding to the matrix ``B``. +- `c`: The WMMA fragment corresponding to the matrix ``C``. + +# Placeholders +- `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. +- `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. +- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`. +- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). +- `{c_elem_type}`: The type of each element in the ``C`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). + +!!! warning + + Remember that the shape, type and layout of all operations (be it MMA, load or store) **MUST** match. + Otherwise, the behaviour is undefined! +""" +llvm_wmma_mma() = error("Cannot call llvm_wmma_mma without values for placeholders!") +export llvm_wmma_mma + +for a_layout in ["col", "row"], + b_layout in ["col", "row"], + shape in ["m16n16k16"], + d_elem_type in ["f16", "f32"], + c_elem_type in ["f16", "f32"], + b_elem_type in ["f16"], + a_elem_type in ["f16"] + + # Name of the Julia wrapper function + func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_")) + + # Name of the LLVM intrinsic + llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type" + + # Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D + a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type) + b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type) + c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type) + d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type) + + ccall_name = "extern $llvm_intr" + + a_types = ntuple(i -> a_frag_ty, a_sz) + b_types = ntuple(i -> b_frag_ty, b_sz) + c_types = ntuple(i -> c_frag_ty, c_sz) + + a_vars = ntuple(i -> :(a[$i]), a_sz) + b_vars = ntuple(i -> :(b[$i]), b_sz) + c_vars = ntuple(i -> :(c[$i]), c_sz) + + @eval $func_name(a, b, c) = ccall($ccall_name, llvmcall, NTuple{$d_sz, $d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)) + @eval export $func_name + @eval @doc (@doc llvm_wmma_mma) $func_name +end + +################################################################################ +# FLATTENING/UNFLATTENING LOGIC +################################################################################ + +# Base case (Float16, Float32, ...) +flatten_recurse(typ, e) = [:($e)] +unflatten_recurse(typ, e, idx) = :($e[$idx]), idx + 1 + +# VecElements +flatten_recurse(typ::Type{VecElement{T}}, e) where T = [:($e.value)] +unflatten_recurse(typ::Type{VecElement{T}}, e, idx) where T = :(VecElement{$T}($e[$idx])), idx + 1 + +# NTuples +function flatten_recurse(typ::Type{NTuple{N, T}}, e) where {N, T} + ret = Expr[] + + for (i, eltyp) in enumerate(typ.types) + append!(ret, flatten_recurse(eltyp, :($e[$i]))) + end + + return ret +end + +function unflatten_recurse(typ::Type{NTuple{N, T}}, e, idx) where {N, T} + ret = Expr(:tuple) + + for (i, eltyp) in enumerate(typ.types) + arg, idx = unflatten_recurse(eltyp, e, idx) + push!(ret.args, arg) + end + + return ret, idx +end + +@generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...) +@generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1] + +################################################################################ +# HIGH LEVEL (CUDA-STYLE API) +################################################################################ + +# ------------- +# WMMA fragment +# ------------- + +export FragmentLayout, RowMajor, ColMajor, Unspecified + +""" + FragmentLayout + +Abstract type that specifies the storage layout of a matrix. + +Possible values are [`RowMajor`](@ref), [`ColMajor`](@ref) and [`Unspecified`](@ref). +""" +abstract type FragmentLayout end + +""" + RowMajor + +Type that represents a matrix stored in row major (C style) order. +""" +struct RowMajor <: FragmentLayout end + +""" + ColMajor + +Type that represents a matrix stored in column major (Julia style) order. +""" +struct ColMajor <: FragmentLayout end + +""" + Unspecified + +Type that represents a matrix stored in an unspecified order. + +!!! warning + + This storage format is not valid for all WMMA operations! +""" +struct Unspecified <: FragmentLayout end + + +export MatrixA, MatrixB, Accumulator + +abstract type FragmentUse end +struct MatrixA <: FragmentUse end +struct MatrixB <: FragmentUse end +struct Accumulator <: FragmentUse end + + +export Fragment + +""" + Fragment + +Type that represents per-thread intermediate results of WMMA operations. + +You can access individual elements using the `x` member or `[]` operator, but beware that the exact ordering of elements is unspecified. +""" +struct Fragment{M, N, K, FS, T, L <: FragmentLayout, U <: FragmentUse} + x::NTuple{FS, T} +end + +# ---------------------- +# WMMA fragment indexing +# ---------------------- + +for f in (:getindex, :setindex!, :firstindex, :lastindex) + @eval Base.$f(frag::Fragment, args...) = $f(frag.x, args...) +end + +# ------------------ +# WMMA configuration +# ------------------ + +export Config + +""" + Config{M, N, K, d_type} + +Type that contains all information for WMMA operations that cannot be inferred from the argument's types. + +WMMA instructions calculate the matrix multiply-accumulate operation ``D = A \\cdot B + C``, where ``A`` is a ``M \\times K`` matrix, +``B`` a ``K \\times N`` matrix, and ``C`` and ``D`` are ``M \\times N`` matrices. + +`d_type` refers to the type of the elements of matrix ``D``, and can be either `Float16` or `Float32`. + +All WMMA operations take a `Config` as their final argument. + +# Examples +```jldoctest +julia> config = Config{16, 16, 16, Float32} +Config{16,16,16,Float32} +``` +""" +struct Config{M, N, K, d_type} end + +# --------- +# Constants +# --------- + +# Maps Julia array types to string +const map_jl_array_to_str = Dict(val => key for (key, val) in map_ptx_to_jl_array) + +# Maps CUDAnative.AS types to string +const map_as_ty_to_str = Dict(val => key for (key, val) in map_ptx_as_to_as_ty) + +# Maps layout types to string +const map_layout_ty_to_str = Dict( + RowMajor => "row", + ColMajor => "col" + ) + +# Maps matrix & type to number of elements (size after flattening) +const map_num_elems = Dict( + ("a", Float16) => 16, + ("b", Float16) => 16, + ("c", Float16) => 8, + ("c", Float32) => 8, + ("d", Float16) => 8, + ("d", Float32) => 8 + ) + +# Maps matrix to its use +const map_matrix_to_use = Dict( + "a" => MatrixA, + "b" => MatrixB, + "c" => Accumulator, + "d" => Accumulator + ) + +# ---------------- +# Helper functions +# ---------------- + +function get_hl_as_info(AS) + try + return map_as_ty_to_str[AS] + catch + error("Invalid address space for WMMA: $AS") + end +end + +function get_hl_layout(L) + try + return map_layout_ty_to_str[L] + catch + error("Invalid layout for WMMA: $L") + end +end + +function get_hl_shape(M, N, K) + if (M, N, K) != (16, 16, 16) + error("Invalid shape for WMMA: (M, N, K) = ($M, $N, $K)") + end + + return "m$(M)n$(N)k$(K)" +end + +get_hl_mat_use(mat) = map_matrix_to_use[mat] + +function get_hl_frag_info(matrix, T) + ptx_ty = nothing + + try + ptx_ty = map_jl_array_to_str[T] + catch + error("Invalid element type for WMMA: $T") + end + + try + return (map_num_elems[(matrix, T)], + map_frag_sizes["$matrix.$ptx_ty"], + map_ptx_to_jl_frag[ptx_ty], + ptx_ty) + catch + error("Invalid type $T for matrix $matrix") + end +end + +# --------- +# WMMA load +# --------- + +export load_a, load_b, load_c + +""" + load_a(addr, stride, layout, config) + load_b(addr, stride, layout, config) + load_c(addr, stride, layout, config) + +Load the matrix `a`, `b` or `c` from the memory location indicated by `addr`, and return the resulting [`Fragment`](@ref). + +# Arguments +- `addr`: The address to load the matrix from. +- `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. +- `layout`: The storage layout of the matrix. Possible values are [`RowMajor`](@ref) and [`ColMajor`](@ref). +- `config`: The WMMA configuration that should be used for loading this matrix. See [`Config`](@ref). + +See also: [`Fragment`](@ref), [`FragmentLayout`](@ref), [`Config`](@ref) + +!!! warning + + All threads in a warp **MUST** execute the load operation in lockstep, and have to use exactly the same arguments. + Failure to do so will result in undefined behaviour. +""" +load_a, load_b, load_c + +for mat in ["a", "b", "c"] + func_name = Symbol("load_$mat") + + @eval @generated function $func_name(addr::DevicePtr{T, AS}, + stride::Number, + layout::Type{L}, + config::Type{Config{M, N, K, D_TYPE}}) where {T, AS, L, M, N, K, D_TYPE} + + as_str = get_hl_as_info(AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) + num_els, _, _, arr_str = get_hl_frag_info($mat, T) + U = get_hl_mat_use($mat) + L_ret = ($mat == "c") ? Unspecified : L + + # Name of the Julia wrapper + wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_")) + + return quote + x = flatten($wrapper(addr, stride)) + return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) + end + end +end + + +# ------------------------ +# WMMA multiply-accumulate +# ------------------------ + +export mma + +""" + mma(a, b, c, conf) + +Perform the matrix multiply-accumulate operation ``D = A \\cdot B + C``. + +# Arguments + +- `a`: The [`Fragment`](@ref) corresponding to the matrix ``A``. +- `b`: The [`Fragment`](@ref) corresponding to the matrix ``B``. +- `c`: The [`Fragment`](@ref) corresponding to the matrix ``C``. +- `conf`: The [`Config`](@ref) that should be used in this WMMA operation. + +!!! warning + + All threads in a warp **MUST** execute the `mma` operation in lockstep, and have to use exactly the same arguments. + Failure to do so will result in undefined behaviour. +""" +mma + +@generated function mma(a::Fragment{M, N, K, A_SZ, A_T, A_L, MatrixA}, + b::Fragment{M, N, K, B_SZ, B_T, B_L, MatrixB}, + c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator}, + config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T} + + _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T) + _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T) + _, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T) + d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T) + + a_layout = get_hl_layout(A_L) + b_layout = get_hl_layout(B_L) + shape = get_hl_shape(M, N, K) + + # Name of the Julia wrapper + wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_")) + + return quote + a_unfl = unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x) + b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x) + c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) + + x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) + return Fragment{$M, $N, $K, $d_num_els, $D_T, Unspecified, Accumulator}(x) + end +end + + +# ---------- +# WMMA store +# ---------- + +export store_d + +""" + store_d(addr, d, stride, layout, config) + +Store the result matrix `d` to the memory location indicated by `addr`. + +# Arguments +- `addr`: The address to store the matrix to. +- `d`: The [`Fragment`](@ref) corresponding to the `d` matrix. +- `stride`: The leading dimension of the matrix pointed to by `addr`, specified in number of elements. +- `layout`: The storage layout of the matrix. Possible values are [`RowMajor`](@ref) and [`ColMajor`](@ref). +- `config`: The WMMA configuration that should be used for storing this matrix. See [`Config`](@ref). + +See also: [`Fragment`](@ref), [`FragmentLayout`](@ref), [`Config`](@ref) + +!!! warning + + All threads in a warp **MUST** execute the `store` operation in lockstep, and have to use exactly the same arguments. + Failure to do so will result in undefined behaviour. +""" +store_d + +@generated function store_d(addr::DevicePtr{T, AS}, + d::Fragment{M, N, K, D_SZ, T, Unspecified, Accumulator}, + stride::Number, + layout::Type{L}, + config::Type{Config{M, N, K, T}}) where {T, AS, M, N, K, D_SZ, L} + + as_str = get_hl_as_info(AS) + layout = get_hl_layout(L) + shape = get_hl_shape(M, N, K) + num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T) + + # Name of the Julia wrapper + wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str]), "_")) + + return quote + d_unfl = unflatten(NTuple{$frag_sz, $frag_ty}, d.x) + $wrapper(addr, d_unfl, stride) + return nothing + end +end + + +# ------------------ +# WMMA fill fragment +# ------------------ + +export fill_c + +""" + fill_c(value, config) + +Return a [`Fragment`](@ref) filled with the value `value`. + +This operation is useful if you want to implement a matrix multiplication (and thus want to set ``C = O``). + +# Arguments +- `value`: The value used to fill the fragment. Can be a `Float16` or `Float32`. +- `config`: The WMMA configuration that should be used for this WMMA operation. See [`Config`](@ref). +""" +fill_c + +@generated function fill_c(value::T, + config::Type{Config{M, N, K, D_TYPE}}) where {T, M, N, K, D_TYPE} + + # We can't use closures in @generated functions, so we'll have to do it this way instead of + # ntuple(i -> val, $num_els) + num_els, _, _ = get_hl_frag_info("c", T) + + args = [:value for i=1:num_els] + expr = :(tuple($(args...))) + + return quote + return Fragment{$M, $N, $K, $num_els, $T, Unspecified, Accumulator}($expr) + end +end + +################################################################################ +# BROADCASTING OVER WMMA FRAGMENTS +################################################################################ + +# Based on broadcasting implementation of Tuples in +# https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl + + +# Custom broadcast style for Fragments +struct FragmentBroadcastStyle <: Broadcast.BroadcastStyle end + +# Use this broadcasting style for Fragments +Base.BroadcastStyle(::Type{<:Fragment}) = FragmentBroadcastStyle() + +# Broadcast style precedence rules +# If we broadcast a fragment with a scalar, we want the Fragment style to take precedence +Base.BroadcastStyle(s::FragmentBroadcastStyle, t::Broadcast.DefaultArrayStyle{0}) = s + +# We don't want to convert fragments before broadcasting +Base.broadcastable(frag::Fragment) = frag + +# Needed for broadcast machinery +Base.axes(frag::Fragment) = axes(frag.x) + +# Helper functions to get element at specified index +@inline get_index(x, i) = x # scalar +@inline get_index(frag::Fragment, i) = frag[i] # Fragment + +# Helper functions to get first fragment in broadcast call +@inline find_first_fragment(args::Tuple) = find_first_fragment(args[1], Base.tail(args)) +@inline find_first_fragment(a::Fragment, tail) = a +@inline find_first_fragment(::Any, tail) = find_first_fragment(tail) + +# Custom broadcast implementation that returns a Fragment +@inline function Base.copy(bc::Broadcast.Broadcasted{FragmentBroadcastStyle}) + dim = Broadcast.combine_axes(bc.args...) + + if length(dim) != 1 + throw(DimensionMismatch("WMMA fragment broadcast only supports one dimension!")) + end + + N = length(dim[1]) + + tuple = ntuple(i -> bc.f(map(arg -> get_index(arg, i), bc.args)...), Val(N)) + + frag_ty = typeof(find_first_fragment(bc.args)) + return frag_ty(tuple) +end + +end diff --git a/test/device/wmma.jl b/test/device/wmma.jl new file mode 100644 index 00000000..412fcc69 --- /dev/null +++ b/test/device/wmma.jl @@ -0,0 +1,252 @@ +# Need https://github.com/JuliaLang/julia/pull/33970 +# and https://github.com/JuliaLang/julia/pull/34043 +if VERSION >= v"1.4.0-DEV.666" && CUDAnative.current_capability() >= v"7.0" + +using CUDAnative.WMMA + +@testset "WMMA" begin + +################################################################################ + + @testset "LLVM intrinsics" begin + + @testset "llvm_wmma_load" begin + @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["a", "b", "c"], + layout in ["row", "col"], + shape in ["m16n16k16"], + addr_space in ["", "_global", "_shared"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # Float32 is only supported for C + if (elem_type == "f32") && (mat != "c") + continue + end + + # Type-dependent variables + array_ty = elem_type == "f16" ? Float16 : Float32 + expected = elem_type == "f16" ? ntuple(i -> VecElement{Float16}(42), 2) : Float32(42) + + # Address-space dependent variables + do_shared_test = (addr_space == "_shared") + + # Get the function name + func = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") + + input = 42 * ones(array_ty, (16, 16)) + input_dev = CuArray(input) + result = Array{Bool}(undef, 1) + result_dev = CuArray(result) + + @eval @inbounds function kernel(input_dev, result_dev) + if $do_shared_test + input_shared = @cuStaticSharedMem($array_ty, 256) + fill!(input_shared, 42) + + data = $func(pointer(input_shared), 16) + else + data = $func(pointer(input_dev), 16) + end + + result_dev[1] = all(val -> val == $expected, data) + + return + end + + @cuda threads=32 kernel(input_dev, result_dev) + @test all(Array(result_dev)) + end + end + + @testset "llvm_wmma_store" begin + @testset "$(mat)_$(layout)_$(shape)_$(addr_space)_$(elem_type)" for mat in ["d"], + layout in ["row", "col"], + shape in ["m16n16k16"], + addr_space in ["", "_global", "_shared"], + stride in ["stride"], + elem_type in ["f16", "f32"] + + # Type-dependent variables + array_ty = elem_type == "f16" ? Float16 : Float32 + data = elem_type == "f16" ? ntuple(i -> ntuple(j -> VecElement{Float16}(42), 2), 4) : ntuple(i -> 42, 8) + + # Get the function name + func = Symbol("llvm_wmma_store_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") + + # Address-space dependent variables + do_shared_test = (addr_space == "_shared") + + output = Array{array_ty}(undef, (16, 16)) + output_dev = CuArray(output) + + @eval function kernel(output_dev) + if $do_shared_test + shared_mem = @cuStaticSharedMem($array_ty, 256) + $func(pointer(shared_mem), $data, 16) + + for i = 1:256 + @inbounds output_dev[i] = shared_mem[i] + end + else + $func(pointer(output_dev), $data, 16) + end + + return + end + + @cuda threads=32 kernel(output_dev) + @test all(Array(output_dev) .== 42.0) + end + end + + @testset "llvm_wmma_mma" begin + @testset "$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)" for a_layout in ["row", "col"], + b_layout in ["row", "col"], + shape in ["m16n16k16"], + d_elem_type in ["f16", "f32"], + c_elem_type in ["f16", "f32"] + + # Type-dependent variables + d_ty = d_elem_type == "f16" ? Float16 : Float32 + c_ty = c_elem_type == "f16" ? Float16 : Float32 + + # Get the function names + lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_m16n16k16_stride_f16")) + ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_m16n16k16_stride_f16")) + ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_m16n16k16_stride_$(c_elem_type)")) + mma_func = getfield(Main, Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_m16n16k16_$(d_elem_type)_$(c_elem_type)")) + std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_m16n16k16_stride_$(d_elem_type)")) + + # Generate input matrices + a = rand(Float16, (16, 16)) + a_dev = CuArray(a) + b = rand(Float16, (16, 16)) + b_dev = CuArray(b) + c = rand(c_ty, (16, 16)) + c_dev = CuArray(c) + + # Reserve space for result + d = Array{d_ty}(undef, (16, 16)) + d_dev = CuArray(d) + + # Matrix MAC kernel (D = A * B + C) + function kernel(a_dev, b_dev, c_dev, d_dev) + a_frag = lda_func(pointer(a_dev), 16) + b_frag = ldb_func(pointer(b_dev), 16) + c_frag = ldc_func(pointer(c_dev), 16) + + d_frag = mma_func(a_frag, b_frag, c_frag) + + std_func(pointer(d_dev), d_frag, 16) + return + end + + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev) + + new_a = (a_layout == "col" ? a : transpose(a)) + new_b = (b_layout == "col" ? b : transpose(b)) + + @test all(isapprox.(new_a * new_b + c, Array(d_dev); rtol=sqrt(eps(Float16)))) + end + end + end + +################################################################################ + + @testset "Flattening/unflattening" begin + @testset "Flattening" begin + @test CUDAnative.WMMA.flatten(5) == (5,) + @test CUDAnative.WMMA.flatten(5.0) == (5.0,) + @test CUDAnative.WMMA.flatten(VecElement{Float16}(5)) == (Float16(5),) + @test CUDAnative.WMMA.flatten(ntuple(i -> i, 8)) == ntuple(i -> i, 8) + @test CUDAnative.WMMA.flatten(ntuple(i -> VecElement{Float16}(i), 8)) == ntuple(i -> Float16(i), 8) + @test CUDAnative.WMMA.flatten(ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8)) == ntuple(i -> i, 2 * 8) + @test CUDAnative.WMMA.flatten(ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)) == ntuple(i -> Float16(i), 2 * 8) + end + + @testset "Unflattening" begin + @test CUDAnative.WMMA.unflatten(Int64, (5,)) == 5 + @test CUDAnative.WMMA.unflatten(Float64, (5.0,)) == 5.0 + @test CUDAnative.WMMA.unflatten(VecElement{Float16}, (Float16(5),)) == VecElement{Float16}(5) + @test CUDAnative.WMMA.unflatten(NTuple{8, Int64}, ntuple(i -> i, 8)) == ntuple(i -> i, 8) + @test CUDAnative.WMMA.unflatten(NTuple{8, VecElement{Float16}}, ntuple(i -> Float16(i), 8)) == ntuple(i -> VecElement{Float16}(i), 8) + @test CUDAnative.WMMA.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) + @test CUDAnative.WMMA.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) + end + end + +################################################################################ + + @testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5], + ty = [Float16, Float32] + @test ty(5) .* Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 * i), sz)) + @test ty(5) .+ Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 + i), sz)) + end + +################################################################################ + + @testset "CUDA C-style API" begin + + @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout, C type: $c_type, D type: $d_type" for a_layout in [ColMajor, RowMajor], + b_layout in [ColMajor, RowMajor], + c_layout in [ColMajor, RowMajor], + d_layout in [ColMajor, RowMajor], + c_type in [Float16, Float32], + d_type in [Float16, Float32], + do_mac in [true, false] + + a = rand(Float16, (16, 16)) + b = rand(Float16, (16, 16)) + c = rand(c_type, (16, 16)) + d = Array{d_type}(undef, (16, 16)) + + a_dev = CuArray(a) + b_dev = CuArray(b) + c_dev = CuArray(c) + d_dev = CuArray(d) + + alpha = rand(Float16) + beta = rand(c_type) + + @eval function kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) + conf = Config{16, 16, 16, $d_type} + + a_frag = load_a(pointer(a_dev), 16, $a_layout, conf) + b_frag = load_b(pointer(b_dev), 16, $b_layout, conf) + + if $do_mac + c_frag = load_c(pointer(c_dev), 16, $c_layout, conf) + else + c_frag = fill_c($c_type(0), conf) + end + + a_frag = alpha .* a_frag + c_frag = beta .* c_frag + + d_frag = mma(a_frag, b_frag, c_frag, conf) + + store_d(pointer(d_dev), d_frag, 16, $d_layout, conf) + + return + end + + @cuda threads=32 kernel(a_dev, b_dev, c_dev, d_dev, alpha, beta) + d = Array(d_dev) + + new_a = (a_layout == ColMajor) ? a : transpose(a) + new_b = (b_layout == ColMajor) ? b : transpose(b) + new_c = (c_layout == ColMajor) ? c : transpose(c) + new_d = (d_layout == ColMajor) ? d : transpose(d) + + if do_mac + @test all(isapprox.(alpha * new_a * new_b + beta * new_c, new_d; rtol=sqrt(eps(Float16)))) + else + @test all(isapprox.(alpha * new_a * new_b, new_d; rtol=sqrt(eps(Float16)))) + end + end + + end + +################################################################################ +end +end diff --git a/test/runtests.jl b/test/runtests.jl index 1c4d72a9..e7fca7ac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -87,6 +87,7 @@ else include("device/pointer.jl") include("device/array.jl") include("device/cuda.jl") + include("device/wmma.jl") include("examples.jl") end