Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add per-tensor and per-token AZP epilogues #5941

Merged
merged 58 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
9780ed0
working version of folding bias into azp
ProExpertProg Jun 20, 2024
09a2796
per-token quantization
ProExpertProg Jun 21, 2024
6a83723
fixed the quantized zp case
ProExpertProg Jun 25, 2024
f8641c8
removed TODO
ProExpertProg Jun 25, 2024
e0230d8
bias is out dtype now
ProExpertProg Jun 25, 2024
9b08220
naive_mm extracted
ProExpertProg Jun 26, 2024
fad210f
Added comments to BiasEpilogue
ProExpertProg Jun 26, 2024
c028abf
added ScaledEpilogueBiasAzp with per-token azp
ProExpertProg Jun 27, 2024
3f6c73b
per-token azp epilogue for other sm types (not just sm80)
ProExpertProg Jul 8, 2024
3dcf9ef
refactored w8a8 benchmarks and added bias and azp cases
ProExpertProg Jul 8, 2024
caeeea7
format
ProExpertProg Jul 8, 2024
2f13e48
w8a8 for fp8
ProExpertProg Jul 8, 2024
2d76ddb
PR comments, tests more robust
ProExpertProg Jul 9, 2024
f3d4cc4
using float for compute type in per-token azp
ProExpertProg Jul 9, 2024
c811236
refactored epilogues by extracting utilities
ProExpertProg Jul 10, 2024
e91c1aa
per-token azp done in integers before scaling
ProExpertProg Jul 10, 2024
db3550a
renamed epilogue, fixed comments
ProExpertProg Jul 11, 2024
92239aa
Added per-tensor azp epilogue (int32) and added a test
ProExpertProg Jul 11, 2024
f57ed79
Fixed assert for de-quantizing in tests
ProExpertProg Jul 11, 2024
cc81863
format
ProExpertProg Jul 11, 2024
8ad8b10
added bias to azp tests
ProExpertProg Jul 12, 2024
0665754
visual load reduced, azp i8/i32 TODO, azp-fold test skipped
ProExpertProg Jul 12, 2024
1117bd2
merged per-token and per-tensor azp tests into one function to reduce…
ProExpertProg Jul 12, 2024
567389a
AZP epilogue fixed for sm75 and sm89
ProExpertProg Jul 12, 2024
7be8c97
AZP epilogues for c3x
ProExpertProg Jul 12, 2024
762c64b
Expanded benchmarking cases
ProExpertProg Jul 12, 2024
6abd9f5
Cache zero bias tensor
ProExpertProg Jul 16, 2024
5a93712
nullptr bias working
ProExpertProg Jul 16, 2024
54bf6eb
added nullptr support to ColOrScalar
ProExpertProg Jul 16, 2024
b3ed8c4
If nullptr enabled, no scalar load
ProExpertProg Jul 16, 2024
53a6565
Tightened tolerance for AZP tests
ProExpertProg Jul 17, 2024
7c6dfd7
Merge remote-tracking branch 'refs/remotes/upstream/main' into luka/a…
ProExpertProg Jul 17, 2024
7c533bd
Cleanup: azp tensors can be 2d, removed BiasCache, comments
ProExpertProg Jul 17, 2024
818411a
comments
ProExpertProg Jul 17, 2024
96aaf8f
Added epilogues doc
ProExpertProg Jul 24, 2024
6e4006c
Math fixes
ProExpertProg Jul 24, 2024
9af72d6
Math fixes
ProExpertProg Jul 24, 2024
0f96e7c
Reordered parameters for consistency
ProExpertProg Jul 24, 2024
4322d6d
Changed zero-points to be negative
ProExpertProg Jul 24, 2024
7f4c4d9
Arg order benchmarks
ProExpertProg Jul 24, 2024
1af6354
Unrefactored c2x visitor to stay more similar to original
ProExpertProg Jul 27, 2024
2acaead
fixed test case azp_adj dtype
ProExpertProg Jul 27, 2024
c669d51
Merge remote-tracking branch 'refs/remotes/upstream/main' into luka/a…
ProExpertProg Jul 30, 2024
845e62a
PR comments
ProExpertProg Jul 30, 2024
8a0e670
PR comments: md
ProExpertProg Jul 30, 2024
21cc383
revert visitor
ProExpertProg Jul 30, 2024
17f06b9
added roworzero visitor, c2x uses it for bias in AZP cases, also asse…
ProExpertProg Jul 30, 2024
310d1f4
PR comments: equations
ProExpertProg Jul 30, 2024
717d015
format
ProExpertProg Jul 30, 2024
77f1765
AZP per-token comment
ProExpertProg Jul 31, 2024
8c76f7e
c3x azptoken comment
ProExpertProg Jul 31, 2024
62e4790
moved TORCH_CHECKs to top function
ProExpertProg Jul 31, 2024
509ce0e
Merge remote-tracking branch 'refs/remotes/upstream/main' into luka/a…
ProExpertProg Jul 31, 2024
8de55f4
fixed type annotation
ProExpertProg Aug 1, 2024
5ab9f00
Merge remote-tracking branch 'refs/remotes/upstream/main' into luka/a…
ProExpertProg Aug 1, 2024
fcca6c7
w8a8_benchmarks fix
ProExpertProg Aug 1, 2024
30077ba
Merge remote-tracking branch 'upstream/main' into luka/aq-azp-test
ProExpertProg Aug 2, 2024
4f3cabd
Merge remote-tracking branch 'refs/remotes/upstream/main' into luka/a…
ProExpertProg Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 107 additions & 78 deletions benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:

def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:

a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5

Expand All @@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
raise ValueError("unsupported dtype")


# impl


def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch.mm(a, b)


def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype)


def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
use_fast_accum=True)


def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)


# bench
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
sub_label: str, fn: Callable, description: str) -> TMeasurement:

def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
min_run_time = 1

globals = {
"a": a,
"b": b,
"scale_a": scale_a,
"scale_b": scale_b,
"out_dtype": out_dtype,
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
Expand All @@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)

timers = []
# pytorch impl - bfloat16
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16)))

# pytorch impl - float16
timers.append(
bench_fn(a.to(dtype=torch.float16, device="cuda"),
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
torch.float16, label, sub_label, pytorch_mm_impl,
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
bench_fn(label, sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))

# cutlass impl
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))

# cutlass with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))

# cutlass with azp per-tensor
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj))

# cutlass with azp per-tensor + bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, None, bias))

# cutlass with azp per-token
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, azp))

# cutlass with azp per-token + bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, azp, bias))

return timers

Expand All @@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

timers = []

# pytorch impl w. bf16
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))

# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))

# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True))

# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))

# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True))

# cutlass impl: bf16 output
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass impl: fp16 output
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))

# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))

# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)))

return timers


Expand All @@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]):

def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:

results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
Expand All @@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None):

print(f"== All Results {base_description} ====")
print_timers(data)

Expand Down Expand Up @@ -251,7 +281,6 @@ def run_range_bench(args):


def run_model_bench(args):

print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);

torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
Expand Down
Loading
Loading