Skip to content

Commit

Permalink
[lrn] add the updated lrn examples; this can reproduce the issue repo…
Browse files Browse the repository at this point in the history
…rted in intel/llvm#8292
  • Loading branch information
Jin Z committed Feb 14, 2023
1 parent 0420f78 commit 028e2bf
Show file tree
Hide file tree
Showing 9 changed files with 1,413 additions and 0 deletions.
55 changes: 55 additions & 0 deletions lrn-cuda/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#===============================================================================
# User Options
#===============================================================================

# Compiler can be set below, or via environment variable
CC = nvcc
OPTIMIZE = yes
DEBUG = no
ARCH = sm_60

#===============================================================================
# Program name & source code list
#===============================================================================

program = main

source = main.cu

obj = $(source:.cu=.o)

#===============================================================================
# Sets Flags
#===============================================================================

# Standard Flags
CFLAGS := $(EXTRA_CFLAGS) -std=c++14 -Xcompiler -Wall -arch=$(ARCH)

# Linker Flags
LDFLAGS =

# Debug Flags
ifeq ($(DEBUG),yes)
CFLAGS += -g -DDEBUG
LDFLAGS += -g
endif

# Optimization Flags
ifeq ($(OPTIMIZE),yes)
CFLAGS += -O3
endif
#===============================================================================
# Targets to Build
#===============================================================================

$(program): $(obj) Makefile
$(CC) $(CFLAGS) $(obj) -o $@ $(LDFLAGS)

%.o: %.cu kernels.h Makefile
$(CC) $(CFLAGS) -c $< -o $@

clean:
rm -rf $(program) $(obj)

run: $(program)
./$(program) 2
238 changes: 238 additions & 0 deletions lrn-cuda/kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#ifndef KERNELS
#define KERNELS

__global__
void lrn_fwd_kernel(
const float* __restrict__ src_,
float* __restrict__ dst_,
int64_t N_,
int64_t C_,
int64_t D_,
int64_t H_,
int64_t W_,
int64_t stride_mb_,
int64_t ndims_,
int64_t wk_size_,
int64_t size_,
float alpha_,
float beta_,
float k_)
{
int64_t wg_offset_t = blockIdx.x * 32;
int64_t sg_offset_t = threadIdx.x / 32 * 32;
int64_t wi_offset_t = threadIdx.x % 32;
int64_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;
int64_t base_idx = offset_t * 16;

auto data_off = [=](int64_t mb, int64_t c, int64_t d, int64_t h, int64_t w) {
int64_t tag = 0;
switch (tag) {
case 0 : return mb * stride_mb_ + c * H_ * W_ + h * W_ + w;
case 1 : return mb * stride_mb_ + h * W_ * C_ + w * C_ + c;
default:
return (int64_t)1;
}
};

auto ker = [=](int64_t mb, int64_t oc, int64_t od, int64_t oh, int64_t ow) {
float sum = 0;
const int64_t half_size = (size_ - 1) / 2;
bool across_channel = 1;
if (across_channel) {
const int64_t c_st = max(oc - half_size + 0, (int64_t)0);
const int64_t c_en = min(oc + half_size + 1, C_);

for (int64_t c = c_st; c < c_en; ++c) {
const auto s_off = data_off(mb, c, od, oh, ow);
const auto s = src_[s_off];
sum+=s*s;
}
} else {
int64_t d_st = max(od - half_size + 0, (int64_t)0);
int64_t d_en = min(od + half_size + 1, D_);
int64_t h_st = max(oh - half_size + 0, (int64_t)0);
int64_t h_en = min(oh + half_size + 1, H_);
int64_t w_st = max(ow - half_size + 0, (int64_t)0);
int64_t w_en = min(ow + half_size + 1, W_);
for (int64_t d = d_st; d < d_en; ++d) {
for (int64_t h = h_st; h < h_en; ++h) {
for (int64_t w = w_st; w < w_en; ++w) {
const auto s_off = data_off(mb, oc, d, h, w);
const auto s = src_[s_off];
sum+=s*s;
}
}
}
}
sum = k_ + alpha_ * sum / size_;
const auto s_off = data_off(mb, oc, od, oh, ow);
const auto s = src_[s_off];
return (s * sqrtf(1.0f / (sqrtf(sum) * sum)));
};

auto Operation = [=]( int64_t mb, int64_t c, int64_t d, int64_t h, int64_t w) {
bool channel = 0;
if(channel) {
const int64_t off = mb * stride_mb_ + h * W_ * C_ + w * C_ + c;
auto val = ker(mb, c, 0, h, w);
dst_[off] = val;
}
else {
const int64_t off = data_off(mb, c, d, h, w);
auto val = ker(mb, c, d, h, w);
dst_[off] = val;
}
};

for (int64_t blk_idx = 0; blk_idx < 16; blk_idx++) {
int64_t idx = base_idx + blk_idx;
int64_t n = (idx / (C_ * D_ * H_ * W_)) % N_;
int64_t c = (idx / (D_ * H_ * W_)) % C_;
int64_t d = (idx / (H_ * W_)) % D_;
int64_t h = (idx / (W_)) % H_;
int64_t w = (idx / (1)) % W_;

Operation(n, c, d, h, w);
}
}

