-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathutils.cpp
154 lines (132 loc) · 4.8 KB
/
utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
void segment_blocks(torch::Tensor layout, torch::Tensor idx, torch::Tensor scratch, int max_width, ret_t& ret){
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
torch::Tensor tmp = torch::zeros_like(layout);
auto _tmp = tmp.accessor <int, 3>();
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor <int, 3>();
auto _scratch = scratch.accessor<int, 3>();
std::vector<int> current(H, 0);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(size_t h = 0; h < H; h++){
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
for(size_t m = 0; m < M; m++){
for(size_t n = 0; n < N; n++){
int v = _layout[h][m][n];
if(v == 0)
continue;
int n_left= ii_left[max_width-1];
int m_top = ii_top [max_width-1][n];
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
int topleft = (m_top >=0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for(int nn = n_left + 1; nn < n; nn++)
if(ii_top[max_width-1][nn] > ii_top[max_width-1][n])
width = 1;
_tmp[h][m][n] = width;
// update n_left ring buffer
for(int k = 0; k < max_width-1; k++)
ii_left[k] = ii_left[k+1];
ii_left[max_width-1] = n;
// update ii_top ring buffer
for(int k = 0; k < max_width-1; k++)
ii_top[k][n] = ii_top[k+1][n];
ii_top[max_width-1][n] = m;
// block is too small -- skip
if(width != max_width)
continue;
// retained blocks are set to zeros
for(size_t km = 0; km < max_width; km++)
for(size_t kn = 0; kn < max_width; kn++)
{
int mm = ii_top[km][n];
int nn = ii_left[kn];
if(mm < 0 || nn < 0)
continue;
_layout[h][mm][nn] = 0;
_tmp[h][mm][nn] = 0;
_scratch[h][current[h]][0] = (int)h;
_scratch[h][current[h]][1] = (int)mm;
_scratch[h][current[h]][2] = (int)nn;
_scratch[h][current[h]][3] = _idx[h][mm][nn];
current[h]++;
}
}
}
}
std::vector<torch::Tensor> to_cat;
for(size_t h = 0; h < H; h++)
if(current[h] > 0)
to_cat.push_back(scratch[h].slice(0, 0, current[h]));
if(!to_cat.empty())
ret.push_back({max_width, torch::cat(to_cat)});
}
ret_t sdd_segment(torch::Tensor layout, int start_width) {
ret_t ret;
// block index
torch::Tensor idx = torch::zeros_like(layout);
int current = 0;
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
auto _layout = layout.accessor <int, 3>();
auto _idx = idx.accessor<int, 3>();
for(size_t h = 0; h < H; h++)
for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){
if(_layout[h][m][n] == 0)
continue;
_idx[h][m][n] = current++;
}
// scratch memory
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
for(int max_width = start_width; max_width > 0; max_width /= 2)
segment_blocks(layout, idx, scratch, max_width, ret);
return ret;
}
torch::Tensor make_layout(int num_heads, int num_blocks, const std::string& mode,
int block_stride, bool unidirectional, int numverts, int vertsize)
{
std::vector<long> shape = {{num_heads, num_blocks, num_blocks}};
if(mode == "dense")
return torch::ones(shape, torch::kLong);
torch::Tensor ret = torch::zeros(shape, torch::kLong);
auto _ret = ret.accessor<long, 3>();
for(int h = 0; h < num_heads; h++){
// set first part of layout
for(int i = 0; i < num_blocks; i+= block_stride)
for(int j = i; j < i + block_stride; j++)
for(int k = i; k < (unidirectional ? j+1 : i + block_stride); k++)
_ret[h][j][k] = 1;
// set second part of layout
int start = block_stride - (1 + h % numverts) * vertsize;
for(int i = 0; i < num_blocks; i++){
int end = unidirectional ? i : num_blocks;
for(int j = start; j < end; j+= block_stride)
for(int k = j; k < j + vertsize; k += num_blocks)
_ret[h][i][k] = 1;
}
}
return ret;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("make_layout", &make_layout, "make sparsity layout");
m.def("sdd_segment", &sdd_segment, "SDD segmentation handler");
}