Skip to content
This repository has been archived by the owner on Jan 10, 2023. It is now read-only.

Commit

Permalink
AVX512 support for Myelin (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
ringgaard authored May 17, 2018
1 parent 8d3bb6e commit 941939a
Show file tree
Hide file tree
Showing 19 changed files with 1,739 additions and 479 deletions.
12 changes: 6 additions & 6 deletions doc/guide/myelin.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ file:
```python
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from flow import Flow
from flow import FlowBuilder
from sling.myelin import Flow
from sling.myelin import Builder

# Import data.
mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)
Expand All @@ -57,7 +57,7 @@ for _ in range(1000):

# Save model to flow file.
flow = Flow()
builder = FlowBuilder(sess, flow)
builder = Builder(sess, flow)
builder.add(flow.func("classifier"), [x], [y])
flow.save("/tmp/mnist.flow")
```
Expand All @@ -74,8 +74,8 @@ a flow file:

```python
import tensorflow as tf
from flow import Flow
from flow import FlowBuilder
from sling.myelin import Flow
from sling.myelin import Builder

# Load Tensorflow checkpoint.
sess = tf.Session()
Expand All @@ -84,7 +84,7 @@ saver.restore(sess, '/tmp/mnist.ckpt')

# Create Myelin flow.
flow = Flow()
builder = FlowBuilder(sess, flow)
builder = Builder(sess, flow)

# Extract flow from graph.
inputs = [sess.graph.get_tensor_by_name("x:0")]
Expand Down
4 changes: 4 additions & 0 deletions python/myelin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from flow import *
from builder import *
from nn import *

41 changes: 26 additions & 15 deletions sling/myelin/express.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,13 @@ class RecipeParser {
}
next();

// Parse single qualifier (only used for testing).
bool single = false;
if (is('\'')) {
single = true;
next();
}

// Parse variable id.
int id = 0;
int digits = 0;
Expand All @@ -295,8 +302,9 @@ class RecipeParser {
if (digits == 0) Error("Variable id expected in expression");

// Return variable.
// type could be unitialized at this point
return expr_->Variable(type, id);
Express::Var *var = expr_->Variable(type, id);
var->single = single;
return var;
}

// Output error.
Expand Down Expand Up @@ -674,16 +682,10 @@ bool Express::TryToEliminateOps() {
}

void Express::Hoist(int limit) {
// Collect all existing cached variables.
std::set<Var *> cached;
// Collect all existing hoisted variables.
std::set<Var *> hoisted;
for (int i = 0; i < body_; ++i) {
cached.insert(ops_[i]->result);
}

// Single element inputs and constants are also considered as cached since
// these are by definition loop invariant.
for (Var *var : vars_) {
if (var->type == NUMBER || var->single) cached.insert(var);
hoisted.insert(ops_[i]->result);
}

// Hoist const loads outside the body until limit reached.
Expand All @@ -693,7 +695,7 @@ void Express::Hoist(int limit) {
Var *candidate = nullptr;
for (Var *v : vars_) {
if (v->type == CONST || v->type == NUMBER) {
if (cached.count(v) == 0) {
if (hoisted.count(v) == 0) {
if (candidate == nullptr || v->usages() > candidate->usages()) {
candidate = v;
}
Expand All @@ -720,11 +722,20 @@ void Express::Hoist(int limit) {
assign->Assign(temp);
assign->AddArgument(candidate);
body_++;
cached.insert(candidate);
hoisted.insert(candidate);
hoisted.insert(temp);
new_temps++;
}
if (new_temps > 0) CompactTempVars();

// Single element inputs and constants are also considered hoisted since
// these are by definition loop invariant.
for (Var *var : vars_) {
if (var->type == NUMBER || var->type == CONST || var->single) {
hoisted.insert(var);
}
}

// Hoist loop-invariant operations.
bool again = true;
while (again) {
Expand All @@ -735,7 +746,7 @@ void Express::Hoist(int limit) {
// Check if all arguments are cached.
bool invariant = true;
for (Var *arg : op->args) {
if (cached.count(arg) == 0) {
if (hoisted.count(arg) == 0) {
invariant = false;
break;
}
Expand All @@ -745,7 +756,7 @@ void Express::Hoist(int limit) {
if (invariant) {
for (int j = i; j > body_; --j) ops_[j] = ops_[j - 1];
ops_[body_++] = op;
cached.insert(op->result);
hoisted.insert(op->result);
again = true;
break;
}
Expand Down
1 change: 1 addition & 0 deletions sling/myelin/generator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cc_library(
"vector-flt-sse.cc",
"vector-flt-avx128.cc",
"vector-flt-avx256.cc",
"vector-flt-avx512.cc",
"scalar-int.cc",
"vector-int-sse.cc",
"vector-int-avx128.cc",
Expand Down
7 changes: 7 additions & 0 deletions sling/myelin/generator/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ ElementwiseIndexGenerator::Locator *ElementwiseIndexGenerator::GetLocator(
<< " input: " << var->shape().ToString()
<< " output: " << shape_.ToString();
}
} else if (var->shape().outer(shape_.rank()) == 1) {
// The variable shape prefix is a one vector so use a simple iterator.
loc->iterator = GetIterator(SIMPLE, var->elements());
} else {
LOG(FATAL) << "Unsupported iterator: " << var->name() << " with shape "
<< var->shape().ToString()
<< " to output shape " << shape_.ToString();
}

return loc;
Expand Down
Loading

0 comments on commit 941939a

Please sign in to comment.