Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Profiler] Allow user to flush L2 cache in time_evalutor function for profiling CUDA kernels #13726

Merged
merged 12 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/runtime/cuda/l2_cache_flush.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
// Acknowledgement: l2flush struct in nvbench project.
// Reference:
// https://github.com/NVIDIA/nvbench/blob/1a13a2e724b8aa8aee27649ac6878babb63862a6/nvbench/detail/l2flush.cuh
#include <cuda.h>
#include <cuda_runtime.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

#include "cuda_common.h"

namespace tvm {

namespace runtime {

class L2Flush {
public:
explicit L2Flush() : initialized_(false), l2_size_(0), l2_buffer_(nullptr) {}

~L2Flush() {
if (l2_size_ > 0) {
echuraev marked this conversation as resolved.
Show resolved Hide resolved
CUDA_CALL(cudaFree(l2_buffer_));
}
}

void Flush() {
if (!initialized_) {
echuraev marked this conversation as resolved.
Show resolved Hide resolved
// initialize l2_buffer_ and l2_size_
initialized_ = true;
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
CUDA_CALL(cudaDeviceGetAttribute(&l2_size_, cudaDevAttrL2CacheSize, device_id));
if (l2_size_ > 0) {
void* buffer = l2_buffer_;
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
CUDA_CALL(cudaMalloc(&buffer, l2_size_));
l2_buffer_ = reinterpret_cast<int*>(buffer);
}
}
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
if (l2_size_ > 0) {
CUDA_CALL(cudaMemsetAsync(l2_buffer_, 0, l2_size_, stream));
}
}

static L2Flush* ThreadLocal();

private:
bool initialized_ = false;
int l2_size_;
int* l2_buffer_;
};

typedef dmlc::ThreadLocalStore<L2Flush> L2FlushStore;

L2Flush* L2Flush::ThreadLocal() { return L2FlushStore::Get(); }

TVM_REGISTER_GLOBAL("l2_cache_flush_cuda").set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist.";
L2Flush::ThreadLocal()->Flush();
});

} // namespace runtime
} // namespace tvm
19 changes: 11 additions & 8 deletions src/runtime/profiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,9 +882,6 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
DeviceAPI::Get(dev)->StreamSync(dev, nullptr);

for (int i = 0; i < repeat; ++i) {
if (f_preproc != nullptr) {
f_preproc.CallPacked(args, &temp);
}
double duration_ms = 0.0;
int absolute_zero_times = 0;
do {
Expand All @@ -894,15 +891,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, Device dev, int number, int repeat,
std::max((min_repeat_ms / (duration_ms / number) + 1), number * golden_ratio));
}

int64_t accum_t_nanos = 0;
// start timing
Timer t = Timer::Start(dev);
for (int j = 0; j < number; ++j) {
// call preprocessing function
if (f_preproc != nullptr) {
f_preproc.CallPacked(args, &temp);
yzh119 marked this conversation as resolved.
Show resolved Hide resolved
}
Timer t = Timer::Start(dev);
pf.CallPacked(args, &temp);
t->Stop();
int64_t t_nanos = t->SyncAndGetElapsedNanos();
accum_t_nanos += t_nanos;
}
t->Stop();
int64_t t_nanos = t->SyncAndGetElapsedNanos();
if (t_nanos == 0) absolute_zero_times++;
duration_ms = t_nanos / 1e6;
if (accum_t_nanos == 0) absolute_zero_times++;
duration_ms = accum_t_nanos / 1e6;
} while (duration_ms < min_repeat_ms && absolute_zero_times < limit_zero_time_iterations);

double speed = duration_ms / 1e3 / number;
Expand Down
63 changes: 63 additions & 0 deletions tests/python/unittest/test_evaluator_flush_l2_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import te
from tvm.script import tir as T
import tvm.testing
import numpy as np


@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


@tvm.testing.requires_cuda
def test_evaluator_flush_l2_cache():
mod = tvm.IRModule.from_expr(matmul)
sch = tvm.tir.Schedule(mod)
blk = sch.get_block("matmul")
i, j, k = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(k, "threadIdx.x")
f = tvm.build(sch.mod["main"], target=tvm.target.cuda(arch="sm_86"))
dev = tvm.cuda(0)
evaluator_no_flush = f.time_evaluator(f.entry_name, dev, number=100)

a = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
b = tvm.nd.array(np.random.rand(128, 128).astype("float32"), device=dev)
c = tvm.nd.array(np.zeros((128, 128)).astype("float32"), device=dev)
args = [a, b, c]
print("Evaluator (w/o L2 flush):\t{:.5f}ms".format(evaluator_no_flush(*args).mean * 1000))

evaluator_with_flush = f.time_evaluator(
f.entry_name, dev, number=100, f_preproc="l2_cache_flush_cuda"
)
print("Evaluator (w/ L2 flush):\t{:.5f}ms".format(evaluator_with_flush(*args).mean * 1000))
echuraev marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
test_evaluator_flush_l2_cache()