diff --git a/doc/guide/myelin.md b/doc/guide/myelin.md
index d90b84af..357212ae 100644
--- a/doc/guide/myelin.md
+++ b/doc/guide/myelin.md
@@ -21,10 +21,10 @@ Build system: Bazel
## Using Myelin in Python
-Myelin represents a computation graph using a _flow_. The graph is divivded into
+Myelin represents a computation graph using a _flow_. The graph is divided into
_functions_ which can be computed independently. A function is a set of
_operations_ with tensor inputs and outputs. The tensor inputs and outputs are
-_variables_ in the flow. Variables can either be global constant tensor, e.g.
+_variables_ in the flow. Variables can either be global constant tensors, e.g.
learned weights in a neural network, or parameter tensors, which are local to
the function.
diff --git a/python/myelin/builder.py b/python/myelin/builder.py
index 22854774..09cce185 100644
--- a/python/myelin/builder.py
+++ b/python/myelin/builder.py
@@ -35,7 +35,7 @@
"f": DT_FLOAT32,
"d": DT_FLOAT64,
"i": DT_INT32,
- "l": DT_INT32,
+ "l": DT_INT64,
"B": DT_INT8,
"h": DT_INT16,
"b": DT_INT8,
@@ -592,8 +592,8 @@ def assign(self, x, y, name=None):
op.add_input(x)
op.add_input(y)
- def scatter_add(self, m, f, v, ref=False, name=None):
- op = self.rawop("ScatterAdd", name)
+ def assign_add_scatter(self, m, f, v, ref=False, name=None):
+ op = self.rawop("AssignAddScatter", name)
op.add_input(m)
op.add_input(f)
op.add_input(v)
diff --git a/python/myelin/simulator.py b/python/myelin/simulator.py
index 002a4b2d..39d6694a 100644
--- a/python/myelin/simulator.py
+++ b/python/myelin/simulator.py
@@ -229,7 +229,13 @@ def compute(flow, f, data):
for k in range(len(splits)): v[o[k]] = splits[k]
elif op.type == "Gather":
v[o[0]] = gather(v[i[0]], v[i[1]])
- elif op.type == "ScatterAdd":
+ elif op.type == "GatherSum":
+ v[o[0]] = np.sum(gather(v[i[0]], v[i[1]]), axis=1)
+ elif op.type == "GatherMax":
+ v[o[0]] = np.max(gather(v[i[0]], v[i[1]]), axis=1)
+ elif op.type == "GatherAvg":
+ v[o[0]] = np.sum(gather(v[i[0]], v[i[1]]), axis=1) / v[i[1]].shape[1]
+ elif op.type == "AssignAddScatter":
m = v[i[0]]
f = v[i[1]]
x = v[i[2]]
diff --git a/sling/myelin/builder.h b/sling/myelin/builder.h
index 5683f8a6..51c9ee24 100644
--- a/sling/myelin/builder.h
+++ b/sling/myelin/builder.h
@@ -336,6 +336,10 @@ class FlowBuilder : public Scope {
int n = f->rank() == 0 ? 1 : f->dim(1);
return Op("Gather", {M, f}, M->type, {n, M->dim(1)});
}
+ Variable *Gather(Variable *M, Variable *f, Variable *oov) {
+ int n = f->rank() == 0 ? 1 : f->dim(1);
+ return Op("Gather", {M, f, oov}, M->type, {n, M->dim(1)});
+ }
Variable *GatherSum(Variable *M, Variable *f) {
return Op("GatherSum", {M, f}, M->type, {1, M->dim(1)});
}
@@ -350,6 +354,9 @@ class FlowBuilder : public Scope {
Variable *Scatter(Variable *f, Variable *v, int size) {
return Op("Scatter", {f, v}, v->type, {size, v->dim(1)});
}
+ Variable *Scatter(Variable *f, Variable *v, int size, Variable *oov) {
+ return Op("Scatter", {f, v, oov}, v->type, {size, v->dim(1)});
+ }
// Assignment.
Operation *Assign(Variable *var, Variable *value) {
@@ -364,8 +371,8 @@ class FlowBuilder : public Scope {
return Op("Assign", {var, value})->set_ref();
}
- Operation *ScatterAdd(Variable *M, Variable *f, Variable *v) {
- return RawOp("ScatterAdd", {M, f, v});
+ Operation *AssignAddScatter(Variable *M, Variable *f, Variable *v) {
+ return RawOp("AssignAddScatter", {M, f, v});
}
// Concatenation.
diff --git a/sling/myelin/flow.cc b/sling/myelin/flow.cc
index e3bd6eb2..2e2c835b 100644
--- a/sling/myelin/flow.cc
+++ b/sling/myelin/flow.cc
@@ -1790,6 +1790,7 @@ Flow::Operation *Flow::AddOperation(Function *func,
const string &type) {
Operation *op = AddOperation(name, type);
func->AddOperation(op);
+ if (op->name.empty()) op->name = OpName(func->name + "/" + type);
return op;
}
@@ -1798,8 +1799,7 @@ Flow::Operation *Flow::AddOperation(Function *func,
const string &type,
const std::vector &inputs,
const std::vector &outputs) {
- Operation *op = AddOperation(name, type);
- func->AddOperation(op);
+ Operation *op = AddOperation(func, name, type);
for (auto *input : inputs) op->AddInput(input);
for (auto *output : outputs) op->AddOutput(output);
return op;
@@ -2148,7 +2148,10 @@ Flow::Blob *Flow::DataBlock(const string &name) {
string Flow::VarName(const string &prefix) {
for (int n = 0;; ++n) {
string name = prefix;
- if (n > 0) name.append(std::to_string(n));
+ if (n > 0) {
+ name.push_back('_');
+ name.append(std::to_string(n));
+ }
if (Var(name) == nullptr) return name;
}
}
@@ -2156,7 +2159,10 @@ string Flow::VarName(const string &prefix) {
string Flow::OpName(const string &prefix) {
for (int n = 0;; ++n) {
string name = prefix;
- if (n > 0) name.append(std::to_string(n));
+ if (n > 0) {
+ name.push_back('_');
+ name.append(std::to_string(n));
+ }
if (Op(name) == nullptr) return name;
}
}
diff --git a/sling/myelin/kernel/array.cc b/sling/myelin/kernel/array.cc
index 4f5298f8..bf28e05e 100644
--- a/sling/myelin/kernel/array.cc
+++ b/sling/myelin/kernel/array.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include
+
#include "sling/myelin/compute.h"
#include "sling/myelin/macro-assembler.h"
#include "sling/myelin/simd-assembler.h"
@@ -23,38 +25,6 @@ namespace myelin {
using namespace jit;
-// Allocate registers for unrolling.
-static int SIMDUnrolls(int size, int vecsize, int max_unrolls) {
- int unrolls = 0;
- for (int i = 1; i <= max_unrolls; ++i) {
- int batch_size = i * vecsize;
- if (size >= batch_size && size % batch_size == 0) unrolls = i;
- }
- return unrolls;
-}
-
-static int AllocateYMMUnrolls(MacroAssembler *masm,
- int size,
- int max_unrolls,
- std::vector *regs) {
- int unrolls = SIMDUnrolls(size, 8, max_unrolls);
- for (int i = 0; i < std::max(unrolls, 1); ++i) {
- regs->push_back(masm->mm().allocy());
- }
- return unrolls;
-}
-
-static int AllocateZMMUnrolls(MacroAssembler *masm,
- int size,
- int max_unrolls,
- std::vector *regs) {
- int unrolls = SIMDUnrolls(size, 16, max_unrolls);
- for (int i = 0; i < std::max(unrolls, 1); ++i) {
- regs->push_back(masm->mm().allocz());
- }
- return unrolls;
-}
-
// Reshape tensor while preserving the underlying data.
class Reshape : public Kernel {
public:
@@ -564,7 +534,7 @@ class GeneralConcat : public Kernel {
}
};
-// Split input tensors input tensor into chunks along a dimension.
+// Split input tensors into chunks along a dimension.
class Split : public Kernel {
public:
string Name() override { return "Split"; }
@@ -654,10 +624,11 @@ class SingleGather : public Kernel {
Tensor *f = step->input(1);
Tensor *oov = step->indegree() == 3 ? step->input(2) : nullptr;
Tensor *v = step->output(0);
+ Type type = M->type();
if (f->type() != DT_INT32) return false;
- if (M->type() != DT_FLOAT || M->rank() != 2) return false;
- if (v->type() != DT_FLOAT) return false;
- if (oov != nullptr && oov->type() != DT_FLOAT) return false;
+ if (M->rank() != 2) return false;
+ if (v->type() != type) return false;
+ if (oov != nullptr && oov->type() != type) return false;
int n = f->elements();
int d = M->dim(1);
int r = v->rank() - 1;
@@ -755,10 +726,11 @@ class MultiGather : public Kernel {
Tensor *f = step->input(1);
Tensor *oov = step->indegree() == 3 ? step->input(2) : nullptr;
Tensor *v = step->output(0);
+ Type type = M->type();
if (f->type() != DT_INT32) return false;
- if (M->type() != DT_FLOAT || M->rank() != 2) return false;
- if (v->type() != DT_FLOAT) return false;
- if (oov != nullptr && oov->type() != DT_FLOAT) return false;
+ if (M->rank() != 2) return false;
+ if (v->type() != type) return false;
+ if (oov != nullptr && oov->type() != type) return false;
int n = f->elements();
int d = M->dim(1);
int r = v->rank() - 1;
@@ -860,9 +832,6 @@ class PoolingGather : public Kernel {
}
bool Supports(Step *step) override {
- // Requires SSE or AVX support.
- if (!CPU::Enabled(AVX) && !CPU::Enabled(SSE)) return false;
-
// Check inputs and outputs.
if (step->indegree() != 2 || step->outdegree() != 1) return false;
@@ -870,9 +839,13 @@ class PoolingGather : public Kernel {
Tensor *M = step->input(0);
Tensor *f = step->input(1);
Tensor *v = step->output(0);
- if (M->type() != DT_FLOAT || M->rank() != 2) return false;
+ if (!SIMDAssembler::Supports(M->type()) || M->rank() != 2) return false;
if (f->type() != DT_INT32 || f->rank() != 2) return false;
- if (v->type() != DT_FLOAT || v->elements() != M->dim(1)) return false;
+ if (v->type() != M->type() || v->elements() != M->dim(1)) return false;
+ if (pooling_ == AVG) {
+ if (M->type() != DT_FLOAT && M->type() != DT_DOUBLE) return false;
+ if (!CPU::Enabled(SSE2)) return false;
+ }
return true;
}
@@ -881,16 +854,18 @@ class PoolingGather : public Kernel {
Tensor *M = step->input(0);
Tensor *v = step->output(0);
- // Align to one ymm/xmm register.
- int align = 4;
- if (CPU::Enabled(AVX)) align = 8;
- if (CPU::Enabled(AVX512F)) align = 16;
- M->SetMiniumAlignment(align * sizeof(float));
- v->SetMiniumAlignment(align * sizeof(float));
+ // Align to one vector register.
+ Type type = M->type();
+ int vecbytes = SIMDAssembler::VectorBytes(type);
+ M->SetMiniumAlignment(vecbytes);
+ v->SetMiniumAlignment(vecbytes);
// Embedding matrix must be row-major.
M->RequireOrder(ROW_MAJOR);
- if (M->dim(1) >= align) M->MinAlign({1, align});
+
+ // Reserve registers.
+ int regs = SIMDAssembler::RegisterUsage(type) + 8;
+ step->SetRegisterUsage(regs);
}
void Generate(Step *step, MacroAssembler *masm) override {
@@ -898,10 +873,19 @@ class PoolingGather : public Kernel {
Tensor *M = step->input(0);
Tensor *f = step->input(1);
Tensor *v = step->output(0);
- CHECK(f->IsLocal()) << f->name();
- CHECK(v->IsLocal()) << v->name();
int n = v->elements();
+ // Create SIMD code generators.
+ Type type = M->type();
+ int dsize = TypeTraits::of(type).size();
+ int vecbytes = SIMDAssembler::VectorBytes(type);
+ bool aligned = M->stride(0) % vecbytes == 0;
+ SIMDAssembler sasm(masm, type, aligned);
+
+ // Compute vector processing strategy.
+ SIMDStrategy strategy(&sasm, n);
+ strategy.PreloadMasks();
+
// Allocate registers.
Register acc = masm->rr().alloc_fixed(rax);
Register src = masm->rr().alloc_fixed(rsi);
@@ -913,6 +897,7 @@ class PoolingGather : public Kernel {
Register embeddings = masm->rr().alloc();
Register input = masm->rr().alloc();
Register output = masm->rr().alloc();
+ auto elem = sasm.alloc(strategy.MaxUnrolls());
// Load tensor locations.
__ LoadTensorAddress(embeddings, M);
@@ -925,12 +910,6 @@ class PoolingGather : public Kernel {
__ xorq(fcnt, fcnt);
}
- // Set up mask.
- OpmaskRegister mask = masm->kk().alloc();
- if (CPU::Enabled(AVX512F) && n % 16 != 0) {
- __ LoadMask(n % 16, mask);
- }
-
// Find first (non-negative) feature.
Label l1, l2, done;
__ bind(&l1);
@@ -979,129 +958,45 @@ class PoolingGather : public Kernel {
__ addq(src, acc);
// Update output vector with embedding vector for feature.
- if (masm->Enabled(AVX512F)) {
- // Combine elements using AVX512 vectors.
- std::vector elem;
- int main = (n / 16) * 16;
- int unrolls = AllocateZMMUnrolls(masm, main, 4, &elem);
- if (unrolls > 0) {
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 16 * sizeof(float);
- __ vmovaps(elem[i], Operand(src, ofs, times_1, disp));
- }
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 16 * sizeof(float);
- if (pooling_ == MAX) {
- __ vmaxps(elem[i], elem[i], Operand(output, ofs, times_1, disp));
- } else {
- __ vaddps(elem[i], elem[i], Operand(output, ofs, times_1, disp));
- }
- }
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 16 * sizeof(float);
- __ vmovaps(Operand(output, ofs, times_1, disp), elem[i]);
- }
-
- if (main > 16 * unrolls) {
- __ addq(ofs, Immediate(16 * unrolls * sizeof(float)));
- __ cmpq(ofs, Immediate(main * sizeof(float)));
- __ j(less, &next);
- }
- }
-
- // Combine residual elements.
- if (n % 16 > 0) {
- int disp = main * sizeof(float);
- __ vmovaps(elem[0], Operand(src, disp), Mask(mask, zeroing));
- if (pooling_ == MAX) {
- __ vmaxps(elem[0], elem[0], Operand(output, disp),
- Mask(mask, zeroing));
+ Reduction op = pooling_ == MAX ? REDUCE_MAX : REDUCE_ADD;
+ for (auto &phase : strategy.phases()) {
+ auto *gen = phase.generator;
+ int vecsize = gen->VectorSize();
+ int blkstart = phase.offset * dsize;
+ int blksize = phase.unrolls * vecsize * dsize;
+
+ if (phase.repeat > 1) {
+ // Repeated phase.
+ Label lu;
+ if (blkstart == 0) {
+ __ xorq(ofs, ofs);
} else {
- __ vaddps(elem[0], elem[0], Operand(output, disp),
- Mask(mask, zeroing));
- }
- __ vmovaps(Operand(output, disp), elem[0], Mask(mask, merging));
- }
- } else if (masm->Enabled(AVX)) {
- // Combine elements using AVX vectors.
- std::vector elem;
- int main = (n / 8) * 8;
- int unrolls = AllocateYMMUnrolls(masm, main, 4, &elem);
- if (unrolls > 0) {
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 8 * sizeof(float);
- __ vmovaps(elem[i], Operand(src, ofs, times_1, disp));
+ __ movq(ofs, Immediate(blkstart));
}
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 8 * sizeof(float);
- if (pooling_ == MAX) {
- __ vmaxps(elem[i], elem[i], Operand(output, ofs, times_1, disp));
- } else {
- __ vaddps(elem[i], elem[i], Operand(output, ofs, times_1, disp));
- }
- }
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 8 * sizeof(float);
- __ vmovaps(Operand(output, ofs, times_1, disp), elem[i]);
+ __ bind(&lu);
+ for (int i = 0; i < phase.unrolls; ++i) {
+ int disp = i * vecsize * dsize;
+ gen->Load(elem[i], Operand(src, ofs, times_1, disp));
+ gen->Accumulate(op, elem[i], Operand(output, ofs, times_1, disp));
+ gen->Store(Operand(output, ofs, times_1, disp), elem[i]);
}
-
- if (main > 8 * unrolls) {
- __ addq(ofs, Immediate(8 * unrolls * sizeof(float)));
- __ cmpq(ofs, Immediate(main * sizeof(float)));
- __ j(less, &next);
+ __ addq(ofs, Immediate(blksize));
+ __ cmpq(ofs, Immediate(blkstart + phase.repeat * blksize));
+ __ j(less, &lu);
+ } else if (phase.masked == 0) {
+ // Residual phase.
+ for (int i = 0; i < phase.unrolls; ++i) {
+ int disp = blkstart + i * vecsize * dsize;
+ gen->Load(elem[i], Operand(src, disp));
+ gen->Accumulate(op, elem[i], Operand(output, disp));
+ gen->Store(Operand(output, disp), elem[i]);
}
- }
-
- // Combine residual elements.
- int disp = main * sizeof(float);
- for (int i = 0; i < n % 8; ++i) {
- int r = i % std::max(unrolls, 1);
- __ vmovss(elem[r], Operand(src, disp));
- if (pooling_ == MAX) {
- __ vmaxss(elem[r], elem[r], Operand(output, disp));
- } else {
- __ vaddss(elem[r], elem[r], Operand(output, disp));
- }
- __ vmovss(Operand(output, disp), elem[r]);
- disp += sizeof(float);
- }
- } else {
- // Combine elements using SSE vectors.
- int main = (n / 4) * 4;
- XMMRegister elem = masm->mm().allocx();
- if (n >= 4) {
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- __ movaps(elem, Operand(src, ofs));
- if (pooling_ == MAX) {
- __ maxps(elem, Operand(output, ofs));
- } else {
- __ addps(elem, Operand(output, ofs));
- }
- __ movaps(Operand(output, ofs), elem);
- __ addq(ofs, Immediate(4 * sizeof(float)));
- __ cmpq(ofs, Immediate(main * sizeof(float)));
- __ j(less, &next);
- }
-
- // Combine residual elements.
- int disp = main * sizeof(float);
- for (int i = 0; i < n % 4; ++i) {
- __ movss(elem, Operand(src, disp));
- if (pooling_ == MAX) {
- __ maxss(elem, Operand(output, disp));
- } else {
- __ addss(elem, Operand(output, disp));
- }
- __ movss(Operand(output, disp), elem);
- disp += sizeof(float);
+ } else {
+ // Masked phase.
+ CHECK_EQ(phase.unrolls, 1);
+ gen->MaskedLoad(elem[0], Operand(src, blkstart));
+ gen->MaskedAccumulate(op, elem[0], Operand(output, blkstart));
+ gen->MaskedStore(Operand(output, blkstart), elem[0]);
}
}
@@ -1111,95 +1006,62 @@ class PoolingGather : public Kernel {
// Compute average.
if (pooling_ == AVG) {
- if (masm->Enabled(AVX512F)) {
- // Compute 1/fcnt.
- ZMMRegister scalar = masm->mm().allocz();
- __ vcvtqsi2ss(scalar.xmm(), scalar.xmm(), fcnt);
- __ vrcpss(scalar.xmm(), scalar.xmm(), scalar.xmm());
- __ vbroadcastss(scalar, scalar);
-
- // Multiply all output elements with scalar to get the average.
- std::vector elem;
- int main = (n / 16) * 16;
- int unrolls = AllocateZMMUnrolls(masm, main, 4, &elem);
- if (unrolls > 0) {
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 16 * sizeof(float);
- __ vmulps(elem[i], scalar, Operand(output, ofs, times_1, disp));
- }
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 16 * sizeof(float);
- __ vmovaps(Operand(output, ofs, times_1, disp), elem[i]);
- }
- __ addq(ofs, Immediate(16 * unrolls * sizeof(float)));
- __ cmpq(ofs, Immediate(main * sizeof(float)));
- __ j(less, &next);
+ // Compute 1/fcnt.
+ int scalar = sasm.alloc();
+ XMMRegister sr = jit::XMMRegister::from_code(scalar);
+ if (masm->Enabled(AVX)) {
+ __ vcvtqsi2ss(sr, sr, fcnt);
+ __ vrcpss(sr, sr, sr);
+ if (type == DT_DOUBLE) {
+ __ vcvtss2sd(sr, sr, sr);
}
- if (n % 16 > 0) {
- int disp = main * sizeof(float);
- __ vmulps(elem[0], scalar, Operand(output, disp),
- Mask(mask, zeroing));
- __ vmovaps(Operand(output, disp), elem[0], Mask(mask, merging));
- }
- } else if (masm->Enabled(AVX)) {
- // Compute 1/fcnt.
- YMMRegister scalar = masm->mm().allocy();
- __ vcvtqsi2ss(scalar.xmm(), scalar.xmm(), fcnt);
- __ vrcpss(scalar.xmm(), scalar.xmm(), scalar.xmm());
- if (masm->Enabled(AVX2)) {
- __ vbroadcastss(scalar, scalar);
- } else {
- __ vshufps(scalar, scalar, scalar, 0);
- __ vperm2f128(scalar, scalar, scalar, 0);
+ } else {
+ __ cvtqsi2ss(sr, fcnt);
+ __ rcpss(sr, sr);
+ if (type == DT_DOUBLE) {
+ CHECK(masm->Enabled(SSE2));
+ __ cvtss2sd(sr, sr);
}
+ }
+ sasm.main()->Broadcast(scalar, scalar);
- // Multiply all output elements with scalar to get the average.
- std::vector elem;
- int main = (n / 8) * 8;
- int unrolls = AllocateYMMUnrolls(masm, main, 4, &elem);
- if (unrolls > 0) {
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 8 * sizeof(float);
- __ vmulps(elem[i], scalar, Operand(output, ofs, times_1, disp));
+ // Multiply all output elements with scalar to get the average.
+ for (auto &phase : strategy.phases()) {
+ auto *gen = phase.generator;
+ int vecsize = gen->VectorSize();
+ int blkstart = phase.offset * dsize;
+ int blksize = phase.unrolls * vecsize * dsize;
+
+ if (phase.repeat > 1) {
+ // Repeated phase.
+ Label lu;
+ if (blkstart == 0) {
+ __ xorq(ofs, ofs);
+ } else {
+ __ movq(ofs, Immediate(blkstart));
}
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 8 * sizeof(float);
- __ vmovaps(Operand(output, ofs, times_1, disp), elem[i]);
+ __ bind(&lu);
+ for (int i = 0; i < phase.unrolls; ++i) {
+ int disp = i * vecsize * dsize;
+ gen->Mul(elem[i], scalar, Operand(output, ofs, times_1, disp));
+ gen->Store(Operand(output, ofs, times_1, disp), elem[i]);
}
- __ addq(ofs, Immediate(8 * unrolls * sizeof(float)));
- __ cmpq(ofs, Immediate(main * sizeof(float)));
- __ j(less, &next);
- }
- int disp = main * sizeof(float);
- for (int i = 0; i < n % 8; ++i) {
- int r = i % std::max(unrolls, 1);
- __ vmulss(elem[r].xmm(), scalar.xmm(), Operand(output, disp));
- __ vmovss(Operand(output, disp), elem[r].xmm());
- disp += sizeof(float);
+ __ addq(ofs, Immediate(blksize));
+ __ cmpq(ofs, Immediate(blkstart + phase.repeat * blksize));
+ __ j(less, &lu);
+ } else if (phase.masked == 0) {
+ // Residual phase.
+ for (int i = 0; i < phase.unrolls; ++i) {
+ int disp = blkstart + i * vecsize * dsize;
+ gen->Mul(elem[i], scalar, Operand(output, disp));
+ gen->Store(Operand(output, disp), elem[i]);
+ }
+ } else {
+ // Masked phase.
+ CHECK_EQ(phase.unrolls, 1);
+ gen->MaskedMul(elem[0], scalar, Operand(output, blkstart));
+ gen->MaskedStore(Operand(output, blkstart), elem[0]);
}
- } else {
- // Compute 1/fcnt.
- XMMRegister scalar = masm->mm().allocx();
- __ cvtqsi2ss(scalar, fcnt);
- __ rcpss(scalar, scalar);
-
- // Multiply all output elements with scalar to get the average.
- XMMRegister elem = masm->mm().allocx();
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- __ movss(elem, Operand(output, ofs));
- __ mulss(elem, scalar);
- __ movss(Operand(output, ofs), elem);
- __ addq(ofs, Immediate(sizeof(float)));
- __ cmpq(ofs, Immediate(v->size()));
- __ j(less, &next);
}
}
@@ -1216,90 +1078,92 @@ class PoolingGather : public Kernel {
Pooling pooling_; // pooling operation for combining vectors
};
-// Add sparse (scaled) input to variable.
-class ScatterAdd : public Kernel {
+// Accumulate sparse (scaled) input.
+class AssignAddScatter : public Kernel {
public:
- ScatterAdd(bool scale) : scale_(scale) {}
+ AssignAddScatter(bool scale) : scale_(scale) {}
string Name() override { return Operation(); }
string Operation() override {
- return scale_ ? "ScatterMulAdd" : "ScatterAdd";
+ return scale_ ? "AssignAddMulScatter" : "AssignAddScatter";
}
bool Supports(Step *step) override {
- // Requires SSE or AVX support.
- if (!CPU::Enabled(AVX) && !CPU::Enabled(SSE)) return false;
-
// Check inputs and outputs.
- if (step->indegree() != (scale_ ? 4 : 3)) return false;
- if (step->outdegree() > 1) return false;
+ Args args(step, scale_);
+ if (!args.valid) return false;
- // Check types.
- Tensor *var = step->input(0);
- Tensor *indices = step->input(1);
- Tensor *value = step->input(2);
- Tensor *scaler = scale_ ? step->input(3) : nullptr;
- Tensor *ref = step->outdegree() > 0 ? step->output(0) : nullptr;
- if (var->type() != DT_FLOAT || var->rank() != 2) return false;
- if (var->constant()) return false;
- if (indices->type() != DT_INT32 || indices->rank() != 2) return false;
- if (value->type() != DT_FLOAT) return false;
- if (value->elements() != var->dim(1)) return false;
+ // Check arguments.
+ Type type = args.var->type();
+ if (!SIMDAssembler::Supports(type)) return false;
+ if (args.var->rank() != 2) return false;
+ if (args.var->constant()) return false;
+ if (args.indices->type() != DT_INT32) return false;
+ if (args.indices->rank() != 2) return false;
+ if (args.value->type() != type || args.value->rank() != 2) return false;
+ if (args.value->dim(1) != args.var->dim(1)) return false;
+ if (args.value->dim(0) != 1 &&
+ args.value->dim(0) != args.indices->dim(1)) {
+ return false;
+ }
if (scale_) {
- if (scaler->type() != DT_FLOAT) return false;
- if (scaler->elements() != 1) return false;
+ if (args.scaler->type() != type) return false;
+ if (args.scaler->elements() != 1) return false;
}
- if (ref) {
- if (ref->type() != var->type()) return false;
- if (ref->shape() != var->shape()) return false;
- if (!ref->ref()) return false;
+ if (args.ref) {
+ if (args.ref->type() != type) return false;
+ if (args.ref->shape() != args.var->shape()) return false;
+ if (!args.ref->ref()) return false;
}
return true;
}
void Adjust(Step *step, const Options &options) override {
- Tensor *var = step->input(0);
- Tensor *value = step->input(2);
- Tensor *ref = step->outdegree() > 0 ? step->output(0) : nullptr;
+ Args args(step, scale_);
// Add sparsity bitmap index.
if (options.sparse_threshold > 0 &&
- var->dim(0) >= options.sparse_threshold &&
+ args.var->dim(0) >= options.sparse_threshold &&
step->GetAttr("sparse", true)) {
- Tensor *sparse = var->MakeSparse();
- if (ref) ref->set_sparse(sparse);
+ Tensor *sparse = args.var->MakeSparse();
+ if (args.ref) args.ref->set_sparse(sparse);
}
// Link output reference to input variable.
- if (ref) var->Link(ref);
+ if (args.ref) args.var->Link(args.ref);
- // Align to one SIMD register.
- int align = 4;
- if (CPU::Enabled(AVX)) align = 8;
- if (CPU::Enabled(AVX512F)) align = 16;
- var->SetMiniumAlignment(align * sizeof(float));
- value->SetMiniumAlignment(align * sizeof(float));
+ // Align to one vector register.
+ Type type = args.var->type();
+ int vecbytes = SIMDAssembler::VectorBytes(type);
+ args.var->SetMiniumAlignment(vecbytes);
+ args.value->SetMiniumAlignment(vecbytes);
// Embedding matrix must be row-major.
- var->RequireOrder(ROW_MAJOR);
- int minalign = 1;
- if (var->dim(1) >= 4) minalign = 4;
- if (CPU::Enabled(AVX) && var->dim(1) >= 8) minalign = 8;
- if (CPU::Enabled(AVX512F) && var->dim(1) >= 16) minalign = 16;
- var->MinAlign({1, minalign});
+ args.var->RequireOrder(ROW_MAJOR);
+
+ // Reserve registers.
+ int regs = SIMDAssembler::RegisterUsage(type) + 8;
+ step->SetRegisterUsage(regs);
}
void Generate(Step *step, MacroAssembler *masm) override {
// Get inputs.
- Tensor *var = step->input(0);
- Tensor *indices = step->input(1);
- Tensor *value = step->input(2);
- Tensor *scaler = scale_ ? step->input(3) : nullptr;
- Tensor *ref = step->outdegree() > 0 ? step->output(0) : nullptr;
- Tensor *sparse = var->sparse();
- bool single = indices->elements() == 1;
- int n = value->elements();
+ Args args(step, scale_);
+ Tensor *sparse = args.var->sparse();
+ bool single = args.indices->elements() == 1;
+ int n = args.value->dim(1);
+
+ // Create SIMD code generators.
+ Type type = args.var->type();
+ int dsize = TypeTraits::of(type).size();
+ int vecbytes = SIMDAssembler::VectorBytes(type);
+ bool aligned = args.var->stride(0) % vecbytes == 0;
+ SIMDAssembler sasm(masm, type, aligned);
+
+ // Compute vector processing strategy.
+ SIMDStrategy strategy(&sasm, n);
+ strategy.PreloadMasks();
// Allocate registers.
Register bit = masm->rr().alloc_fixed(rcx);
@@ -1312,41 +1176,28 @@ class ScatterAdd : public Kernel {
Register ofs = masm->rr().alloc();
Register src = bit;
Register aux = ofs;
-
- ZMMRegister factor = masm->mm().allocz(false);
+ auto elem = sasm.alloc(strategy.MaxUnrolls());
+ int factor = args.scaler ? sasm.alloc() : -1;
// Load tensor locations.
- __ LoadTensorAddress(varaddr, var);
- __ LoadTensorAddress(idxaddr, indices);
- __ LoadTensorAddress(valaddr, value);
+ __ LoadTensorAddress(varaddr, args.var);
+ __ LoadTensorAddress(idxaddr, args.indices);
+ __ LoadTensorAddress(valaddr, args.value);
if (sparse) {
__ LoadTensorAddress(bmaddr, sparse);
}
// Optionally output reference to assigned variable.
- if (ref != nullptr) {
- CHECK(ref->IsLocal());
- CHECK(ref->ref());
- __ movq(Operand(masm->instance(), ref->offset()), varaddr);
+ if (args.ref != nullptr) {
+ CHECK(args.ref->IsLocal());
+ CHECK(args.ref->ref());
+ __ movq(Operand(masm->instance(), args.ref->offset()), varaddr);
}
// Load scaling value.
- if (scaler) {
- __ LoadTensorAddress(src, scaler);
- if (masm->Enabled(AVX512F)) {
- __ vbroadcastss(factor, Operand(src));
- } else if (masm->Enabled(AVX)) {
- __ vbroadcastss(factor.ymm(), Operand(src));
- } else {
- __ movss(factor.xmm(), Operand(src));
- __ shufps(factor.xmm(), factor.xmm(), 0);
- }
- }
-
- // Set up mask.
- OpmaskRegister mask = masm->kk().alloc();
- if (CPU::Enabled(AVX512F) && n % 16 != 0) {
- __ LoadMask(n % 16, mask);
+ if (args.scaler) {
+ __ LoadTensorAddress(src, args.scaler);
+ sasm.main()->Broadcast(factor, Operand(src));
}
// Loop over features.
@@ -1373,149 +1224,85 @@ class ScatterAdd : public Kernel {
}
// Look up address of index in embedding.
- __ Multiply(acc, var->stride(0));
+ __ Multiply(acc, args.var->stride(0));
__ addq(acc, varaddr);
+ // Update OOV vector for missing features.
+ if (args.oov) {
+ Label l3;
+ __ jmp(&l3);
+ __ bind(&l2);
+ __ LoadTensorAddress(acc, args.oov);
+ __ bind(&l3);
+ }
+
// Add (scaled) input vector for feature to embedding vector.
- if (masm->Enabled(AVX512F)) {
- // Update elements using AVX-512 vectors.
- std::vector elem;
- int main = (n / 16) * 16;
- int unrolls = AllocateZMMUnrolls(masm, main, 4, &elem);
- if (unrolls > 0) {
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 16 * sizeof(float);
+ for (auto &phase : strategy.phases()) {
+ auto *gen = phase.generator;
+ int vecsize = gen->VectorSize();
+ int blkstart = phase.offset * dsize;
+ int blksize = phase.unrolls * vecsize * dsize;
+
+ if (phase.repeat > 1) {
+ // Repeated phase.
+ Label lu;
+ if (blkstart == 0) {
+ __ xorq(ofs, ofs);
+ } else {
+ __ movq(ofs, Immediate(blkstart));
+ }
+ __ bind(&lu);
+ for (int i = 0; i < phase.unrolls; ++i) {
+ int disp = i * vecsize * dsize;
+ gen->Load(elem[i], Operand(acc, ofs, times_1, disp));
if (scale_) {
- __ vmulps(elem[i], factor, Operand(valaddr, ofs, times_1, disp));
+ gen->MulAdd(elem[i], factor, Operand(valaddr, ofs, times_1, disp),
+ true);
} else {
- __ vmovaps(elem[i], Operand(valaddr, ofs, times_1, disp));
+ gen->Add(elem[i], elem[i], Operand(valaddr, ofs, times_1, disp));
}
+ gen->Store(Operand(acc, ofs, times_1, disp), elem[i]);
}
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 16 * sizeof(float);
- __ vaddps(elem[i], elem[i], Operand(acc, ofs, times_1, disp));
- }
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 16 * sizeof(float);
- __ vmovaps(Operand(acc, ofs, times_1, disp), elem[i]);
- }
- if (main > 16 * unrolls) {
- __ addq(ofs, Immediate(16 * unrolls * sizeof(float)));
- __ cmpq(ofs, Immediate(main * sizeof(float)));
- __ j(less, &next);
- }
- }
-
- // Update residual elements.
- if (n % 16 != 0) {
- int disp = main * sizeof(float);
- if (scale_) {
- __ vmulps(elem[0], factor, Operand(valaddr, disp),
- Mask(mask, zeroing));
- } else {
- __ vmovups(elem[0], Operand(valaddr, disp), Mask(mask, zeroing));
- }
- __ vaddps(elem[0], elem[0], Operand(acc, disp),
- Mask(mask, zeroing));
- __ vmovups(Operand(acc, disp), elem[0], Mask(mask, merging));
- }
- } else if (masm->Enabled(AVX)) {
- // Update elements using AVX vectors.
- std::vector elem;
- int main = (n / 8) * 8;
- int unrolls = AllocateYMMUnrolls(masm, main, 4, &elem);
- if (unrolls > 0) {
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 8 * sizeof(float);
+ __ addq(ofs, Immediate(blksize));
+ __ cmpq(ofs, Immediate(blkstart + phase.repeat * blksize));
+ __ j(less, &lu);
+ } else if (phase.masked == 0) {
+ // Residual phase.
+ for (int i = 0; i < phase.unrolls; ++i) {
+ int disp = blkstart + i * vecsize * dsize;
+ gen->Load(elem[i], Operand(acc, disp));
if (scale_) {
- __ vmulps(elem[i], factor.ymm(),
- Operand(valaddr, ofs, times_1, disp));
+ gen->MulAdd(elem[i], factor, Operand(valaddr, disp), true);
} else {
- __ vmovaps(elem[i], Operand(valaddr, ofs, times_1, disp));
+ gen->Add(elem[i], elem[i], Operand(valaddr, disp));
}
+ gen->Store(Operand(acc, disp), elem[i]);
}
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 8 * sizeof(float);
- __ vaddps(elem[i], elem[i], Operand(acc, ofs, times_1, disp));
- }
- for (int i = 0; i < unrolls; ++i) {
- int disp = i * 8 * sizeof(float);
- __ vmovaps(Operand(acc, ofs, times_1, disp), elem[i]);
- }
- if (main > 8 * unrolls) {
- __ addq(ofs, Immediate(8 * unrolls * sizeof(float)));
- __ cmpq(ofs, Immediate(main * sizeof(float)));
- __ j(less, &next);
- }
- }
-
- // Update residual elements.
- int disp = main * sizeof(float);
- if (n % 8 >= 4) {
- if (scale_) {
- __ vmulps(elem[0].xmm(), factor.xmm(), Operand(valaddr, disp));
- } else {
- __ vmovaps(elem[0].xmm(), Operand(valaddr, disp));
- }
- __ vaddps(elem[0].xmm(), elem[0].xmm(), Operand(acc, disp));
- __ vmovaps(Operand(acc, disp), elem[0].xmm());
- disp += 4 * sizeof(float);
- }
- for (int i = 0; i < n % 4; ++i) {
- int r = i % std::max(unrolls, 1);
+ } else {
+ // Masked phase.
+ CHECK_EQ(phase.unrolls, 1);
+ gen->MaskedLoad(elem[0], Operand(acc, blkstart));
if (scale_) {
- __ vmulss(elem[r].xmm(), factor.xmm(), Operand(valaddr, disp));
+ gen->MaskedMulAdd(elem[0], factor, Operand(valaddr, blkstart));
} else {
- __ vmovss(elem[r].xmm(), Operand(valaddr, disp));
+ gen->MaskedAdd(elem[0], elem[0], Operand(valaddr, blkstart));
}
- __ vaddss(elem[r].xmm(), elem[r].xmm(), Operand(acc, disp));
- __ vmovss(Operand(acc, disp), elem[r].xmm());
- disp += sizeof(float);
- }
- } else {
- // Update elements using SSE vectors.
- XMMRegister elem = masm->mm().allocx();
- int main = (n / 4) * 4;
- if (n >= 4) {
- Label next;
- __ xorq(ofs, ofs);
- __ bind(&next);
- __ movaps(elem, Operand(valaddr, ofs));
- if (scale_) {
- __ mulps(elem, factor.xmm());
- }
- __ addps(elem, Operand(acc, ofs));
- __ movaps(Operand(acc, ofs), elem);
- __ addq(ofs, Immediate(4 * sizeof(float)));
- __ cmpq(ofs, Immediate(main * sizeof(float)));
- __ j(less, &next);
+ gen->MaskedStore(Operand(acc, blkstart), elem[0]);
}
+ }
- // Update residual elements.
- int disp = main * sizeof(float);
- for (int i = 0; i < n % 4; ++i) {
- __ movss(elem, Operand(valaddr, disp));
- if (scale_) {
- __ mulss(elem, factor.xmm());
- }
- __ addss(elem, Operand(acc, disp));
- __ movss(Operand(acc, disp), elem);
- disp += sizeof(float);
- }
+ if (args.value->dim(0) != 1) {
+ __ addq(valaddr, Immediate(args.value->stride(0)));
}
if (!single) {
__ incq(fidx);
- __ cmpq(fidx, Immediate(indices->elements()));
+ __ cmpq(fidx, Immediate(args.indices->elements()));
__ j(less, &l1);
}
- __ bind(&l2);
+ if (args.oov == nullptr) {
+ __ bind(&l2);
+ }
}
int64 Complexity(const Step *step) override {
@@ -1525,6 +1312,36 @@ class ScatterAdd : public Kernel {
}
private:
+ // Arguments to scatter op.
+ struct Args {
+ Args(Step *step, bool scale) {
+ if (step->indegree() < 3) return;
+ if (step->outdegree() > 1) return;
+ var = step->input(0);
+ indices = step->input(1);
+ value = step->input(2);
+ if (step->outdegree() > 0) ref = step->output(0);
+
+ if (scale) {
+ if (step->indegree() != 4 && step->indegree() != 5) return;
+ if (step->indegree() > 3) scaler = step->input(3);
+ if (step->indegree() > 4) oov = step->input(4);
+ } else {
+ if (step->indegree() != 3 && step->indegree() != 4) return;
+ if (step->indegree() > 3) oov = step->input(3);
+ }
+ valid = true;
+ }
+
+ bool valid = false;
+ Tensor *var = nullptr;
+ Tensor *indices = nullptr;
+ Tensor *value = nullptr;
+ Tensor *scaler = nullptr;
+ Tensor *ref = nullptr;
+ Tensor *oov = nullptr;
+ };
+
bool scale_; // scale input
};
@@ -1986,9 +1803,33 @@ class UpdateTransformer : public Transformer {
string Name() override { return "UpdateTransformer"; }
bool Transform(Flow *flow) override {
- int updates = 0;
+ bool updated = false;
+ bool again = true;
+ while (again) {
+ again = false;
+ if (TransformMatMul(flow)) {
+ again = true;
+ updated = true;
+ }
+ if (TransformDistributiveUpdate(flow)) {
+ again = true;
+ updated = true;
+ }
+ if (TransformSparseUpdate(flow)) {
+ again = true;
+ updated = true;
+ }
+ if (TransformScaledSparseUpdate(flow)) {
+ again = true;
+ updated = true;
+ }
+ }
+ return updated;
+ }
- // Transform matrix multiplication update.
+ // Transform matrix multiplication updates.
+ bool TransformMatMul(Flow *flow) {
+ int updates = 0;
for (Flow::Operation *op : flow->Find("MatMul|1:Add|1:Assign")) {
Flow::Operation *assign = op;
Flow::Operation *add = assign->inputs[1]->producer;
@@ -2001,45 +1842,48 @@ class UpdateTransformer : public Transformer {
flow->Fuse(assign, flow->Fuse(add, matmul, ""), "AssignAddMatMul", true);
updates++;
}
+ return updates > 0;
+ }
- // Transform double sparse update.
- for (Flow::Operation *op : flow->Find("Scatter|1:Add|1:Add|1:Assign")) {
- Flow::Operation *assign = op;
- Flow::Operation *add1 = assign->inputs[1]->producer;
+ // Transform distributive scatter udates.
+ bool TransformDistributiveUpdate(Flow *flow) {
+ // Find assignments for scatter operations.
+ std::set scatter_assigns;
+ for (Flow::Operation *op : flow->Find("Scatter")) {
+ while (op->outdegree() == 1 && op->outputs[0]->usages() == 1) {
+ op = op->outputs[0]->consumers[0];
+ }
+ if (op->type == "Assign") scatter_assigns.insert(op);
+ }
+
+ // Split additive updates.
+ int updates = 0;
+ for (Flow::Operation *op : flow->Find("Add|1:Add|1:Assign")) {
+ Flow::Operation *assign1 = op;
+ Flow::Operation *add1 = assign1->inputs[1]->producer;
Flow::Operation *add2 = add1->inputs[1]->producer;
- Flow::Operation *scatter1 = add2->inputs[1]->producer;
- Flow::Operation *scatter2 = add2->inputs[0]->producer;
+ Flow::Variable *target = assign1->inputs[0];
- if (assign->inputs[0] != add1->inputs[0]) continue;
if (add1->outputs[0]->usages() != 1) continue;
if (add2->outputs[0]->usages() != 1) continue;
- if (scatter2->type != "Scatter") continue;
-
- Flow::Variable *target = assign->inputs[0];
- if (target != add1->inputs[0]) continue;
-
- Flow::Variable *s1 = scatter1->outputs[0];
- Flow::Variable *s2 = scatter2->outputs[0];
- Flow::Variable *a2 = add2->outputs[0];
- if (s1->usages() != 1) continue;
- if (s2->usages() != 1) continue;
-
- // Decompose scatter updates.
- add1->RemoveInput(a2);
- add2->RemoveInput(s1);
- add2->RemoveInput(s2);
- add1->AddInput(s1);
- add2->AddInput(target);
- add2->AddInput(s2);
- string name = flow->OpName(assign->name);
- auto *a = flow->AddOperation(add2->func, name, "Assign");
- a->AddInput(target);
- a->AddInput(a2);
-
+ if (add1->inputs[0] != target) continue;
+ if (scatter_assigns.count(assign1) == 0) continue;
+
+ // Split into two accumulative updates.
+ Flow::Function *func = assign1->func;
+ Flow::Operation *assign2 = flow->AddOperation(func, "", "Assign");
+ assign2->AddInput(target);
+ assign2->AddInput(add2->outputs[0]);
+ add1->ReplaceInput(add1->inputs[1], add2->inputs[0]);
+ add2->ReplaceInput(add2->inputs[0], target);
updates++;
}
+ return updates > 0;
+ }
- // Transform sparse update.
+ // Transform sparse updates.
+ bool TransformSparseUpdate(Flow *flow) {
+ int updates = 0;
for (Flow::Operation *op : flow->Find("Scatter|1:Add|1:Assign")) {
Flow::Operation *assign = op;
Flow::Operation *add = assign->inputs[1]->producer;
@@ -2048,20 +1892,24 @@ class UpdateTransformer : public Transformer {
if (add->outputs[0]->usages() != 1) continue;
if (scatter->outputs[0]->usages() != 1) continue;
- flow->Fuse(assign, flow->Fuse(add, scatter, ""), "ScatterAdd", true);
+ Flow::Operation *add_scatter = flow->Fuse(add, scatter, "");
+ flow->Fuse(assign, add_scatter, "AssignAddScatter", true);
updates++;
}
+ return updates > 0;
+ }
- // Transform sparse update scaling.
- for (Flow::Operation *op : flow->Find("Mul|2:ScatterAdd")) {
+ // Transform sparse update scalings.
+ bool TransformScaledSparseUpdate(Flow *flow) {
+ int updates = 0;
+ for (Flow::Operation *op : flow->Find("Mul|2:AssignAddScatter")) {
Flow::Operation *scatter = op;
Flow::Operation *mul = scatter->inputs[2]->producer;
if (scatter->indegree() != 3) continue;
if (mul->outputs[0]->usages() != 1) continue;
- flow->Fuse(scatter, mul, "ScatterMulAdd");
+ flow->Fuse(scatter, mul, "AssignAddMulScatter");
updates++;
}
-
return updates > 0;
}
};
@@ -2109,8 +1957,8 @@ void RegisterArrayKernels(Library *library) {
library->Register(new PoolingGather(PoolingGather::SUM));
library->Register(new PoolingGather(PoolingGather::AVG));
library->Register(new PoolingGather(PoolingGather::MAX));
- library->Register(new ScatterAdd(false));
- library->Register(new ScatterAdd(true));
+ library->Register(new AssignAddScatter(false));
+ library->Register(new AssignAddScatter(true));
library->Register(new Reduce("Sum", REDUCE_ADD));
library->Register(new Reduce("Product", REDUCE_MUL));
diff --git a/sling/myelin/kernel/gradients.cc b/sling/myelin/kernel/gradients.cc
index e7c2417d..5e859570 100644
--- a/sling/myelin/kernel/gradients.cc
+++ b/sling/myelin/kernel/gradients.cc
@@ -472,13 +472,26 @@ void identity_grad(Flow::Operation *op, Gradients *g) {
g->add(x, g->d(y));
}
+// y = reshape(x, shape)
+// dx = reshape(dy, shape(x))
+void reshape_grad(Flow::Operation *op, Gradients *g) {
+ auto x = op->inputs[0];
+ auto y = op->outputs[0];
+ g->add(x, g->Reshape(g->d(y), g->v(x)->shape));
+}
+
// v = gather(M, f)
// dM = scatter(dv, f)
void gather_grad(Flow::Operation *op, Gradients *g) {
auto M = op->inputs[0];
auto f = op->inputs[1];
auto v = op->outputs[0];
- g->add(M, g->Scatter(g->v(f), g->d(v), M->dim(0)));
+ if (op->indegree() == 3) {
+ auto oov = op->inputs[2];
+ g->add(M, g->Scatter(g->v(f), g->d(v), M->dim(0), g->d(oov)));
+ } else {
+ g->add(M, g->Scatter(g->v(f), g->d(v), M->dim(0)));
+ }
}
// v = gather_sum(M, f)
@@ -607,6 +620,7 @@ void RegisterStandardGradients(Transformations *library) {
library->RegisterGradient("Relu", relu_grad);
library->RegisterGradient("Norm", norm_grad);
library->RegisterGradient("Identity", identity_grad);
+ library->RegisterGradient("Reshape", reshape_grad);
library->RegisterGradient("Gather", gather_grad);
library->RegisterGradient("GatherSum", gathersum_grad);
library->RegisterGradient("ConcatV2", concat_grad);
diff --git a/sling/myelin/simd-assembler.cc b/sling/myelin/simd-assembler.cc
index b4d44bfb..3a50f76c 100644
--- a/sling/myelin/simd-assembler.cc
+++ b/sling/myelin/simd-assembler.cc
@@ -84,6 +84,12 @@ bool SIMDGenerator::SupportsUnroll() {
return true;
}
+void SIMDGenerator::Broadcast(int dst, int src) {
+ // Broadcast is just a move for scalars.
+ CHECK_EQ(VectorSize(), 1);
+ if (dst != src) Move(dst, src);
+}
+
void SIMDGenerator::Broadcast(int dst, const Operand &src) {
// Broadcast is just a load for scalars.
CHECK_EQ(VectorSize(), 1);
@@ -160,6 +166,10 @@ class AVX512FloatGenerator : public SIMDGenerator {
int VectorSize() override { return 16; }
int Alloc() override { return masm_->mm().alloc(true); }
+ void Move(int dst, int src) override {
+ masm_->vmovaps(zmm(dst), zmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
if (aligned_) {
masm_->vmovaps(zmm(dst), src);
@@ -176,6 +186,10 @@ class AVX512FloatGenerator : public SIMDGenerator {
}
}
+ void Broadcast(int dst, int src) override {
+ masm_->vbroadcastss(zmm(dst), zmm(src));
+ }
+
void Broadcast(int dst, const Operand &src) override {
masm_->vbroadcastss(zmm(dst), src);
}
@@ -212,7 +226,7 @@ class AVX512FloatGenerator : public SIMDGenerator {
if (neutral == nullptr) {
Zero(r);
} else {
- masm_->vbroadcastss(zmm(r), neutral->address());
+ Broadcast(r, neutral->address());
}
}
@@ -286,6 +300,10 @@ class AVX256FloatGenerator : public SIMDGenerator {
int VectorSize() override { return 8; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->vmovaps(ymm(dst), ymm(src));
+ }
+
void Load(int dst, const Operand &src) override {
if (aligned_) {
masm_->vmovaps(ymm(dst), src);
@@ -302,6 +320,10 @@ class AVX256FloatGenerator : public SIMDGenerator {
}
}
+ void Broadcast(int dst, int src) override {
+ masm_->vbroadcastss(ymm(dst), ymm(src));
+ }
+
void Broadcast(int dst, const Operand &src) override {
masm_->vbroadcastss(ymm(dst), src);
}
@@ -348,7 +370,7 @@ class AVX256FloatGenerator : public SIMDGenerator {
if (neutral == nullptr) {
Zero(r);
} else {
- masm_->vbroadcastss(ymm(r), neutral->address());
+ Broadcast(r, neutral->address());
}
}
@@ -378,6 +400,10 @@ class AVX128FloatGenerator : public SIMDGenerator {
int VectorSize() override { return 4; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->vmovaps(xmm(dst), xmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
if (aligned_) {
masm_->vmovaps(xmm(dst), src);
@@ -394,6 +420,10 @@ class AVX128FloatGenerator : public SIMDGenerator {
}
}
+ void Broadcast(int dst, int src) override {
+ masm_->vbroadcastss(xmm(dst), xmm(src));
+ }
+
void Broadcast(int dst, const Operand &src) override {
masm_->vbroadcastss(xmm(dst), src);
}
@@ -440,7 +470,7 @@ class AVX128FloatGenerator : public SIMDGenerator {
if (neutral == nullptr) {
Zero(r);
} else {
- masm_->vbroadcastss(xmm(r), neutral->address());
+ Broadcast(r, neutral->address());
}
}
@@ -470,6 +500,10 @@ class SSE128FloatGenerator : public SIMDGenerator {
int VectorSize() override { return 4; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->movaps(xmm(dst), xmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
if (aligned_) {
masm_->movaps(xmm(dst), src);
@@ -486,6 +520,11 @@ class SSE128FloatGenerator : public SIMDGenerator {
}
}
+ void Broadcast(int dst, int src) override {
+ if (dst != src) masm_->movss(xmm(dst), xmm(src));
+ masm_->shufps(xmm(dst), xmm(dst), 0);
+ }
+
void Broadcast(int dst, const Operand &src) override {
masm_->movss(xmm(dst), src);
masm_->shufps(xmm(dst), xmm(dst), 0);
@@ -610,6 +649,10 @@ class AVX512ScalarFloatGenerator : public SIMDGenerator {
int VectorSize() override { return 1; }
int Alloc() override { return masm_->mm().alloc(true); }
+ void Move(int dst, int src) override {
+ masm_->vmovaps(zmm(dst), zmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
masm_->vmovss(zmm(dst), src);
}
@@ -684,6 +727,10 @@ class AVXScalarFloatGenerator : public SIMDGenerator {
int VectorSize() override { return 1; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->vmovaps(xmm(dst), xmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
masm_->vmovss(xmm(dst), src);
}
@@ -778,6 +825,10 @@ class SSEScalarFloatGenerator : public SIMDGenerator {
int VectorSize() override { return 1; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->movss(xmm(dst), xmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
masm_->movss(xmm(dst), src);
}
@@ -879,6 +930,10 @@ class AVX512DoubleGenerator : public SIMDGenerator {
int VectorSize() override { return 8; }
int Alloc() override { return masm_->mm().alloc(true); }
+ void Move(int dst, int src) override {
+ masm_->vmovapd(zmm(dst), zmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
if (aligned_) {
masm_->vmovapd(zmm(dst), src);
@@ -895,6 +950,10 @@ class AVX512DoubleGenerator : public SIMDGenerator {
}
}
+ void Broadcast(int dst, int src) override {
+ masm_->vbroadcastsd(zmm(dst), zmm(src));
+ }
+
void Broadcast(int dst, const Operand &src) override {
masm_->vbroadcastsd(zmm(dst), src);
}
@@ -931,7 +990,7 @@ class AVX512DoubleGenerator : public SIMDGenerator {
if (neutral == nullptr) {
Zero(r);
} else {
- masm_->vbroadcastsd(zmm(r), neutral->address());
+ Broadcast(r, neutral->address());
}
}
@@ -1005,6 +1064,10 @@ class AVX256DoubleGenerator : public SIMDGenerator {
int VectorSize() override { return 4; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->vmovapd(ymm(dst), ymm(src));
+ }
+
void Load(int dst, const Operand &src) override {
if (aligned_) {
masm_->vmovapd(ymm(dst), src);
@@ -1021,6 +1084,10 @@ class AVX256DoubleGenerator : public SIMDGenerator {
}
}
+ void Broadcast(int dst, int src) override {
+ masm_->vbroadcastsd(ymm(dst), ymm(src));
+ }
+
void Broadcast(int dst, const Operand &src) override {
masm_->vbroadcastsd(ymm(dst), src);
}
@@ -1067,7 +1134,7 @@ class AVX256DoubleGenerator : public SIMDGenerator {
if (neutral == nullptr) {
Zero(r);
} else {
- masm_->vbroadcastsd(ymm(r), neutral->address());
+ Broadcast(r, neutral->address());
}
}
@@ -1097,6 +1164,10 @@ class AVX128DoubleGenerator : public SIMDGenerator {
int VectorSize() override { return 2; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->vmovapd(xmm(dst), xmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
if (aligned_) {
masm_->vmovapd(xmm(dst), src);
@@ -1113,6 +1184,11 @@ class AVX128DoubleGenerator : public SIMDGenerator {
}
}
+ void Broadcast(int dst, int src) override {
+ if (dst != src) masm_->vmovapd(xmm(dst), xmm(src));
+ masm_->vshufpd(xmm(dst), xmm(dst), xmm(dst), 0);
+ }
+
void Broadcast(int dst, const Operand &src) override {
masm_->vmovsd(xmm(dst), src);
masm_->vshufpd(xmm(dst), xmm(dst), xmm(dst), 0);
@@ -1190,6 +1266,10 @@ class SSE128DoubleGenerator : public SIMDGenerator {
int VectorSize() override { return 2; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->movapd(xmm(dst), xmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
if (aligned_) {
masm_->movapd(xmm(dst), src);
@@ -1206,6 +1286,11 @@ class SSE128DoubleGenerator : public SIMDGenerator {
}
}
+ void Broadcast(int dst, int src) override {
+ if (dst != src) masm_->movapd(xmm(dst), xmm(src));
+ masm_->shufpd(xmm(dst), xmm(dst), 0);
+ }
+
void Broadcast(int dst, const Operand &src) override {
masm_->movsd(xmm(dst), src);
masm_->shufpd(xmm(dst), xmm(dst), 0);
@@ -1330,6 +1415,10 @@ class AVX512ScalarDoubleGenerator : public SIMDGenerator {
int VectorSize() override { return 1; }
int Alloc() override { return masm_->mm().alloc(true); }
+ void Move(int dst, int src) override {
+ masm_->vmovapd(zmm(dst), zmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
masm_->vmovsd(zmm(dst), src);
}
@@ -1404,6 +1493,10 @@ class AVXScalarDoubleGenerator : public SIMDGenerator {
int VectorSize() override { return 1; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->vmovapd(xmm(dst), xmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
masm_->vmovsd(xmm(dst), src);
}
@@ -1498,6 +1591,10 @@ class SSEScalarDoubleGenerator : public SIMDGenerator {
int VectorSize() override { return 1; }
int Alloc() override { return masm_->mm().alloc(false); }
+ void Move(int dst, int src) override {
+ masm_->movsd(xmm(dst), xmm(src));
+ }
+
void Load(int dst, const Operand &src) override {
masm_->movsd(xmm(dst), src);
}
@@ -1595,6 +1692,10 @@ class ScalarIntSIMDGenerator : public SIMDGenerator {
int Alloc() override { return masm_->rr().alloc().code(); }
bool SupportsUnroll() override { return false; }
+ void Move(int dst, int src) override {
+ masm_->movq(reg(dst), reg(src));
+ }
+
void Load(int dst, const Operand &src) override {
switch (type_) {
case DT_INT8: masm_->movsxbq(reg(dst), src); break;
@@ -1800,9 +1901,12 @@ SIMDAssembler::SIMDAssembler(MacroAssembler *masm, Type type, bool aligned) {
add(new AVX256DoubleGenerator(masm, aligned));
add(new AVX128DoubleGenerator(masm, aligned));
add(new AVXScalarDoubleGenerator(masm, aligned));
+ } else if (masm->Enabled(SSE2)) {
+ name_ = "SSE2Dbl";
+ add(new SSE128DoubleGenerator(masm, aligned));
+ add(new SSEScalarDoubleGenerator(masm, aligned));
} else if (masm->Enabled(SSE)) {
name_ = "SSEDbl";
- add(new SSE128DoubleGenerator(masm, aligned));
add(new SSEScalarDoubleGenerator(masm, aligned));
}
break;
diff --git a/sling/myelin/simd-assembler.h b/sling/myelin/simd-assembler.h
index 4afa4657..f007a781 100644
--- a/sling/myelin/simd-assembler.h
+++ b/sling/myelin/simd-assembler.h
@@ -39,13 +39,17 @@ class SIMDGenerator {
// Allocate SIMD register.
virtual int Alloc() = 0;
+ // Move value from one regiser to another.
+ virtual void Move(int dst, int src) = 0;
+
// Load memory into register.
virtual void Load(int dst, const jit::Operand &src) = 0;
// Store register into memory.
virtual void Store(const jit::Operand &dst, int src) = 0;
- // Broadcast memory into register.
+ // Broadcast value to all elements of register.
+ virtual void Broadcast(int dst, int src);
virtual void Broadcast(int dst, const jit::Operand &src);
// Clear register.
@@ -61,7 +65,7 @@ class SIMDGenerator {
// Multiply src1 and src2 and store it in dst.
virtual void Mul(int dst, int src1, const jit::Operand &src2) = 0;
- // Multiply src1 and src2 and add it to dst. If the keep flag is false the
+ // Multiply src1 and src2 and add it to dst. If the retain flag is false the
// contents of src1 can possibly be destroyed.
virtual void MulAdd(int dst, int src1, const jit::Operand &src2,
bool retain) = 0;
diff --git a/sling/myelin/tests/opcheck.py b/sling/myelin/tests/opcheck.py
index 071b805f..9b5e3c4e 100644
--- a/sling/myelin/tests/opcheck.py
+++ b/sling/myelin/tests/opcheck.py
@@ -763,13 +763,37 @@ def gather_test(n, d, s):
v = f.gather(emb, ind)
check(flow, (n, d, s), 0, n)
+def gather_sum_test(n, d, s):
+ flow = myelin.Flow()
+ f = flow.define("gather_sum")
+ emb = f.array("emb", np.random.ranf((n, d)).astype(simulator.nptypes[dt]))
+ ind = f.var("ind", myelin.DT_INT32, [1, s])
+ v = f.gather_sum(emb, ind)
+ check(flow, (n, d, s), 0, n)
+
+def gather_max_test(n, d, s):
+ flow = myelin.Flow()
+ f = flow.define("gather_max")
+ emb = f.array("emb", np.random.ranf((n, d)).astype(simulator.nptypes[dt]))
+ ind = f.var("ind", myelin.DT_INT32, [1, s])
+ v = f.gather_max(emb, ind)
+ check(flow, (n, d, s), 0, n)
+
+def gather_avg_test(n, d, s):
+ flow = myelin.Flow()
+ f = flow.define("gather_avg")
+ emb = f.array("emb", np.random.ranf((n, d)).astype(simulator.nptypes[dt]))
+ ind = f.var("ind", myelin.DT_INT32, [1, s])
+ v = f.gather_avg(emb, ind)
+ check(flow, (n, d, s), 0, n, rtol=1e-3)
+
def scatter_add_test(n, d, s):
flow = myelin.Flow()
f = flow.define("scatter_add")
m = f.var("m", dt, [n, d])
ind = f.var("ind", myelin.DT_INT32, [1, s])
v = f.var("v", dt, [1, d])
- f.scatter_add(m, ind, v)
+ f.assign_add_scatter(m, ind, v)
check(flow, (n, d, s), 0, n, check=[m])
def negfold_test(n):
@@ -835,6 +859,15 @@ def acc_matmul_test(m, k, n):
size_test(i)
if i < 32: rank_test(i)
+ embsize = 32
+ for f in [1, 2, 5]:
+ gather_test(embsize, i, f)
+ gather_sum_test(embsize, i, f)
+ gather_max_test(embsize, i, f)
+ if dt == myelin.DT_FLOAT or dt == myelin.DT_DOUBLE:
+ gather_avg_test(embsize, i, f)
+ scatter_add_test(embsize, i, f)
+
for c in [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8]:
add_const_test(i, c)
sub_const_test(i, c)
diff --git a/sling/nlp/embedding/embedding-model.cc b/sling/nlp/embedding/embedding-model.cc
index 4c19d9d4..7a6e7618 100644
--- a/sling/nlp/embedding/embedding-model.cc
+++ b/sling/nlp/embedding/embedding-model.cc
@@ -68,7 +68,7 @@ void MikolovFlow::BuildLayer1() {
// Backprop layer 1.
tf.AssignAdd(error, tf.Mul(embed, eta));
- tf.ScatterAdd(W1, target, tf.Mul(h, eta));
+ tf.AssignAddScatter(W1, target, tf.Mul(h, eta));
}
void MikolovFlow::BuildLayer0Back() {
@@ -77,7 +77,7 @@ void MikolovFlow::BuildLayer0Back() {
l0b_l0 = tf.Instance(layer0);
l0b_l1 = tf.Instance(layer1);
- tf.ScatterAdd(W0, tf.Ref(l0b_l0, fv), tf.Ref(l0b_l1, error));
+ tf.AssignAddScatter(W0, tf.Ref(l0b_l0, fv), tf.Ref(l0b_l1, error));
}
void DualEncoderFlow::Build(const Transformations &library) {