diff --git a/doc/guide/myelin.md b/doc/guide/myelin.md index d90b84af..357212ae 100644 --- a/doc/guide/myelin.md +++ b/doc/guide/myelin.md @@ -21,10 +21,10 @@ Build system: Bazel
## Using Myelin in Python -Myelin represents a computation graph using a _flow_. The graph is divivded into +Myelin represents a computation graph using a _flow_. The graph is divided into _functions_ which can be computed independently. A function is a set of _operations_ with tensor inputs and outputs. The tensor inputs and outputs are -_variables_ in the flow. Variables can either be global constant tensor, e.g. +_variables_ in the flow. Variables can either be global constant tensors, e.g. learned weights in a neural network, or parameter tensors, which are local to the function. diff --git a/python/myelin/builder.py b/python/myelin/builder.py index 22854774..09cce185 100644 --- a/python/myelin/builder.py +++ b/python/myelin/builder.py @@ -35,7 +35,7 @@ "f": DT_FLOAT32, "d": DT_FLOAT64, "i": DT_INT32, - "l": DT_INT32, + "l": DT_INT64, "B": DT_INT8, "h": DT_INT16, "b": DT_INT8, @@ -592,8 +592,8 @@ def assign(self, x, y, name=None): op.add_input(x) op.add_input(y) - def scatter_add(self, m, f, v, ref=False, name=None): - op = self.rawop("ScatterAdd", name) + def assign_add_scatter(self, m, f, v, ref=False, name=None): + op = self.rawop("AssignAddScatter", name) op.add_input(m) op.add_input(f) op.add_input(v) diff --git a/python/myelin/simulator.py b/python/myelin/simulator.py index 002a4b2d..39d6694a 100644 --- a/python/myelin/simulator.py +++ b/python/myelin/simulator.py @@ -229,7 +229,13 @@ def compute(flow, f, data): for k in range(len(splits)): v[o[k]] = splits[k] elif op.type == "Gather": v[o[0]] = gather(v[i[0]], v[i[1]]) - elif op.type == "ScatterAdd": + elif op.type == "GatherSum": + v[o[0]] = np.sum(gather(v[i[0]], v[i[1]]), axis=1) + elif op.type == "GatherMax": + v[o[0]] = np.max(gather(v[i[0]], v[i[1]]), axis=1) + elif op.type == "GatherAvg": + v[o[0]] = np.sum(gather(v[i[0]], v[i[1]]), axis=1) / v[i[1]].shape[1] + elif op.type == "AssignAddScatter": m = v[i[0]] f = v[i[1]] x = v[i[2]] diff --git a/sling/myelin/builder.h b/sling/myelin/builder.h index 5683f8a6..51c9ee24 100644 --- a/sling/myelin/builder.h +++ b/sling/myelin/builder.h @@ -336,6 +336,10 @@ class FlowBuilder : public Scope { int n = f->rank() == 0 ? 1 : f->dim(1); return Op("Gather", {M, f}, M->type, {n, M->dim(1)}); } + Variable *Gather(Variable *M, Variable *f, Variable *oov) { + int n = f->rank() == 0 ? 1 : f->dim(1); + return Op("Gather", {M, f, oov}, M->type, {n, M->dim(1)}); + } Variable *GatherSum(Variable *M, Variable *f) { return Op("GatherSum", {M, f}, M->type, {1, M->dim(1)}); } @@ -350,6 +354,9 @@ class FlowBuilder : public Scope { Variable *Scatter(Variable *f, Variable *v, int size) { return Op("Scatter", {f, v}, v->type, {size, v->dim(1)}); } + Variable *Scatter(Variable *f, Variable *v, int size, Variable *oov) { + return Op("Scatter", {f, v, oov}, v->type, {size, v->dim(1)}); + } // Assignment. Operation *Assign(Variable *var, Variable *value) { @@ -364,8 +371,8 @@ class FlowBuilder : public Scope { return Op("Assign", {var, value})->set_ref(); } - Operation *ScatterAdd(Variable *M, Variable *f, Variable *v) { - return RawOp("ScatterAdd", {M, f, v}); + Operation *AssignAddScatter(Variable *M, Variable *f, Variable *v) { + return RawOp("AssignAddScatter", {M, f, v}); } // Concatenation. diff --git a/sling/myelin/flow.cc b/sling/myelin/flow.cc index e3bd6eb2..2e2c835b 100644 --- a/sling/myelin/flow.cc +++ b/sling/myelin/flow.cc @@ -1790,6 +1790,7 @@ Flow::Operation *Flow::AddOperation(Function *func, const string &type) { Operation *op = AddOperation(name, type); func->AddOperation(op); + if (op->name.empty()) op->name = OpName(func->name + "/" + type); return op; } @@ -1798,8 +1799,7 @@ Flow::Operation *Flow::AddOperation(Function *func, const string &type, const std::vector &inputs, const std::vector &outputs) { - Operation *op = AddOperation(name, type); - func->AddOperation(op); + Operation *op = AddOperation(func, name, type); for (auto *input : inputs) op->AddInput(input); for (auto *output : outputs) op->AddOutput(output); return op; @@ -2148,7 +2148,10 @@ Flow::Blob *Flow::DataBlock(const string &name) { string Flow::VarName(const string &prefix) { for (int n = 0;; ++n) { string name = prefix; - if (n > 0) name.append(std::to_string(n)); + if (n > 0) { + name.push_back('_'); + name.append(std::to_string(n)); + } if (Var(name) == nullptr) return name; } } @@ -2156,7 +2159,10 @@ string Flow::VarName(const string &prefix) { string Flow::OpName(const string &prefix) { for (int n = 0;; ++n) { string name = prefix; - if (n > 0) name.append(std::to_string(n)); + if (n > 0) { + name.push_back('_'); + name.append(std::to_string(n)); + } if (Op(name) == nullptr) return name; } } diff --git a/sling/myelin/kernel/array.cc b/sling/myelin/kernel/array.cc index 4f5298f8..bf28e05e 100644 --- a/sling/myelin/kernel/array.cc +++ b/sling/myelin/kernel/array.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "sling/myelin/compute.h" #include "sling/myelin/macro-assembler.h" #include "sling/myelin/simd-assembler.h" @@ -23,38 +25,6 @@ namespace myelin { using namespace jit; -// Allocate registers for unrolling. -static int SIMDUnrolls(int size, int vecsize, int max_unrolls) { - int unrolls = 0; - for (int i = 1; i <= max_unrolls; ++i) { - int batch_size = i * vecsize; - if (size >= batch_size && size % batch_size == 0) unrolls = i; - } - return unrolls; -} - -static int AllocateYMMUnrolls(MacroAssembler *masm, - int size, - int max_unrolls, - std::vector *regs) { - int unrolls = SIMDUnrolls(size, 8, max_unrolls); - for (int i = 0; i < std::max(unrolls, 1); ++i) { - regs->push_back(masm->mm().allocy()); - } - return unrolls; -} - -static int AllocateZMMUnrolls(MacroAssembler *masm, - int size, - int max_unrolls, - std::vector *regs) { - int unrolls = SIMDUnrolls(size, 16, max_unrolls); - for (int i = 0; i < std::max(unrolls, 1); ++i) { - regs->push_back(masm->mm().allocz()); - } - return unrolls; -} - // Reshape tensor while preserving the underlying data. class Reshape : public Kernel { public: @@ -564,7 +534,7 @@ class GeneralConcat : public Kernel { } }; -// Split input tensors input tensor into chunks along a dimension. +// Split input tensors into chunks along a dimension. class Split : public Kernel { public: string Name() override { return "Split"; } @@ -654,10 +624,11 @@ class SingleGather : public Kernel { Tensor *f = step->input(1); Tensor *oov = step->indegree() == 3 ? step->input(2) : nullptr; Tensor *v = step->output(0); + Type type = M->type(); if (f->type() != DT_INT32) return false; - if (M->type() != DT_FLOAT || M->rank() != 2) return false; - if (v->type() != DT_FLOAT) return false; - if (oov != nullptr && oov->type() != DT_FLOAT) return false; + if (M->rank() != 2) return false; + if (v->type() != type) return false; + if (oov != nullptr && oov->type() != type) return false; int n = f->elements(); int d = M->dim(1); int r = v->rank() - 1; @@ -755,10 +726,11 @@ class MultiGather : public Kernel { Tensor *f = step->input(1); Tensor *oov = step->indegree() == 3 ? step->input(2) : nullptr; Tensor *v = step->output(0); + Type type = M->type(); if (f->type() != DT_INT32) return false; - if (M->type() != DT_FLOAT || M->rank() != 2) return false; - if (v->type() != DT_FLOAT) return false; - if (oov != nullptr && oov->type() != DT_FLOAT) return false; + if (M->rank() != 2) return false; + if (v->type() != type) return false; + if (oov != nullptr && oov->type() != type) return false; int n = f->elements(); int d = M->dim(1); int r = v->rank() - 1; @@ -860,9 +832,6 @@ class PoolingGather : public Kernel { } bool Supports(Step *step) override { - // Requires SSE or AVX support. - if (!CPU::Enabled(AVX) && !CPU::Enabled(SSE)) return false; - // Check inputs and outputs. if (step->indegree() != 2 || step->outdegree() != 1) return false; @@ -870,9 +839,13 @@ class PoolingGather : public Kernel { Tensor *M = step->input(0); Tensor *f = step->input(1); Tensor *v = step->output(0); - if (M->type() != DT_FLOAT || M->rank() != 2) return false; + if (!SIMDAssembler::Supports(M->type()) || M->rank() != 2) return false; if (f->type() != DT_INT32 || f->rank() != 2) return false; - if (v->type() != DT_FLOAT || v->elements() != M->dim(1)) return false; + if (v->type() != M->type() || v->elements() != M->dim(1)) return false; + if (pooling_ == AVG) { + if (M->type() != DT_FLOAT && M->type() != DT_DOUBLE) return false; + if (!CPU::Enabled(SSE2)) return false; + } return true; } @@ -881,16 +854,18 @@ class PoolingGather : public Kernel { Tensor *M = step->input(0); Tensor *v = step->output(0); - // Align to one ymm/xmm register. - int align = 4; - if (CPU::Enabled(AVX)) align = 8; - if (CPU::Enabled(AVX512F)) align = 16; - M->SetMiniumAlignment(align * sizeof(float)); - v->SetMiniumAlignment(align * sizeof(float)); + // Align to one vector register. + Type type = M->type(); + int vecbytes = SIMDAssembler::VectorBytes(type); + M->SetMiniumAlignment(vecbytes); + v->SetMiniumAlignment(vecbytes); // Embedding matrix must be row-major. M->RequireOrder(ROW_MAJOR); - if (M->dim(1) >= align) M->MinAlign({1, align}); + + // Reserve registers. + int regs = SIMDAssembler::RegisterUsage(type) + 8; + step->SetRegisterUsage(regs); } void Generate(Step *step, MacroAssembler *masm) override { @@ -898,10 +873,19 @@ class PoolingGather : public Kernel { Tensor *M = step->input(0); Tensor *f = step->input(1); Tensor *v = step->output(0); - CHECK(f->IsLocal()) << f->name(); - CHECK(v->IsLocal()) << v->name(); int n = v->elements(); + // Create SIMD code generators. + Type type = M->type(); + int dsize = TypeTraits::of(type).size(); + int vecbytes = SIMDAssembler::VectorBytes(type); + bool aligned = M->stride(0) % vecbytes == 0; + SIMDAssembler sasm(masm, type, aligned); + + // Compute vector processing strategy. + SIMDStrategy strategy(&sasm, n); + strategy.PreloadMasks(); + // Allocate registers. Register acc = masm->rr().alloc_fixed(rax); Register src = masm->rr().alloc_fixed(rsi); @@ -913,6 +897,7 @@ class PoolingGather : public Kernel { Register embeddings = masm->rr().alloc(); Register input = masm->rr().alloc(); Register output = masm->rr().alloc(); + auto elem = sasm.alloc(strategy.MaxUnrolls()); // Load tensor locations. __ LoadTensorAddress(embeddings, M); @@ -925,12 +910,6 @@ class PoolingGather : public Kernel { __ xorq(fcnt, fcnt); } - // Set up mask. - OpmaskRegister mask = masm->kk().alloc(); - if (CPU::Enabled(AVX512F) && n % 16 != 0) { - __ LoadMask(n % 16, mask); - } - // Find first (non-negative) feature. Label l1, l2, done; __ bind(&l1); @@ -979,129 +958,45 @@ class PoolingGather : public Kernel { __ addq(src, acc); // Update output vector with embedding vector for feature. - if (masm->Enabled(AVX512F)) { - // Combine elements using AVX512 vectors. - std::vector elem; - int main = (n / 16) * 16; - int unrolls = AllocateZMMUnrolls(masm, main, 4, &elem); - if (unrolls > 0) { - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - for (int i = 0; i < unrolls; ++i) { - int disp = i * 16 * sizeof(float); - __ vmovaps(elem[i], Operand(src, ofs, times_1, disp)); - } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 16 * sizeof(float); - if (pooling_ == MAX) { - __ vmaxps(elem[i], elem[i], Operand(output, ofs, times_1, disp)); - } else { - __ vaddps(elem[i], elem[i], Operand(output, ofs, times_1, disp)); - } - } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 16 * sizeof(float); - __ vmovaps(Operand(output, ofs, times_1, disp), elem[i]); - } - - if (main > 16 * unrolls) { - __ addq(ofs, Immediate(16 * unrolls * sizeof(float))); - __ cmpq(ofs, Immediate(main * sizeof(float))); - __ j(less, &next); - } - } - - // Combine residual elements. - if (n % 16 > 0) { - int disp = main * sizeof(float); - __ vmovaps(elem[0], Operand(src, disp), Mask(mask, zeroing)); - if (pooling_ == MAX) { - __ vmaxps(elem[0], elem[0], Operand(output, disp), - Mask(mask, zeroing)); + Reduction op = pooling_ == MAX ? REDUCE_MAX : REDUCE_ADD; + for (auto &phase : strategy.phases()) { + auto *gen = phase.generator; + int vecsize = gen->VectorSize(); + int blkstart = phase.offset * dsize; + int blksize = phase.unrolls * vecsize * dsize; + + if (phase.repeat > 1) { + // Repeated phase. + Label lu; + if (blkstart == 0) { + __ xorq(ofs, ofs); } else { - __ vaddps(elem[0], elem[0], Operand(output, disp), - Mask(mask, zeroing)); - } - __ vmovaps(Operand(output, disp), elem[0], Mask(mask, merging)); - } - } else if (masm->Enabled(AVX)) { - // Combine elements using AVX vectors. - std::vector elem; - int main = (n / 8) * 8; - int unrolls = AllocateYMMUnrolls(masm, main, 4, &elem); - if (unrolls > 0) { - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - for (int i = 0; i < unrolls; ++i) { - int disp = i * 8 * sizeof(float); - __ vmovaps(elem[i], Operand(src, ofs, times_1, disp)); + __ movq(ofs, Immediate(blkstart)); } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 8 * sizeof(float); - if (pooling_ == MAX) { - __ vmaxps(elem[i], elem[i], Operand(output, ofs, times_1, disp)); - } else { - __ vaddps(elem[i], elem[i], Operand(output, ofs, times_1, disp)); - } - } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 8 * sizeof(float); - __ vmovaps(Operand(output, ofs, times_1, disp), elem[i]); + __ bind(&lu); + for (int i = 0; i < phase.unrolls; ++i) { + int disp = i * vecsize * dsize; + gen->Load(elem[i], Operand(src, ofs, times_1, disp)); + gen->Accumulate(op, elem[i], Operand(output, ofs, times_1, disp)); + gen->Store(Operand(output, ofs, times_1, disp), elem[i]); } - - if (main > 8 * unrolls) { - __ addq(ofs, Immediate(8 * unrolls * sizeof(float))); - __ cmpq(ofs, Immediate(main * sizeof(float))); - __ j(less, &next); + __ addq(ofs, Immediate(blksize)); + __ cmpq(ofs, Immediate(blkstart + phase.repeat * blksize)); + __ j(less, &lu); + } else if (phase.masked == 0) { + // Residual phase. + for (int i = 0; i < phase.unrolls; ++i) { + int disp = blkstart + i * vecsize * dsize; + gen->Load(elem[i], Operand(src, disp)); + gen->Accumulate(op, elem[i], Operand(output, disp)); + gen->Store(Operand(output, disp), elem[i]); } - } - - // Combine residual elements. - int disp = main * sizeof(float); - for (int i = 0; i < n % 8; ++i) { - int r = i % std::max(unrolls, 1); - __ vmovss(elem[r], Operand(src, disp)); - if (pooling_ == MAX) { - __ vmaxss(elem[r], elem[r], Operand(output, disp)); - } else { - __ vaddss(elem[r], elem[r], Operand(output, disp)); - } - __ vmovss(Operand(output, disp), elem[r]); - disp += sizeof(float); - } - } else { - // Combine elements using SSE vectors. - int main = (n / 4) * 4; - XMMRegister elem = masm->mm().allocx(); - if (n >= 4) { - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - __ movaps(elem, Operand(src, ofs)); - if (pooling_ == MAX) { - __ maxps(elem, Operand(output, ofs)); - } else { - __ addps(elem, Operand(output, ofs)); - } - __ movaps(Operand(output, ofs), elem); - __ addq(ofs, Immediate(4 * sizeof(float))); - __ cmpq(ofs, Immediate(main * sizeof(float))); - __ j(less, &next); - } - - // Combine residual elements. - int disp = main * sizeof(float); - for (int i = 0; i < n % 4; ++i) { - __ movss(elem, Operand(src, disp)); - if (pooling_ == MAX) { - __ maxss(elem, Operand(output, disp)); - } else { - __ addss(elem, Operand(output, disp)); - } - __ movss(Operand(output, disp), elem); - disp += sizeof(float); + } else { + // Masked phase. + CHECK_EQ(phase.unrolls, 1); + gen->MaskedLoad(elem[0], Operand(src, blkstart)); + gen->MaskedAccumulate(op, elem[0], Operand(output, blkstart)); + gen->MaskedStore(Operand(output, blkstart), elem[0]); } } @@ -1111,95 +1006,62 @@ class PoolingGather : public Kernel { // Compute average. if (pooling_ == AVG) { - if (masm->Enabled(AVX512F)) { - // Compute 1/fcnt. - ZMMRegister scalar = masm->mm().allocz(); - __ vcvtqsi2ss(scalar.xmm(), scalar.xmm(), fcnt); - __ vrcpss(scalar.xmm(), scalar.xmm(), scalar.xmm()); - __ vbroadcastss(scalar, scalar); - - // Multiply all output elements with scalar to get the average. - std::vector elem; - int main = (n / 16) * 16; - int unrolls = AllocateZMMUnrolls(masm, main, 4, &elem); - if (unrolls > 0) { - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - for (int i = 0; i < unrolls; ++i) { - int disp = i * 16 * sizeof(float); - __ vmulps(elem[i], scalar, Operand(output, ofs, times_1, disp)); - } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 16 * sizeof(float); - __ vmovaps(Operand(output, ofs, times_1, disp), elem[i]); - } - __ addq(ofs, Immediate(16 * unrolls * sizeof(float))); - __ cmpq(ofs, Immediate(main * sizeof(float))); - __ j(less, &next); + // Compute 1/fcnt. + int scalar = sasm.alloc(); + XMMRegister sr = jit::XMMRegister::from_code(scalar); + if (masm->Enabled(AVX)) { + __ vcvtqsi2ss(sr, sr, fcnt); + __ vrcpss(sr, sr, sr); + if (type == DT_DOUBLE) { + __ vcvtss2sd(sr, sr, sr); } - if (n % 16 > 0) { - int disp = main * sizeof(float); - __ vmulps(elem[0], scalar, Operand(output, disp), - Mask(mask, zeroing)); - __ vmovaps(Operand(output, disp), elem[0], Mask(mask, merging)); - } - } else if (masm->Enabled(AVX)) { - // Compute 1/fcnt. - YMMRegister scalar = masm->mm().allocy(); - __ vcvtqsi2ss(scalar.xmm(), scalar.xmm(), fcnt); - __ vrcpss(scalar.xmm(), scalar.xmm(), scalar.xmm()); - if (masm->Enabled(AVX2)) { - __ vbroadcastss(scalar, scalar); - } else { - __ vshufps(scalar, scalar, scalar, 0); - __ vperm2f128(scalar, scalar, scalar, 0); + } else { + __ cvtqsi2ss(sr, fcnt); + __ rcpss(sr, sr); + if (type == DT_DOUBLE) { + CHECK(masm->Enabled(SSE2)); + __ cvtss2sd(sr, sr); } + } + sasm.main()->Broadcast(scalar, scalar); - // Multiply all output elements with scalar to get the average. - std::vector elem; - int main = (n / 8) * 8; - int unrolls = AllocateYMMUnrolls(masm, main, 4, &elem); - if (unrolls > 0) { - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - for (int i = 0; i < unrolls; ++i) { - int disp = i * 8 * sizeof(float); - __ vmulps(elem[i], scalar, Operand(output, ofs, times_1, disp)); + // Multiply all output elements with scalar to get the average. + for (auto &phase : strategy.phases()) { + auto *gen = phase.generator; + int vecsize = gen->VectorSize(); + int blkstart = phase.offset * dsize; + int blksize = phase.unrolls * vecsize * dsize; + + if (phase.repeat > 1) { + // Repeated phase. + Label lu; + if (blkstart == 0) { + __ xorq(ofs, ofs); + } else { + __ movq(ofs, Immediate(blkstart)); } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 8 * sizeof(float); - __ vmovaps(Operand(output, ofs, times_1, disp), elem[i]); + __ bind(&lu); + for (int i = 0; i < phase.unrolls; ++i) { + int disp = i * vecsize * dsize; + gen->Mul(elem[i], scalar, Operand(output, ofs, times_1, disp)); + gen->Store(Operand(output, ofs, times_1, disp), elem[i]); } - __ addq(ofs, Immediate(8 * unrolls * sizeof(float))); - __ cmpq(ofs, Immediate(main * sizeof(float))); - __ j(less, &next); - } - int disp = main * sizeof(float); - for (int i = 0; i < n % 8; ++i) { - int r = i % std::max(unrolls, 1); - __ vmulss(elem[r].xmm(), scalar.xmm(), Operand(output, disp)); - __ vmovss(Operand(output, disp), elem[r].xmm()); - disp += sizeof(float); + __ addq(ofs, Immediate(blksize)); + __ cmpq(ofs, Immediate(blkstart + phase.repeat * blksize)); + __ j(less, &lu); + } else if (phase.masked == 0) { + // Residual phase. + for (int i = 0; i < phase.unrolls; ++i) { + int disp = blkstart + i * vecsize * dsize; + gen->Mul(elem[i], scalar, Operand(output, disp)); + gen->Store(Operand(output, disp), elem[i]); + } + } else { + // Masked phase. + CHECK_EQ(phase.unrolls, 1); + gen->MaskedMul(elem[0], scalar, Operand(output, blkstart)); + gen->MaskedStore(Operand(output, blkstart), elem[0]); } - } else { - // Compute 1/fcnt. - XMMRegister scalar = masm->mm().allocx(); - __ cvtqsi2ss(scalar, fcnt); - __ rcpss(scalar, scalar); - - // Multiply all output elements with scalar to get the average. - XMMRegister elem = masm->mm().allocx(); - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - __ movss(elem, Operand(output, ofs)); - __ mulss(elem, scalar); - __ movss(Operand(output, ofs), elem); - __ addq(ofs, Immediate(sizeof(float))); - __ cmpq(ofs, Immediate(v->size())); - __ j(less, &next); } } @@ -1216,90 +1078,92 @@ class PoolingGather : public Kernel { Pooling pooling_; // pooling operation for combining vectors }; -// Add sparse (scaled) input to variable. -class ScatterAdd : public Kernel { +// Accumulate sparse (scaled) input. +class AssignAddScatter : public Kernel { public: - ScatterAdd(bool scale) : scale_(scale) {} + AssignAddScatter(bool scale) : scale_(scale) {} string Name() override { return Operation(); } string Operation() override { - return scale_ ? "ScatterMulAdd" : "ScatterAdd"; + return scale_ ? "AssignAddMulScatter" : "AssignAddScatter"; } bool Supports(Step *step) override { - // Requires SSE or AVX support. - if (!CPU::Enabled(AVX) && !CPU::Enabled(SSE)) return false; - // Check inputs and outputs. - if (step->indegree() != (scale_ ? 4 : 3)) return false; - if (step->outdegree() > 1) return false; + Args args(step, scale_); + if (!args.valid) return false; - // Check types. - Tensor *var = step->input(0); - Tensor *indices = step->input(1); - Tensor *value = step->input(2); - Tensor *scaler = scale_ ? step->input(3) : nullptr; - Tensor *ref = step->outdegree() > 0 ? step->output(0) : nullptr; - if (var->type() != DT_FLOAT || var->rank() != 2) return false; - if (var->constant()) return false; - if (indices->type() != DT_INT32 || indices->rank() != 2) return false; - if (value->type() != DT_FLOAT) return false; - if (value->elements() != var->dim(1)) return false; + // Check arguments. + Type type = args.var->type(); + if (!SIMDAssembler::Supports(type)) return false; + if (args.var->rank() != 2) return false; + if (args.var->constant()) return false; + if (args.indices->type() != DT_INT32) return false; + if (args.indices->rank() != 2) return false; + if (args.value->type() != type || args.value->rank() != 2) return false; + if (args.value->dim(1) != args.var->dim(1)) return false; + if (args.value->dim(0) != 1 && + args.value->dim(0) != args.indices->dim(1)) { + return false; + } if (scale_) { - if (scaler->type() != DT_FLOAT) return false; - if (scaler->elements() != 1) return false; + if (args.scaler->type() != type) return false; + if (args.scaler->elements() != 1) return false; } - if (ref) { - if (ref->type() != var->type()) return false; - if (ref->shape() != var->shape()) return false; - if (!ref->ref()) return false; + if (args.ref) { + if (args.ref->type() != type) return false; + if (args.ref->shape() != args.var->shape()) return false; + if (!args.ref->ref()) return false; } return true; } void Adjust(Step *step, const Options &options) override { - Tensor *var = step->input(0); - Tensor *value = step->input(2); - Tensor *ref = step->outdegree() > 0 ? step->output(0) : nullptr; + Args args(step, scale_); // Add sparsity bitmap index. if (options.sparse_threshold > 0 && - var->dim(0) >= options.sparse_threshold && + args.var->dim(0) >= options.sparse_threshold && step->GetAttr("sparse", true)) { - Tensor *sparse = var->MakeSparse(); - if (ref) ref->set_sparse(sparse); + Tensor *sparse = args.var->MakeSparse(); + if (args.ref) args.ref->set_sparse(sparse); } // Link output reference to input variable. - if (ref) var->Link(ref); + if (args.ref) args.var->Link(args.ref); - // Align to one SIMD register. - int align = 4; - if (CPU::Enabled(AVX)) align = 8; - if (CPU::Enabled(AVX512F)) align = 16; - var->SetMiniumAlignment(align * sizeof(float)); - value->SetMiniumAlignment(align * sizeof(float)); + // Align to one vector register. + Type type = args.var->type(); + int vecbytes = SIMDAssembler::VectorBytes(type); + args.var->SetMiniumAlignment(vecbytes); + args.value->SetMiniumAlignment(vecbytes); // Embedding matrix must be row-major. - var->RequireOrder(ROW_MAJOR); - int minalign = 1; - if (var->dim(1) >= 4) minalign = 4; - if (CPU::Enabled(AVX) && var->dim(1) >= 8) minalign = 8; - if (CPU::Enabled(AVX512F) && var->dim(1) >= 16) minalign = 16; - var->MinAlign({1, minalign}); + args.var->RequireOrder(ROW_MAJOR); + + // Reserve registers. + int regs = SIMDAssembler::RegisterUsage(type) + 8; + step->SetRegisterUsage(regs); } void Generate(Step *step, MacroAssembler *masm) override { // Get inputs. - Tensor *var = step->input(0); - Tensor *indices = step->input(1); - Tensor *value = step->input(2); - Tensor *scaler = scale_ ? step->input(3) : nullptr; - Tensor *ref = step->outdegree() > 0 ? step->output(0) : nullptr; - Tensor *sparse = var->sparse(); - bool single = indices->elements() == 1; - int n = value->elements(); + Args args(step, scale_); + Tensor *sparse = args.var->sparse(); + bool single = args.indices->elements() == 1; + int n = args.value->dim(1); + + // Create SIMD code generators. + Type type = args.var->type(); + int dsize = TypeTraits::of(type).size(); + int vecbytes = SIMDAssembler::VectorBytes(type); + bool aligned = args.var->stride(0) % vecbytes == 0; + SIMDAssembler sasm(masm, type, aligned); + + // Compute vector processing strategy. + SIMDStrategy strategy(&sasm, n); + strategy.PreloadMasks(); // Allocate registers. Register bit = masm->rr().alloc_fixed(rcx); @@ -1312,41 +1176,28 @@ class ScatterAdd : public Kernel { Register ofs = masm->rr().alloc(); Register src = bit; Register aux = ofs; - - ZMMRegister factor = masm->mm().allocz(false); + auto elem = sasm.alloc(strategy.MaxUnrolls()); + int factor = args.scaler ? sasm.alloc() : -1; // Load tensor locations. - __ LoadTensorAddress(varaddr, var); - __ LoadTensorAddress(idxaddr, indices); - __ LoadTensorAddress(valaddr, value); + __ LoadTensorAddress(varaddr, args.var); + __ LoadTensorAddress(idxaddr, args.indices); + __ LoadTensorAddress(valaddr, args.value); if (sparse) { __ LoadTensorAddress(bmaddr, sparse); } // Optionally output reference to assigned variable. - if (ref != nullptr) { - CHECK(ref->IsLocal()); - CHECK(ref->ref()); - __ movq(Operand(masm->instance(), ref->offset()), varaddr); + if (args.ref != nullptr) { + CHECK(args.ref->IsLocal()); + CHECK(args.ref->ref()); + __ movq(Operand(masm->instance(), args.ref->offset()), varaddr); } // Load scaling value. - if (scaler) { - __ LoadTensorAddress(src, scaler); - if (masm->Enabled(AVX512F)) { - __ vbroadcastss(factor, Operand(src)); - } else if (masm->Enabled(AVX)) { - __ vbroadcastss(factor.ymm(), Operand(src)); - } else { - __ movss(factor.xmm(), Operand(src)); - __ shufps(factor.xmm(), factor.xmm(), 0); - } - } - - // Set up mask. - OpmaskRegister mask = masm->kk().alloc(); - if (CPU::Enabled(AVX512F) && n % 16 != 0) { - __ LoadMask(n % 16, mask); + if (args.scaler) { + __ LoadTensorAddress(src, args.scaler); + sasm.main()->Broadcast(factor, Operand(src)); } // Loop over features. @@ -1373,149 +1224,85 @@ class ScatterAdd : public Kernel { } // Look up address of index in embedding. - __ Multiply(acc, var->stride(0)); + __ Multiply(acc, args.var->stride(0)); __ addq(acc, varaddr); + // Update OOV vector for missing features. + if (args.oov) { + Label l3; + __ jmp(&l3); + __ bind(&l2); + __ LoadTensorAddress(acc, args.oov); + __ bind(&l3); + } + // Add (scaled) input vector for feature to embedding vector. - if (masm->Enabled(AVX512F)) { - // Update elements using AVX-512 vectors. - std::vector elem; - int main = (n / 16) * 16; - int unrolls = AllocateZMMUnrolls(masm, main, 4, &elem); - if (unrolls > 0) { - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - for (int i = 0; i < unrolls; ++i) { - int disp = i * 16 * sizeof(float); + for (auto &phase : strategy.phases()) { + auto *gen = phase.generator; + int vecsize = gen->VectorSize(); + int blkstart = phase.offset * dsize; + int blksize = phase.unrolls * vecsize * dsize; + + if (phase.repeat > 1) { + // Repeated phase. + Label lu; + if (blkstart == 0) { + __ xorq(ofs, ofs); + } else { + __ movq(ofs, Immediate(blkstart)); + } + __ bind(&lu); + for (int i = 0; i < phase.unrolls; ++i) { + int disp = i * vecsize * dsize; + gen->Load(elem[i], Operand(acc, ofs, times_1, disp)); if (scale_) { - __ vmulps(elem[i], factor, Operand(valaddr, ofs, times_1, disp)); + gen->MulAdd(elem[i], factor, Operand(valaddr, ofs, times_1, disp), + true); } else { - __ vmovaps(elem[i], Operand(valaddr, ofs, times_1, disp)); + gen->Add(elem[i], elem[i], Operand(valaddr, ofs, times_1, disp)); } + gen->Store(Operand(acc, ofs, times_1, disp), elem[i]); } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 16 * sizeof(float); - __ vaddps(elem[i], elem[i], Operand(acc, ofs, times_1, disp)); - } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 16 * sizeof(float); - __ vmovaps(Operand(acc, ofs, times_1, disp), elem[i]); - } - if (main > 16 * unrolls) { - __ addq(ofs, Immediate(16 * unrolls * sizeof(float))); - __ cmpq(ofs, Immediate(main * sizeof(float))); - __ j(less, &next); - } - } - - // Update residual elements. - if (n % 16 != 0) { - int disp = main * sizeof(float); - if (scale_) { - __ vmulps(elem[0], factor, Operand(valaddr, disp), - Mask(mask, zeroing)); - } else { - __ vmovups(elem[0], Operand(valaddr, disp), Mask(mask, zeroing)); - } - __ vaddps(elem[0], elem[0], Operand(acc, disp), - Mask(mask, zeroing)); - __ vmovups(Operand(acc, disp), elem[0], Mask(mask, merging)); - } - } else if (masm->Enabled(AVX)) { - // Update elements using AVX vectors. - std::vector elem; - int main = (n / 8) * 8; - int unrolls = AllocateYMMUnrolls(masm, main, 4, &elem); - if (unrolls > 0) { - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - for (int i = 0; i < unrolls; ++i) { - int disp = i * 8 * sizeof(float); + __ addq(ofs, Immediate(blksize)); + __ cmpq(ofs, Immediate(blkstart + phase.repeat * blksize)); + __ j(less, &lu); + } else if (phase.masked == 0) { + // Residual phase. + for (int i = 0; i < phase.unrolls; ++i) { + int disp = blkstart + i * vecsize * dsize; + gen->Load(elem[i], Operand(acc, disp)); if (scale_) { - __ vmulps(elem[i], factor.ymm(), - Operand(valaddr, ofs, times_1, disp)); + gen->MulAdd(elem[i], factor, Operand(valaddr, disp), true); } else { - __ vmovaps(elem[i], Operand(valaddr, ofs, times_1, disp)); + gen->Add(elem[i], elem[i], Operand(valaddr, disp)); } + gen->Store(Operand(acc, disp), elem[i]); } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 8 * sizeof(float); - __ vaddps(elem[i], elem[i], Operand(acc, ofs, times_1, disp)); - } - for (int i = 0; i < unrolls; ++i) { - int disp = i * 8 * sizeof(float); - __ vmovaps(Operand(acc, ofs, times_1, disp), elem[i]); - } - if (main > 8 * unrolls) { - __ addq(ofs, Immediate(8 * unrolls * sizeof(float))); - __ cmpq(ofs, Immediate(main * sizeof(float))); - __ j(less, &next); - } - } - - // Update residual elements. - int disp = main * sizeof(float); - if (n % 8 >= 4) { - if (scale_) { - __ vmulps(elem[0].xmm(), factor.xmm(), Operand(valaddr, disp)); - } else { - __ vmovaps(elem[0].xmm(), Operand(valaddr, disp)); - } - __ vaddps(elem[0].xmm(), elem[0].xmm(), Operand(acc, disp)); - __ vmovaps(Operand(acc, disp), elem[0].xmm()); - disp += 4 * sizeof(float); - } - for (int i = 0; i < n % 4; ++i) { - int r = i % std::max(unrolls, 1); + } else { + // Masked phase. + CHECK_EQ(phase.unrolls, 1); + gen->MaskedLoad(elem[0], Operand(acc, blkstart)); if (scale_) { - __ vmulss(elem[r].xmm(), factor.xmm(), Operand(valaddr, disp)); + gen->MaskedMulAdd(elem[0], factor, Operand(valaddr, blkstart)); } else { - __ vmovss(elem[r].xmm(), Operand(valaddr, disp)); + gen->MaskedAdd(elem[0], elem[0], Operand(valaddr, blkstart)); } - __ vaddss(elem[r].xmm(), elem[r].xmm(), Operand(acc, disp)); - __ vmovss(Operand(acc, disp), elem[r].xmm()); - disp += sizeof(float); - } - } else { - // Update elements using SSE vectors. - XMMRegister elem = masm->mm().allocx(); - int main = (n / 4) * 4; - if (n >= 4) { - Label next; - __ xorq(ofs, ofs); - __ bind(&next); - __ movaps(elem, Operand(valaddr, ofs)); - if (scale_) { - __ mulps(elem, factor.xmm()); - } - __ addps(elem, Operand(acc, ofs)); - __ movaps(Operand(acc, ofs), elem); - __ addq(ofs, Immediate(4 * sizeof(float))); - __ cmpq(ofs, Immediate(main * sizeof(float))); - __ j(less, &next); + gen->MaskedStore(Operand(acc, blkstart), elem[0]); } + } - // Update residual elements. - int disp = main * sizeof(float); - for (int i = 0; i < n % 4; ++i) { - __ movss(elem, Operand(valaddr, disp)); - if (scale_) { - __ mulss(elem, factor.xmm()); - } - __ addss(elem, Operand(acc, disp)); - __ movss(Operand(acc, disp), elem); - disp += sizeof(float); - } + if (args.value->dim(0) != 1) { + __ addq(valaddr, Immediate(args.value->stride(0))); } if (!single) { __ incq(fidx); - __ cmpq(fidx, Immediate(indices->elements())); + __ cmpq(fidx, Immediate(args.indices->elements())); __ j(less, &l1); } - __ bind(&l2); + if (args.oov == nullptr) { + __ bind(&l2); + } } int64 Complexity(const Step *step) override { @@ -1525,6 +1312,36 @@ class ScatterAdd : public Kernel { } private: + // Arguments to scatter op. + struct Args { + Args(Step *step, bool scale) { + if (step->indegree() < 3) return; + if (step->outdegree() > 1) return; + var = step->input(0); + indices = step->input(1); + value = step->input(2); + if (step->outdegree() > 0) ref = step->output(0); + + if (scale) { + if (step->indegree() != 4 && step->indegree() != 5) return; + if (step->indegree() > 3) scaler = step->input(3); + if (step->indegree() > 4) oov = step->input(4); + } else { + if (step->indegree() != 3 && step->indegree() != 4) return; + if (step->indegree() > 3) oov = step->input(3); + } + valid = true; + } + + bool valid = false; + Tensor *var = nullptr; + Tensor *indices = nullptr; + Tensor *value = nullptr; + Tensor *scaler = nullptr; + Tensor *ref = nullptr; + Tensor *oov = nullptr; + }; + bool scale_; // scale input }; @@ -1986,9 +1803,33 @@ class UpdateTransformer : public Transformer { string Name() override { return "UpdateTransformer"; } bool Transform(Flow *flow) override { - int updates = 0; + bool updated = false; + bool again = true; + while (again) { + again = false; + if (TransformMatMul(flow)) { + again = true; + updated = true; + } + if (TransformDistributiveUpdate(flow)) { + again = true; + updated = true; + } + if (TransformSparseUpdate(flow)) { + again = true; + updated = true; + } + if (TransformScaledSparseUpdate(flow)) { + again = true; + updated = true; + } + } + return updated; + } - // Transform matrix multiplication update. + // Transform matrix multiplication updates. + bool TransformMatMul(Flow *flow) { + int updates = 0; for (Flow::Operation *op : flow->Find("MatMul|1:Add|1:Assign")) { Flow::Operation *assign = op; Flow::Operation *add = assign->inputs[1]->producer; @@ -2001,45 +1842,48 @@ class UpdateTransformer : public Transformer { flow->Fuse(assign, flow->Fuse(add, matmul, ""), "AssignAddMatMul", true); updates++; } + return updates > 0; + } - // Transform double sparse update. - for (Flow::Operation *op : flow->Find("Scatter|1:Add|1:Add|1:Assign")) { - Flow::Operation *assign = op; - Flow::Operation *add1 = assign->inputs[1]->producer; + // Transform distributive scatter udates. + bool TransformDistributiveUpdate(Flow *flow) { + // Find assignments for scatter operations. + std::set scatter_assigns; + for (Flow::Operation *op : flow->Find("Scatter")) { + while (op->outdegree() == 1 && op->outputs[0]->usages() == 1) { + op = op->outputs[0]->consumers[0]; + } + if (op->type == "Assign") scatter_assigns.insert(op); + } + + // Split additive updates. + int updates = 0; + for (Flow::Operation *op : flow->Find("Add|1:Add|1:Assign")) { + Flow::Operation *assign1 = op; + Flow::Operation *add1 = assign1->inputs[1]->producer; Flow::Operation *add2 = add1->inputs[1]->producer; - Flow::Operation *scatter1 = add2->inputs[1]->producer; - Flow::Operation *scatter2 = add2->inputs[0]->producer; + Flow::Variable *target = assign1->inputs[0]; - if (assign->inputs[0] != add1->inputs[0]) continue; if (add1->outputs[0]->usages() != 1) continue; if (add2->outputs[0]->usages() != 1) continue; - if (scatter2->type != "Scatter") continue; - - Flow::Variable *target = assign->inputs[0]; - if (target != add1->inputs[0]) continue; - - Flow::Variable *s1 = scatter1->outputs[0]; - Flow::Variable *s2 = scatter2->outputs[0]; - Flow::Variable *a2 = add2->outputs[0]; - if (s1->usages() != 1) continue; - if (s2->usages() != 1) continue; - - // Decompose scatter updates. - add1->RemoveInput(a2); - add2->RemoveInput(s1); - add2->RemoveInput(s2); - add1->AddInput(s1); - add2->AddInput(target); - add2->AddInput(s2); - string name = flow->OpName(assign->name); - auto *a = flow->AddOperation(add2->func, name, "Assign"); - a->AddInput(target); - a->AddInput(a2); - + if (add1->inputs[0] != target) continue; + if (scatter_assigns.count(assign1) == 0) continue; + + // Split into two accumulative updates. + Flow::Function *func = assign1->func; + Flow::Operation *assign2 = flow->AddOperation(func, "", "Assign"); + assign2->AddInput(target); + assign2->AddInput(add2->outputs[0]); + add1->ReplaceInput(add1->inputs[1], add2->inputs[0]); + add2->ReplaceInput(add2->inputs[0], target); updates++; } + return updates > 0; + } - // Transform sparse update. + // Transform sparse updates. + bool TransformSparseUpdate(Flow *flow) { + int updates = 0; for (Flow::Operation *op : flow->Find("Scatter|1:Add|1:Assign")) { Flow::Operation *assign = op; Flow::Operation *add = assign->inputs[1]->producer; @@ -2048,20 +1892,24 @@ class UpdateTransformer : public Transformer { if (add->outputs[0]->usages() != 1) continue; if (scatter->outputs[0]->usages() != 1) continue; - flow->Fuse(assign, flow->Fuse(add, scatter, ""), "ScatterAdd", true); + Flow::Operation *add_scatter = flow->Fuse(add, scatter, ""); + flow->Fuse(assign, add_scatter, "AssignAddScatter", true); updates++; } + return updates > 0; + } - // Transform sparse update scaling. - for (Flow::Operation *op : flow->Find("Mul|2:ScatterAdd")) { + // Transform sparse update scalings. + bool TransformScaledSparseUpdate(Flow *flow) { + int updates = 0; + for (Flow::Operation *op : flow->Find("Mul|2:AssignAddScatter")) { Flow::Operation *scatter = op; Flow::Operation *mul = scatter->inputs[2]->producer; if (scatter->indegree() != 3) continue; if (mul->outputs[0]->usages() != 1) continue; - flow->Fuse(scatter, mul, "ScatterMulAdd"); + flow->Fuse(scatter, mul, "AssignAddMulScatter"); updates++; } - return updates > 0; } }; @@ -2109,8 +1957,8 @@ void RegisterArrayKernels(Library *library) { library->Register(new PoolingGather(PoolingGather::SUM)); library->Register(new PoolingGather(PoolingGather::AVG)); library->Register(new PoolingGather(PoolingGather::MAX)); - library->Register(new ScatterAdd(false)); - library->Register(new ScatterAdd(true)); + library->Register(new AssignAddScatter(false)); + library->Register(new AssignAddScatter(true)); library->Register(new Reduce("Sum", REDUCE_ADD)); library->Register(new Reduce("Product", REDUCE_MUL)); diff --git a/sling/myelin/kernel/gradients.cc b/sling/myelin/kernel/gradients.cc index e7c2417d..5e859570 100644 --- a/sling/myelin/kernel/gradients.cc +++ b/sling/myelin/kernel/gradients.cc @@ -472,13 +472,26 @@ void identity_grad(Flow::Operation *op, Gradients *g) { g->add(x, g->d(y)); } +// y = reshape(x, shape) +// dx = reshape(dy, shape(x)) +void reshape_grad(Flow::Operation *op, Gradients *g) { + auto x = op->inputs[0]; + auto y = op->outputs[0]; + g->add(x, g->Reshape(g->d(y), g->v(x)->shape)); +} + // v = gather(M, f) // dM = scatter(dv, f) void gather_grad(Flow::Operation *op, Gradients *g) { auto M = op->inputs[0]; auto f = op->inputs[1]; auto v = op->outputs[0]; - g->add(M, g->Scatter(g->v(f), g->d(v), M->dim(0))); + if (op->indegree() == 3) { + auto oov = op->inputs[2]; + g->add(M, g->Scatter(g->v(f), g->d(v), M->dim(0), g->d(oov))); + } else { + g->add(M, g->Scatter(g->v(f), g->d(v), M->dim(0))); + } } // v = gather_sum(M, f) @@ -607,6 +620,7 @@ void RegisterStandardGradients(Transformations *library) { library->RegisterGradient("Relu", relu_grad); library->RegisterGradient("Norm", norm_grad); library->RegisterGradient("Identity", identity_grad); + library->RegisterGradient("Reshape", reshape_grad); library->RegisterGradient("Gather", gather_grad); library->RegisterGradient("GatherSum", gathersum_grad); library->RegisterGradient("ConcatV2", concat_grad); diff --git a/sling/myelin/simd-assembler.cc b/sling/myelin/simd-assembler.cc index b4d44bfb..3a50f76c 100644 --- a/sling/myelin/simd-assembler.cc +++ b/sling/myelin/simd-assembler.cc @@ -84,6 +84,12 @@ bool SIMDGenerator::SupportsUnroll() { return true; } +void SIMDGenerator::Broadcast(int dst, int src) { + // Broadcast is just a move for scalars. + CHECK_EQ(VectorSize(), 1); + if (dst != src) Move(dst, src); +} + void SIMDGenerator::Broadcast(int dst, const Operand &src) { // Broadcast is just a load for scalars. CHECK_EQ(VectorSize(), 1); @@ -160,6 +166,10 @@ class AVX512FloatGenerator : public SIMDGenerator { int VectorSize() override { return 16; } int Alloc() override { return masm_->mm().alloc(true); } + void Move(int dst, int src) override { + masm_->vmovaps(zmm(dst), zmm(src)); + } + void Load(int dst, const Operand &src) override { if (aligned_) { masm_->vmovaps(zmm(dst), src); @@ -176,6 +186,10 @@ class AVX512FloatGenerator : public SIMDGenerator { } } + void Broadcast(int dst, int src) override { + masm_->vbroadcastss(zmm(dst), zmm(src)); + } + void Broadcast(int dst, const Operand &src) override { masm_->vbroadcastss(zmm(dst), src); } @@ -212,7 +226,7 @@ class AVX512FloatGenerator : public SIMDGenerator { if (neutral == nullptr) { Zero(r); } else { - masm_->vbroadcastss(zmm(r), neutral->address()); + Broadcast(r, neutral->address()); } } @@ -286,6 +300,10 @@ class AVX256FloatGenerator : public SIMDGenerator { int VectorSize() override { return 8; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->vmovaps(ymm(dst), ymm(src)); + } + void Load(int dst, const Operand &src) override { if (aligned_) { masm_->vmovaps(ymm(dst), src); @@ -302,6 +320,10 @@ class AVX256FloatGenerator : public SIMDGenerator { } } + void Broadcast(int dst, int src) override { + masm_->vbroadcastss(ymm(dst), ymm(src)); + } + void Broadcast(int dst, const Operand &src) override { masm_->vbroadcastss(ymm(dst), src); } @@ -348,7 +370,7 @@ class AVX256FloatGenerator : public SIMDGenerator { if (neutral == nullptr) { Zero(r); } else { - masm_->vbroadcastss(ymm(r), neutral->address()); + Broadcast(r, neutral->address()); } } @@ -378,6 +400,10 @@ class AVX128FloatGenerator : public SIMDGenerator { int VectorSize() override { return 4; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->vmovaps(xmm(dst), xmm(src)); + } + void Load(int dst, const Operand &src) override { if (aligned_) { masm_->vmovaps(xmm(dst), src); @@ -394,6 +420,10 @@ class AVX128FloatGenerator : public SIMDGenerator { } } + void Broadcast(int dst, int src) override { + masm_->vbroadcastss(xmm(dst), xmm(src)); + } + void Broadcast(int dst, const Operand &src) override { masm_->vbroadcastss(xmm(dst), src); } @@ -440,7 +470,7 @@ class AVX128FloatGenerator : public SIMDGenerator { if (neutral == nullptr) { Zero(r); } else { - masm_->vbroadcastss(xmm(r), neutral->address()); + Broadcast(r, neutral->address()); } } @@ -470,6 +500,10 @@ class SSE128FloatGenerator : public SIMDGenerator { int VectorSize() override { return 4; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->movaps(xmm(dst), xmm(src)); + } + void Load(int dst, const Operand &src) override { if (aligned_) { masm_->movaps(xmm(dst), src); @@ -486,6 +520,11 @@ class SSE128FloatGenerator : public SIMDGenerator { } } + void Broadcast(int dst, int src) override { + if (dst != src) masm_->movss(xmm(dst), xmm(src)); + masm_->shufps(xmm(dst), xmm(dst), 0); + } + void Broadcast(int dst, const Operand &src) override { masm_->movss(xmm(dst), src); masm_->shufps(xmm(dst), xmm(dst), 0); @@ -610,6 +649,10 @@ class AVX512ScalarFloatGenerator : public SIMDGenerator { int VectorSize() override { return 1; } int Alloc() override { return masm_->mm().alloc(true); } + void Move(int dst, int src) override { + masm_->vmovaps(zmm(dst), zmm(src)); + } + void Load(int dst, const Operand &src) override { masm_->vmovss(zmm(dst), src); } @@ -684,6 +727,10 @@ class AVXScalarFloatGenerator : public SIMDGenerator { int VectorSize() override { return 1; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->vmovaps(xmm(dst), xmm(src)); + } + void Load(int dst, const Operand &src) override { masm_->vmovss(xmm(dst), src); } @@ -778,6 +825,10 @@ class SSEScalarFloatGenerator : public SIMDGenerator { int VectorSize() override { return 1; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->movss(xmm(dst), xmm(src)); + } + void Load(int dst, const Operand &src) override { masm_->movss(xmm(dst), src); } @@ -879,6 +930,10 @@ class AVX512DoubleGenerator : public SIMDGenerator { int VectorSize() override { return 8; } int Alloc() override { return masm_->mm().alloc(true); } + void Move(int dst, int src) override { + masm_->vmovapd(zmm(dst), zmm(src)); + } + void Load(int dst, const Operand &src) override { if (aligned_) { masm_->vmovapd(zmm(dst), src); @@ -895,6 +950,10 @@ class AVX512DoubleGenerator : public SIMDGenerator { } } + void Broadcast(int dst, int src) override { + masm_->vbroadcastsd(zmm(dst), zmm(src)); + } + void Broadcast(int dst, const Operand &src) override { masm_->vbroadcastsd(zmm(dst), src); } @@ -931,7 +990,7 @@ class AVX512DoubleGenerator : public SIMDGenerator { if (neutral == nullptr) { Zero(r); } else { - masm_->vbroadcastsd(zmm(r), neutral->address()); + Broadcast(r, neutral->address()); } } @@ -1005,6 +1064,10 @@ class AVX256DoubleGenerator : public SIMDGenerator { int VectorSize() override { return 4; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->vmovapd(ymm(dst), ymm(src)); + } + void Load(int dst, const Operand &src) override { if (aligned_) { masm_->vmovapd(ymm(dst), src); @@ -1021,6 +1084,10 @@ class AVX256DoubleGenerator : public SIMDGenerator { } } + void Broadcast(int dst, int src) override { + masm_->vbroadcastsd(ymm(dst), ymm(src)); + } + void Broadcast(int dst, const Operand &src) override { masm_->vbroadcastsd(ymm(dst), src); } @@ -1067,7 +1134,7 @@ class AVX256DoubleGenerator : public SIMDGenerator { if (neutral == nullptr) { Zero(r); } else { - masm_->vbroadcastsd(ymm(r), neutral->address()); + Broadcast(r, neutral->address()); } } @@ -1097,6 +1164,10 @@ class AVX128DoubleGenerator : public SIMDGenerator { int VectorSize() override { return 2; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->vmovapd(xmm(dst), xmm(src)); + } + void Load(int dst, const Operand &src) override { if (aligned_) { masm_->vmovapd(xmm(dst), src); @@ -1113,6 +1184,11 @@ class AVX128DoubleGenerator : public SIMDGenerator { } } + void Broadcast(int dst, int src) override { + if (dst != src) masm_->vmovapd(xmm(dst), xmm(src)); + masm_->vshufpd(xmm(dst), xmm(dst), xmm(dst), 0); + } + void Broadcast(int dst, const Operand &src) override { masm_->vmovsd(xmm(dst), src); masm_->vshufpd(xmm(dst), xmm(dst), xmm(dst), 0); @@ -1190,6 +1266,10 @@ class SSE128DoubleGenerator : public SIMDGenerator { int VectorSize() override { return 2; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->movapd(xmm(dst), xmm(src)); + } + void Load(int dst, const Operand &src) override { if (aligned_) { masm_->movapd(xmm(dst), src); @@ -1206,6 +1286,11 @@ class SSE128DoubleGenerator : public SIMDGenerator { } } + void Broadcast(int dst, int src) override { + if (dst != src) masm_->movapd(xmm(dst), xmm(src)); + masm_->shufpd(xmm(dst), xmm(dst), 0); + } + void Broadcast(int dst, const Operand &src) override { masm_->movsd(xmm(dst), src); masm_->shufpd(xmm(dst), xmm(dst), 0); @@ -1330,6 +1415,10 @@ class AVX512ScalarDoubleGenerator : public SIMDGenerator { int VectorSize() override { return 1; } int Alloc() override { return masm_->mm().alloc(true); } + void Move(int dst, int src) override { + masm_->vmovapd(zmm(dst), zmm(src)); + } + void Load(int dst, const Operand &src) override { masm_->vmovsd(zmm(dst), src); } @@ -1404,6 +1493,10 @@ class AVXScalarDoubleGenerator : public SIMDGenerator { int VectorSize() override { return 1; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->vmovapd(xmm(dst), xmm(src)); + } + void Load(int dst, const Operand &src) override { masm_->vmovsd(xmm(dst), src); } @@ -1498,6 +1591,10 @@ class SSEScalarDoubleGenerator : public SIMDGenerator { int VectorSize() override { return 1; } int Alloc() override { return masm_->mm().alloc(false); } + void Move(int dst, int src) override { + masm_->movsd(xmm(dst), xmm(src)); + } + void Load(int dst, const Operand &src) override { masm_->movsd(xmm(dst), src); } @@ -1595,6 +1692,10 @@ class ScalarIntSIMDGenerator : public SIMDGenerator { int Alloc() override { return masm_->rr().alloc().code(); } bool SupportsUnroll() override { return false; } + void Move(int dst, int src) override { + masm_->movq(reg(dst), reg(src)); + } + void Load(int dst, const Operand &src) override { switch (type_) { case DT_INT8: masm_->movsxbq(reg(dst), src); break; @@ -1800,9 +1901,12 @@ SIMDAssembler::SIMDAssembler(MacroAssembler *masm, Type type, bool aligned) { add(new AVX256DoubleGenerator(masm, aligned)); add(new AVX128DoubleGenerator(masm, aligned)); add(new AVXScalarDoubleGenerator(masm, aligned)); + } else if (masm->Enabled(SSE2)) { + name_ = "SSE2Dbl"; + add(new SSE128DoubleGenerator(masm, aligned)); + add(new SSEScalarDoubleGenerator(masm, aligned)); } else if (masm->Enabled(SSE)) { name_ = "SSEDbl"; - add(new SSE128DoubleGenerator(masm, aligned)); add(new SSEScalarDoubleGenerator(masm, aligned)); } break; diff --git a/sling/myelin/simd-assembler.h b/sling/myelin/simd-assembler.h index 4afa4657..f007a781 100644 --- a/sling/myelin/simd-assembler.h +++ b/sling/myelin/simd-assembler.h @@ -39,13 +39,17 @@ class SIMDGenerator { // Allocate SIMD register. virtual int Alloc() = 0; + // Move value from one regiser to another. + virtual void Move(int dst, int src) = 0; + // Load memory into register. virtual void Load(int dst, const jit::Operand &src) = 0; // Store register into memory. virtual void Store(const jit::Operand &dst, int src) = 0; - // Broadcast memory into register. + // Broadcast value to all elements of register. + virtual void Broadcast(int dst, int src); virtual void Broadcast(int dst, const jit::Operand &src); // Clear register. @@ -61,7 +65,7 @@ class SIMDGenerator { // Multiply src1 and src2 and store it in dst. virtual void Mul(int dst, int src1, const jit::Operand &src2) = 0; - // Multiply src1 and src2 and add it to dst. If the keep flag is false the + // Multiply src1 and src2 and add it to dst. If the retain flag is false the // contents of src1 can possibly be destroyed. virtual void MulAdd(int dst, int src1, const jit::Operand &src2, bool retain) = 0; diff --git a/sling/myelin/tests/opcheck.py b/sling/myelin/tests/opcheck.py index 071b805f..9b5e3c4e 100644 --- a/sling/myelin/tests/opcheck.py +++ b/sling/myelin/tests/opcheck.py @@ -763,13 +763,37 @@ def gather_test(n, d, s): v = f.gather(emb, ind) check(flow, (n, d, s), 0, n) +def gather_sum_test(n, d, s): + flow = myelin.Flow() + f = flow.define("gather_sum") + emb = f.array("emb", np.random.ranf((n, d)).astype(simulator.nptypes[dt])) + ind = f.var("ind", myelin.DT_INT32, [1, s]) + v = f.gather_sum(emb, ind) + check(flow, (n, d, s), 0, n) + +def gather_max_test(n, d, s): + flow = myelin.Flow() + f = flow.define("gather_max") + emb = f.array("emb", np.random.ranf((n, d)).astype(simulator.nptypes[dt])) + ind = f.var("ind", myelin.DT_INT32, [1, s]) + v = f.gather_max(emb, ind) + check(flow, (n, d, s), 0, n) + +def gather_avg_test(n, d, s): + flow = myelin.Flow() + f = flow.define("gather_avg") + emb = f.array("emb", np.random.ranf((n, d)).astype(simulator.nptypes[dt])) + ind = f.var("ind", myelin.DT_INT32, [1, s]) + v = f.gather_avg(emb, ind) + check(flow, (n, d, s), 0, n, rtol=1e-3) + def scatter_add_test(n, d, s): flow = myelin.Flow() f = flow.define("scatter_add") m = f.var("m", dt, [n, d]) ind = f.var("ind", myelin.DT_INT32, [1, s]) v = f.var("v", dt, [1, d]) - f.scatter_add(m, ind, v) + f.assign_add_scatter(m, ind, v) check(flow, (n, d, s), 0, n, check=[m]) def negfold_test(n): @@ -835,6 +859,15 @@ def acc_matmul_test(m, k, n): size_test(i) if i < 32: rank_test(i) + embsize = 32 + for f in [1, 2, 5]: + gather_test(embsize, i, f) + gather_sum_test(embsize, i, f) + gather_max_test(embsize, i, f) + if dt == myelin.DT_FLOAT or dt == myelin.DT_DOUBLE: + gather_avg_test(embsize, i, f) + scatter_add_test(embsize, i, f) + for c in [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8]: add_const_test(i, c) sub_const_test(i, c) diff --git a/sling/nlp/embedding/embedding-model.cc b/sling/nlp/embedding/embedding-model.cc index 4c19d9d4..7a6e7618 100644 --- a/sling/nlp/embedding/embedding-model.cc +++ b/sling/nlp/embedding/embedding-model.cc @@ -68,7 +68,7 @@ void MikolovFlow::BuildLayer1() { // Backprop layer 1. tf.AssignAdd(error, tf.Mul(embed, eta)); - tf.ScatterAdd(W1, target, tf.Mul(h, eta)); + tf.AssignAddScatter(W1, target, tf.Mul(h, eta)); } void MikolovFlow::BuildLayer0Back() { @@ -77,7 +77,7 @@ void MikolovFlow::BuildLayer0Back() { l0b_l0 = tf.Instance(layer0); l0b_l1 = tf.Instance(layer1); - tf.ScatterAdd(W0, tf.Ref(l0b_l0, fv), tf.Ref(l0b_l1, error)); + tf.AssignAddScatter(W0, tf.Ref(l0b_l0, fv), tf.Ref(l0b_l1, error)); } void DualEncoderFlow::Build(const Transformations &library) {