__global__
void lrn_bwd_kernel(
const float* __restrict__ src_,
float* __restrict__ dst_,
float* __restrict__ diff_src_mem_,
int64_t N_,
int64_t C_,
int64_t D_,
int64_t H_,
int64_t W_,
int64_t stride_mb_,
int64_t ndims_,
int64_t wk_size_,
int64_t size_,
float alpha_,
float beta_,
float k_)
{
int64_t wg_offset_t = blockIdx.x * 32;
int64_t sg_offset_t = threadIdx.x / 32 * 32;
int64_t wi_offset_t = threadIdx.x % 32;
int64_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;
int64_t base_idx = offset_t * 16;

auto data_off = [=](int64_t mb, int64_t c, int64_t d, int64_t h, int64_t w) {
int64_t tag = 0;
switch (tag) {
case 0 : return mb * stride_mb_ + c * H_ * W_ + h * W_ + w;
case 1 : return mb * stride_mb_ + h * W_ * C_ + w * C_ + c;
default:
return (int64_t)1;
}
};

auto get_omega = [=](int64_t mb, int64_t oc, int64_t od, int64_t oh, int64_t ow) {
auto sum = 0;
const int64_t half_size = (size_ - 1) / 2;
bool across_channel = 1;
if (across_channel) {
const int64_t c_st = max(oc - half_size + 0, (int64_t)0);
const int64_t c_en = min(oc + half_size + 1, C_);

for (int64_t c = c_st; c < c_en; ++c) {
const auto s_off = data_off(mb, c, od, oh, ow);
const auto s = src_[s_off];
sum += s * s;
}
} else {
int64_t d_st = max(od - half_size + 0, (int64_t)0);
int64_t d_en = min(od + half_size + 1, D_);
int64_t h_st = max(oh - half_size + 0, (int64_t)0);
int64_t h_en = min(oh + half_size + 1, H_);
int64_t w_st = max(ow - half_size + 0, (int64_t)0);
int64_t w_en = min(ow + half_size + 1, W_);
for (int64_t d = d_st; d < d_en; ++d)
for (int64_t h = h_st; h < h_en; ++h)
for (int64_t w = w_st; w < w_en; ++w) {
const auto s_off = data_off(mb, oc, d, h, w);
const auto s = src_[s_off];
sum += s * s;
}
}
return (k_ + alpha_ * sum / size_);
};

auto ker = [=](int64_t mb, int64_t oc, int64_t od, int64_t oh, int64_t ow) {
float A = 0, B = 0;
const int64_t half_size = (size_ - 1) / 2;
bool across_channel = 1;
if (across_channel) {
const int64_t c_st = max(oc - half_size + 0, (int64_t)0);
const int64_t c_en = min(oc + half_size + 1, C_);

for (int64_t c = c_st; c < c_en; ++c) {
const auto off = data_off(mb, c, od, oh, ow);
const auto omega = get_omega(mb, c, od, oh, ow);
const auto omega_in_beta
= sqrt(1.0f / (sqrt(omega) * omega));

const auto dst_val = dst_[off];
const auto tmp = omega_in_beta * dst_val;
if (c == oc) A = tmp;
const auto src_val = src_[off];
B += (src_val * tmp / omega);
}
} else {
int64_t d_st = max(od - half_size + 0, (int64_t)0);
int64_t d_en = min(od + half_size + 1, D_);
int64_t h_st = max(oh - half_size + 0, (int64_t)0);
int64_t h_en = min(oh + half_size + 1, H_);
int64_t w_st = max(ow - half_size + 0, (int64_t)0);
int64_t w_en = min(ow + half_size + 1, W_);
for (int64_t d = d_st; d < d_en; ++d)
for (int64_t h = h_st; h < h_en; ++h)
for (int64_t w = w_st; w < w_en; ++w) {
const auto off = data_off(mb, oc, d, h, w);
const auto omega = get_omega(mb, oc, d, h, w);
const auto omega_in_beta
= sqrtf(1.0f / (sqrtf(omega) * omega));

const auto dst_val = dst_[off];
const auto tmp
= omega_in_beta * dst_val;
if (d == od && h == oh && w == ow) A = tmp;
const auto src_val = src_[off];
B += (src_val * tmp / omega);
}
}
const auto off = data_off(mb, oc, od, oh, ow);
const auto src_val = src_[off];
B *= (2.0f * alpha_ * beta_ * src_val / size_);
return (A - B);
};

auto Operation = [=]( int64_t mb, int64_t c, int64_t d, int64_t h, int64_t w) {
bool channel = 0;
if(channel) {
const int64_t off = mb * stride_mb_ + h * W_ * C_ + w * C_ + c;
auto val = ker(mb, c, 0, h, w);
dst_[off] = val;
}
else {
const int64_t off = data_off(mb, c, d, h, w);
auto val = ker(mb, c, d, h, w);
diff_src_mem_[off] = val;
}
};

for (int64_t blk_idx = 0; blk_idx < 16; blk_idx++) {
auto idx = base_idx + blk_idx;
auto n = (idx / (C_ * D_ * H_ * W_)) % N_;
auto c = (idx / (D_ * H_ * W_)) % C_;
auto d = (idx / (H_ * W_)) % D_;
auto h = (idx / (W_)) % H_;
auto w = (idx / (1)) % W_;

Operation(n, c, d, h, w);
}
}
#endif
Loading

0 comments on commit 028e2bf

Please sign in to comment.