Skip to content
This repository has been archived by the owner on Jan 10, 2023. It is now read-only.

Commit

Permalink
Add back myelin tf extractor (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
ringgaard authored Nov 16, 2018
1 parent b8b1234 commit 14c5c07
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 14 deletions.
12 changes: 6 additions & 6 deletions doc/guide/myelin.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ file:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from sling.myelin import Flow
from sling.myelin import Builder
from sling.myelin.tf import Extractor

# Import data.
mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)
Expand All @@ -57,8 +57,8 @@ for _ in range(1000):

# Save model to flow file.
flow = Flow()
builder = Builder(sess, flow)
builder.add(flow.func("classifier"), [x], [y])
extractor = Extractor(sess, flow)
extractor.add(flow.func("classifier"), [x], [y])
flow.save("/tmp/mnist.flow")
```

Expand All @@ -75,7 +75,7 @@ a flow file:
```python
import tensorflow as tf
from sling.myelin import Flow
from sling.myelin import Builder
from sling.myelin.tf import Extractor

# Load Tensorflow checkpoint.
sess = tf.Session()
Expand All @@ -84,12 +84,12 @@ saver.restore(sess, '/tmp/mnist.ckpt')

# Create Myelin flow.
flow = Flow()
builder = Builder(sess, flow)
extractor = Extractor(sess, flow)

# Extract flow from graph.
inputs = [sess.graph.get_tensor_by_name("x:0")]
outputs = [sess.graph.get_tensor_by_name("y:0")]
builder.add(flow.func("classifier"), inputs, outputs)
extractor.add(flow.func("classifier"), inputs, outputs)

# Save flow.
flow.save("/tmp/mnist.flow")
Expand Down
3 changes: 0 additions & 3 deletions python/myelin/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@

"""Myelin function builder and expression evaluator."""

import math
from struct import pack

from flow import Variable
from flow import Function
from flow import Flow
Expand Down
10 changes: 5 additions & 5 deletions python/myelin/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, filename):
f.readinto(b)
self.view = memoryview(b)
self.next = 0 # index of next byte to read

def slice(self, size):
"""Returns a slice of given size from the current byte and skips ahead."""
current = self.next
Expand Down Expand Up @@ -267,7 +267,7 @@ def add_attr(self, name, value):
def get_attr(self, name):
"""Get blob attribute as a string or None."""
return self.attrs.get(name, None)

class Flow:
"""Flow with variables, operations, and functions."""

Expand All @@ -288,15 +288,15 @@ def func(self, name):
self.funcs[name] = f
return f

def var(self, name, type="float32", shape=[]):
def var(self, name, type="float32", shape=None):
"""Add variable to flow."""
if isinstance(name, Variable): return name
v = self.vars.get(name, None)
if v == None:
v = Variable(name)
self.vars[name] = v
v.type = type
v.shape = shape
if shape != None: v.shape = shape
return v

def op(self, name):
Expand Down Expand Up @@ -526,7 +526,7 @@ def load(self, filename):
if ref: var.ref = True
data_size = f.read_long()
var.data = f.slice(data_size) # avoid creating a copy

num_ops = f.read_int()
for _ in xrange(num_ops):
name = f.read_string()
Expand Down
120 changes: 120 additions & 0 deletions python/myelin/tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

def attr_str(value):
""" Convert attribute to string value."""
if isinstance(value, bool):
return "true" if value else "false"
elif isinstance(value, int):
return str(value)
elif isinstance(value, long):
return str(value)
elif isinstance(value, str):
return value
elif isinstance(value, list):
l = []
for v in value: l.append(attr_str(v))
return ",".join(l)
elif value.__class__.__name__ == "TensorShapeProto":
dims = []
for d in value.dim: dims.append(str(d.size))
return "x".join(dims)
elif value.__class__.__name__ == "TensorProto":
return str(value)
elif value.__class__.__name__ == "DType":
return value.name
else:
return str(type(value)) + ":" + str(value).replace('\n', ' ')

class Extractor:
"""Extract myelin flow from tensorflow graph."""

def __init__(self, sess, flow):
"""Initialize empty flow builder."""
self.sess = sess
self.feed = None
self.flow = flow
self.vars = []
self.ops = []

def add(self, func, inputs, outputs):
"""Add ops to flow."""
for var in outputs:
self.expand(func, var, inputs)

def expand(self, func, var, inputs):
"""Traverse graphs and add ops to flow."""
if var not in self.vars:
# Add new variable to flow.
self.vars.append(var)
v = self.flow.var(var.name, var.dtype.base_dtype.name, [])

# Get data for constants and variables.
if var.op.type in ["Const", "ConstV2"]:
v.data = tf.contrib.util.constant_value(var)
elif var.op.type in ["Variable", "VariableV2"]:
if self.feed is None:
v.data = var.eval(session=self.sess)
else:
v.data = self.sess.run(var, feed_dict=self.feed)

# Get shape.
if v.data is None:
shape = var.get_shape()
for d in shape.as_list():
if d != None:
v.shape.append(d)
else:
v.shape.append(-1)
else:
for d in v.data.shape:
v.shape.append(d)

if not var in inputs:
op = var.op
if op not in self.ops:
# Add new operation to flow function.
self.ops.append(op)
o = self.flow.op(op.name)
func.add(o)
o.type = op.type
for input in op.inputs:
o.add_input(self.flow.var(input.name))
for output in op.outputs:
o.add_output(self.flow.var(output.name))
for a in op.node_def.attr:
o.add_attr(a, attr_str(op.get_attr(a)))

# Traverse dependencies.
for dep in op.inputs:
self.expand(func, dep, inputs)


def compute_shapes(self):
"""Compute shapes for variables with missing shape information."""
# Find all variables with missing shape information.
missing = {}
for var in self.vars:
v = self.flow.var(var.name)
if not v.shape_defined():
missing[v] = var

if len(missing) > 0:
# Compute variables from feed.
results = self.sess.run(missing, feed_dict=self.feed)

# Use the shape of the computed variables for the flow.
for v in results:
v.shape = results[v].shape
7 changes: 7 additions & 0 deletions tools/build-wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def sha256_content_checksum(data):
'python/flags.py': '$DATA$/sling/flags.py',
'python/log.py': '$DATA$/sling/log.py',

'python/myelin/__init__.py': '$DATA$/sling/myelin/__init__.py',
'python/myelin/builder.py': '$DATA$/sling/myelin/builder.py',
'python/myelin/flow.py': '$DATA$/sling/myelin/flow.py',
'python/myelin/lexical_encoder.py': '$DATA$/sling/myelin/lexical_encoder.py',
'python/myelin/nn.py': '$DATA$/sling/myelin/nn.py',
'python/myelin/tf.py': '$DATA$/sling/myelin/tf.py',

'python/nlp/__init__.py': '$DATA$/sling/nlp/__init__.py',
'python/nlp/document.py': '$DATA$/sling/nlp/document.py',
'python/nlp/parser.py': '$DATA$/sling/nlp/parser.py',
Expand Down

0 comments on commit 14c5c07

Please sign in to comment.