Skip to content
This repository has been archived by the owner on Jan 10, 2023. It is now read-only.

Add back myelin tf extractor #293

Merged
merged 1 commit into from
Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions doc/guide/myelin.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ file:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from sling.myelin import Flow
from sling.myelin import Builder
from sling.myelin.tf import Extractor

# Import data.
mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)
Expand All @@ -57,8 +57,8 @@ for _ in range(1000):

# Save model to flow file.
flow = Flow()
builder = Builder(sess, flow)
builder.add(flow.func("classifier"), [x], [y])
extractor = Extractor(sess, flow)
extractor.add(flow.func("classifier"), [x], [y])
flow.save("/tmp/mnist.flow")
```

Expand All @@ -75,7 +75,7 @@ a flow file:
```python
import tensorflow as tf
from sling.myelin import Flow
from sling.myelin import Builder
from sling.myelin.tf import Extractor

# Load Tensorflow checkpoint.
sess = tf.Session()
Expand All @@ -84,12 +84,12 @@ saver.restore(sess, '/tmp/mnist.ckpt')

# Create Myelin flow.
flow = Flow()
builder = Builder(sess, flow)
extractor = Extractor(sess, flow)

# Extract flow from graph.
inputs = [sess.graph.get_tensor_by_name("x:0")]
outputs = [sess.graph.get_tensor_by_name("y:0")]
builder.add(flow.func("classifier"), inputs, outputs)
extractor.add(flow.func("classifier"), inputs, outputs)

# Save flow.
flow.save("/tmp/mnist.flow")
Expand Down
3 changes: 0 additions & 3 deletions python/myelin/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@

"""Myelin function builder and expression evaluator."""

import math
from struct import pack

from flow import Variable
from flow import Function
from flow import Flow
Expand Down
10 changes: 5 additions & 5 deletions python/myelin/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, filename):
f.readinto(b)
self.view = memoryview(b)
self.next = 0 # index of next byte to read

def slice(self, size):
"""Returns a slice of given size from the current byte and skips ahead."""
current = self.next
Expand Down Expand Up @@ -267,7 +267,7 @@ def add_attr(self, name, value):
def get_attr(self, name):
"""Get blob attribute as a string or None."""
return self.attrs.get(name, None)

class Flow:
"""Flow with variables, operations, and functions."""

Expand All @@ -288,15 +288,15 @@ def func(self, name):
self.funcs[name] = f
return f

def var(self, name, type="float32", shape=[]):
def var(self, name, type="float32", shape=None):
"""Add variable to flow."""
if isinstance(name, Variable): return name
v = self.vars.get(name, None)
if v == None:
v = Variable(name)
self.vars[name] = v
v.type = type
v.shape = shape
if shape != None: v.shape = shape
return v

def op(self, name):
Expand Down Expand Up @@ -526,7 +526,7 @@ def load(self, filename):
if ref: var.ref = True
data_size = f.read_long()
var.data = f.slice(data_size) # avoid creating a copy

num_ops = f.read_int()
for _ in xrange(num_ops):
name = f.read_string()
Expand Down
120 changes: 120 additions & 0 deletions python/myelin/tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

def attr_str(value):
""" Convert attribute to string value."""
if isinstance(value, bool):
return "true" if value else "false"
elif isinstance(value, int):
return str(value)
elif isinstance(value, long):
return str(value)
elif isinstance(value, str):
return value
elif isinstance(value, list):
l = []
for v in value: l.append(attr_str(v))
return ",".join(l)
elif value.__class__.__name__ == "TensorShapeProto":
dims = []
for d in value.dim: dims.append(str(d.size))
return "x".join(dims)
elif value.__class__.__name__ == "TensorProto":
return str(value)
elif value.__class__.__name__ == "DType":
return value.name
else:
return str(type(value)) + ":" + str(value).replace('\n', ' ')

class Extractor:
"""Extract myelin flow from tensorflow graph."""

def __init__(self, sess, flow):
"""Initialize empty flow builder."""
self.sess = sess
self.feed = None
self.flow = flow
self.vars = []
self.ops = []

def add(self, func, inputs, outputs):
"""Add ops to flow."""
for var in outputs:
self.expand(func, var, inputs)

def expand(self, func, var, inputs):
"""Traverse graphs and add ops to flow."""
if var not in self.vars:
# Add new variable to flow.
self.vars.append(var)
v = self.flow.var(var.name, var.dtype.base_dtype.name, [])

# Get data for constants and variables.
if var.op.type in ["Const", "ConstV2"]:
v.data = tf.contrib.util.constant_value(var)
elif var.op.type in ["Variable", "VariableV2"]:
if self.feed is None:
v.data = var.eval(session=self.sess)
else:
v.data = self.sess.run(var, feed_dict=self.feed)

# Get shape.
if v.data is None:
shape = var.get_shape()
for d in shape.as_list():
if d != None:
v.shape.append(d)
else:
v.shape.append(-1)
else:
for d in v.data.shape:
v.shape.append(d)

if not var in inputs:
op = var.op
if op not in self.ops:
# Add new operation to flow function.
self.ops.append(op)
o = self.flow.op(op.name)
func.add(o)
o.type = op.type
for input in op.inputs:
o.add_input(self.flow.var(input.name))
for output in op.outputs:
o.add_output(self.flow.var(output.name))
for a in op.node_def.attr:
o.add_attr(a, attr_str(op.get_attr(a)))

# Traverse dependencies.
for dep in op.inputs:
self.expand(func, dep, inputs)


def compute_shapes(self):
"""Compute shapes for variables with missing shape information."""
# Find all variables with missing shape information.
missing = {}
for var in self.vars:
v = self.flow.var(var.name)
if not v.shape_defined():
missing[v] = var

if len(missing) > 0:
# Compute variables from feed.
results = self.sess.run(missing, feed_dict=self.feed)

# Use the shape of the computed variables for the flow.
for v in results:
v.shape = results[v].shape
7 changes: 7 additions & 0 deletions tools/build-wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def sha256_content_checksum(data):
'python/flags.py': '$DATA$/sling/flags.py',
'python/log.py': '$DATA$/sling/log.py',

'python/myelin/__init__.py': '$DATA$/sling/myelin/__init__.py',
'python/myelin/builder.py': '$DATA$/sling/myelin/builder.py',
'python/myelin/flow.py': '$DATA$/sling/myelin/flow.py',
'python/myelin/lexical_encoder.py': '$DATA$/sling/myelin/lexical_encoder.py',
'python/myelin/nn.py': '$DATA$/sling/myelin/nn.py',
'python/myelin/tf.py': '$DATA$/sling/myelin/tf.py',

'python/nlp/__init__.py': '$DATA$/sling/nlp/__init__.py',
'python/nlp/document.py': '$DATA$/sling/nlp/document.py',
'python/nlp/parser.py': '$DATA$/sling/nlp/parser.py',
Expand Down