Skip to content
This repository has been archived by the owner on Jan 10, 2023. It is now read-only.

Consolidate reduction ops #306

Merged
merged 1 commit into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/guide/myelin.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ Alternatively, a numeric tensor parameter id can be used as as the index key.
The `cell.index(name)` method can be used for looking up tensor parameter ids in
advance, and looking up tensors by parameter ids is faster than looking up
tensors by name.
If the index key is neither a string not an integer, the repr() function of the
If the index key is neither a string nor an integer, the repr() function of the
index key is used for determining the tensor name.

The tensor is a view into the data in the instance for the variable. The tensor
Expand Down
29 changes: 26 additions & 3 deletions sling/myelin/compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,29 @@ Tensor *Step::GetPrototype() const {
return prototype;
}

string Step::Signature() const {
string str;
if (!outputs_.empty()) {
bool first = true;
for (Tensor *output : outputs_) {
if (!first) str.append(",");
str.append(output->TypeString());
first = false;
}
str.append("=");
}
str.append(type_);
str.append("(");
bool first = true;
for (Tensor *input : inputs_) {
if (!first) str.append(",");
str.append(input->TypeString());
first = false;
}
str.append(")");
return str;
}

Network::Network() {
runtime_ = &default_runtime;
linker_ = &jit_linker;
Expand Down Expand Up @@ -1034,7 +1057,7 @@ bool Network::Compile(const Flow &flow, const Library &library) {

if (var->size != stride) {
LOG(ERROR) << "Invalid data size for variable " << var->name << ", "
<< var->size << "bytes, " << stride << "expected";
<< var->size << " bytes, " << stride << " expected";
return false;
}

Expand Down Expand Up @@ -1166,8 +1189,8 @@ bool Network::Compile(const Flow &flow, const Library &library) {
}
}
if (step->kernel_ == nullptr) {
LOG(ERROR) << "No kernel supports " << step->name()
<< " of type " << step->type();
LOG(ERROR) << "No kernel supports step " << step->name() << ": "
<< step->Signature();
return false;
}
VLOG(3) << "Step " << step->name() << " implemented by "
Expand Down
3 changes: 3 additions & 0 deletions sling/myelin/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,9 @@ class Step : public Attributes {
// case, the biggest input is returned.
Tensor *GetPrototype() const;

// Get type signature for step.
string Signature() const;

private:
// Step name from flow operation.
string name_;
Expand Down
59 changes: 11 additions & 48 deletions sling/myelin/generator/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1339,54 +1339,6 @@ void ExpressionGenerator::GenerateZMMFltAccOp(
}
}

void ExpressionGenerator::GenerateZMMReduction(
OpZMMRegRegReg op,
ZMMRegister acc,
ZMMRegister aux,
int elements,
MacroAssembler *masm) {
if (elements >= 2) {
__ vshuff32x4(aux, acc, acc, 0x0E);
(masm->*op)(acc, acc, aux, nomask);
}
if (elements >= 4) {
__ vshuff32x4(aux, acc, acc, 0xB1);
(masm->*op)(acc, acc, aux, nomask);
}
if (elements >= 8) {
__ vpermilps(aux, acc, 0x0E);
(masm->*op)(acc, acc, aux, nomask);
}
if (elements >= 16) {
__ vpermilps(aux, acc, 0x01);
(masm->*op)(acc, acc, aux, nomask);
}
}

void ExpressionGenerator::GenerateZMMReduction(
OpZMMRegRegRegR op,
ZMMRegister acc,
ZMMRegister aux,
int elements,
MacroAssembler *masm) {
if (elements >= 2) {
__ vshuff32x4(aux, acc, acc, 0x0E);
(masm->*op)(acc, acc, aux, nomask, noround);
}
if (elements >= 4) {
__ vshuff32x4(aux, acc, acc, 0xB1);
(masm->*op)(acc, acc, aux, nomask, noround);
}
if (elements >= 8) {
__ vpermilps(aux, acc, 0x0E);
(masm->*op)(acc, acc, aux, nomask, noround);
}
if (elements >= 16) {
__ vpermilps(aux, acc, 0x01);
(masm->*op)(acc, acc, aux, nomask, noround);
}
}

void ExpressionGenerator::GenerateIntUnaryOp(
Express::Op *instr,
OpReg opregb, OpMem opmemb,
Expand Down Expand Up @@ -1625,6 +1577,17 @@ void ExpressionGenerator::GenerateYMMIntOp(
}
}

Reduction ReduceOp(Express::Op *instr) {
switch (instr->type) {
case Express::SUM: return REDUCE_ADD;
case Express::PRODUCT: return REDUCE_MUL;
case Express::MIN: return REDUCE_MIN;
case Express::MAX: return REDUCE_MAX;
default: UNSUPPORTED;
}
return REDUCE_ADD;
}

void UnsupportedOperation(const char *file, int line) {
LOG(FATAL) << "Unsupported operation (" << file << " line " << line << ")";
}
Expand Down
15 changes: 3 additions & 12 deletions sling/myelin/generator/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,18 +417,6 @@ class ExpressionGenerator {
OpZMMRegRegMem fltopmem, OpZMMRegRegMem dblopmem,
MacroAssembler *masm);

// Generate ZMM reduction with op.
void GenerateZMMReduction(
OpZMMRegRegReg op,
ZMMRegister acc,
ZMMRegister aux,
int elements, MacroAssembler *masm);
void GenerateZMMReduction(
OpZMMRegRegRegR op,
ZMMRegister acc,
ZMMRegister aux,
int elements, MacroAssembler *masm);

// Generate one-operand x64 int op.
void GenerateIntUnaryOp(
Express::Op *instr,
Expand Down Expand Up @@ -506,6 +494,9 @@ class ExpressionGenerator {
Express instructions_;
};

// Return reduction operator for reduction instruction.
Reduction ReduceOp(Express::Op *instr);

// Error handler for unsupported operations.
void UnsupportedOperation(const char *file, int line);

Expand Down
45 changes: 2 additions & 43 deletions sling/myelin/generator/vector-flt-avx128.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,58 +450,17 @@ class VectorFltAVX128Generator : public ExpressionGenerator {
void GenerateReduce(Express::Op *instr, MacroAssembler *masm) override {
auto acc = xmm(instr->acc);
auto aux = xmmaux(0);
__ Reduce(ReduceOp(instr), type_, acc, aux);

switch (type_) {
case DT_FLOAT:
switch (instr->type) {
case Express::SUM:
__ vhaddps(acc, acc, acc);
__ vhaddps(acc, acc, acc);
break;
case Express::PRODUCT:
__ vpermilps(aux, acc, 0x0E);
__ vmulps(acc, acc, aux);
__ vpermilps(aux, acc, 0x01);
__ vmulps(acc, acc, aux);
break;
case Express::MIN:
__ vpermilps(aux, acc, 0x0E);
__ vminps(acc, acc, aux);
__ vpermilps(aux, acc, 0x01);
__ vminps(acc, acc, aux);
break;
case Express::MAX:
__ vpermilps(aux, acc, 0x0E);
__ vmaxps(acc, acc, aux);
__ vpermilps(aux, acc, 0x01);
__ vmaxps(acc, acc, aux);
break;
default: UNSUPPORTED;
}
if (instr->dst != -1) {
__ vmovss(xmm(instr->dst), xmm(instr->dst), xmm(instr->acc));
} else {
__ vmovss(addr(instr->result), xmm(instr->acc));
}
break;
case DT_DOUBLE:
switch (instr->type) {
case Express::SUM:
__ vhaddpd(acc, acc, acc);
break;
case Express::PRODUCT:
__ vpermilpd(aux, acc, 1);
__ vmulpd(acc, acc, aux);
break;
case Express::MIN:
__ vpermilpd(aux, acc, 1);
__ vminpd(acc, acc, aux);
break;
case Express::MAX:
__ vpermilpd(aux, acc, 1);
__ vmaxpd(acc, acc, aux);
break;
default: UNSUPPORTED;
}
if (instr->dst != -1) {
__ vmovsd(xmm(instr->dst), xmm(instr->dst), xmm(instr->acc));
} else {
Expand Down
62 changes: 2 additions & 60 deletions sling/myelin/generator/vector-flt-avx256.cc
Original file line number Diff line number Diff line change
Expand Up @@ -518,75 +518,17 @@ class VectorFltAVX256Generator : public ExpressionGenerator {
void GenerateReduce(Express::Op *instr, MacroAssembler *masm) override {
auto acc = ymm(instr->acc);
auto aux = ymmaux(0);
__ Reduce(ReduceOp(instr), type_, acc, aux);

switch (type_) {
case DT_FLOAT:
switch (instr->type) {
case Express::SUM:
__ vperm2f128(aux, acc, acc, 1);
__ vaddps(acc, acc, aux);
__ vhaddps(acc, acc, acc);
__ vhaddps(acc, acc, acc);
break;
case Express::PRODUCT:
__ vperm2f128(aux, acc, acc, 1);
__ vmulps(acc, acc, aux);
__ vpermilps(aux, acc, 0x0E);
__ vmulps(acc, acc, aux);
__ vpermilps(aux, acc, 0x01);
__ vmulps(acc, acc, aux);
break;
case Express::MIN:
__ vperm2f128(aux, acc, acc, 1);
__ vminps(acc, acc, aux);
__ vpermilps(aux, acc, 0x0E);
__ vminps(acc, acc, aux);
__ vpermilps(aux, acc, 0x01);
__ vminps(acc, acc, aux);
break;
case Express::MAX:
__ vperm2f128(aux, acc, acc, 1);
__ vmaxps(acc, acc, aux);
__ vpermilps(aux, acc, 0x0E);
__ vmaxps(acc, acc, aux);
__ vpermilps(aux, acc, 0x01);
__ vmaxps(acc, acc, aux);
break;
default: UNSUPPORTED;
}
if (instr->dst != -1) {
__ vmovss(xmm(instr->dst), xmm(instr->dst), xmm(instr->acc));
} else {
__ vmovss(addr(instr->result), xmm(instr->acc));
}
break;
case DT_DOUBLE:
switch (instr->type) {
case Express::SUM:
__ vperm2f128(aux, acc, acc, 1);
__ vaddpd(acc, acc, aux);
__ vhaddpd(acc, acc, acc);
break;
case Express::PRODUCT:
__ vperm2f128(aux, acc, acc, 1);
__ vmulpd(acc, acc, aux);
__ vpermilpd(aux, acc, 1);
__ vmulpd(acc, acc, aux);
break;
case Express::MIN:
__ vperm2f128(aux, acc, acc, 1);
__ vminpd(acc, acc, aux);
__ vpermilpd(aux, acc, 1);
__ vminpd(acc, acc, aux);
break;
case Express::MAX:
__ vperm2f128(aux, acc, acc, 1);
__ vmaxpd(acc, acc, aux);
__ vpermilpd(aux, acc, 1);
__ vmaxpd(acc, acc, aux);
break;
break;
default: UNSUPPORTED;
}
if (instr->dst != -1) {
__ vmovsd(xmm(instr->dst), xmm(instr->dst), xmm(instr->acc));
} else {
Expand Down
36 changes: 4 additions & 32 deletions sling/myelin/generator/vector-flt-avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,48 +489,20 @@ class VectorFltAVX512Generator : public ExpressionGenerator {
void GenerateReduce(Express::Op *instr, MacroAssembler *masm) override {
auto acc = zmm(instr->acc);
auto aux = zmmaux(0);
__ Reduce(ReduceOp(instr), type_, acc, aux);

switch (type_) {
case DT_FLOAT:
switch (instr->type) {
case Express::SUM:
GenerateZMMReduction(&Assembler::vaddps, acc, aux, 16, masm);
break;
case Express::PRODUCT:
GenerateZMMReduction(&Assembler::vmulps, acc, aux, 16, masm);
break;
case Express::MIN:
GenerateZMMReduction(&Assembler::vminps, acc, aux, 16, masm);
break;
case Express::MAX:
GenerateZMMReduction(&Assembler::vmaxps, acc, aux, 16, masm);
break;
default: UNSUPPORTED;
}
if (instr->dst != -1) {
__ vmovss(zmm(instr->dst).x(), zmm(instr->dst).x(),
__ vmovss(zmm(instr->dst).x(), zmm(instr->dst).x(),
zmm(instr->acc).x());
} else {
__ vmovss(addr(instr->result), zmm(instr->acc).x());
}
break;
case DT_DOUBLE:
switch (instr->type) {
case Express::SUM:
GenerateZMMReduction(&Assembler::vaddpd, acc, aux, 8, masm);
break;
case Express::PRODUCT:
GenerateZMMReduction(&Assembler::vmulpd, acc, aux, 8, masm);
break;
case Express::MIN:
GenerateZMMReduction(&Assembler::vminpd, acc, aux, 8, masm);
break;
case Express::MAX:
GenerateZMMReduction(&Assembler::vmaxpd, acc, aux, 8, masm);
break;
default: UNSUPPORTED;
}
if (instr->dst != -1) {
__ vmovsd(zmm(instr->dst).x(), zmm(instr->dst).x(),
__ vmovsd(zmm(instr->dst).x(), zmm(instr->dst).x(),
zmm(instr->acc).x());
} else {
__ vmovsd(addr(instr->result), zmm(instr->acc).x());
Expand Down
Loading