Skip to content

Commit

Permalink
PyTorch LSTM support (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 7, 2018
1 parent c066d62 commit c4ce60d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 35 deletions.
48 changes: 46 additions & 2 deletions src/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
"schema": {
"attributes": [
{
"default": true,
"default": false,
"name": "inplace"
},
{
Expand Down Expand Up @@ -223,6 +223,27 @@
"package": "torch.nn.modules.batchnorm"
}
},
{
"name": "Dropout2d",
"schema": {
"attributes": [
{
"default": false,
"name": "inplace"
},
{
"default": 0.5,
"name": "p"
},
{
"name": "training",
"visible": false
}
],
"category": "Dropout",
"package": "torch.nn.modules.dropout"
}
},
{
"name": "Dropout",
"schema": {
Expand All @@ -248,7 +269,30 @@
"name": "LSTM",
"schema": {
"category": "Layer",
"package": "torch.nn"
"package": "torch.nn.modules.rnn",
"attributes": [
{ "name": "training", "visible": false },
{ "name": "dropout", "default": 0 },
{ "name": "dropout_state", "default": {} },
{ "name": "num_layers", "default": 1 },
{ "name": "batch_first", "visible": false },
{ "name": "bidirectional", "visible": false }
]
}
},
{
"name": "Embedding",
"schema": {
"category": "Transform",
"package": "torch.nn.modules.sparse",
"attributes": [
{ "name": "training", "visible": false },
{ "name": "norm_type", "default": 2 },
{ "name": "scale_grad_by_freq", "default": false },
{ "name": "sparse", "default": false },
{ "name": "max_norm", "default": null },
{ "name": "padding_idx", "default": null }
]
}
}
]
78 changes: 45 additions & 33 deletions src/pytorch-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class PyTorchModelFactory {
callback(err, null);
return;
}
this._openModel(context, host, callback);
PyTorchOperatorMetadata.open(host, (err, metadata) => {
this._openModel(context, host, callback);
});
});
}

Expand All @@ -39,7 +41,7 @@ class PyTorchModelFactory {
}
var sysInfo = unpickler.load();
if (!sysInfo.little_endian) {
callback(new PyTorchError('Unsupported system information.'));
callback(new PyTorchError('Unsupported endian format.'));
return;
}
if (sysInfo.protocol_version != 1001) {
Expand Down Expand Up @@ -72,9 +74,12 @@ class PyTorchModelFactory {
constructorTable['torch.nn.modules.conv.ConvTranspose2d'] = function () {};
constructorTable['torch.nn.modules.conv.ConvTranspose3d'] = function () {};
constructorTable['torch.nn.modules.dropout.Dropout'] = function () {};
constructorTable['torch.nn.modules.dropout.Dropout2d'] = function () {};
constructorTable['torch.nn.modules.linear.Linear'] = function () {};
constructorTable['torch.nn.modules.pooling.AvgPool2d'] = function () {};
constructorTable['torch.nn.modules.pooling.MaxPool2d'] = function () {};
constructorTable['torch.nn.modules.rnn.LSTM'] = function () {};
constructorTable['torch.nn.modules.sparse.Embedding'] = function () {};
constructorTable['torchvision.models.alexnet.AlexNet'] = function () {};
constructorTable['torchvision.models.densenet.DenseNet'] = function () {};
constructorTable['torchvision.models.densenet._DenseBlock'] = function () {};
Expand Down Expand Up @@ -195,10 +200,7 @@ class PyTorchModelFactory {
}

var model = new PyTorchModel(sysInfo, root);

PyTorchOperatorMetadata.open(host, (err, metadata) => {
callback(null, model);
});
callback(null, model);
}
catch (error) {
host.exception(error, false);
Expand Down Expand Up @@ -233,16 +235,24 @@ class PyTorchGraph {
this._groups = true;

var input = 'data';
this._inputs.push(new PyTorchArgument(input, [ new PyTorchConnection(input, null, null) ]));
this._inputs.push(new PyTorchArgument(input, true, [ new PyTorchConnection(input, null, null) ]));

var outputs = this._loadModule(root, [], [ input ]);
outputs.forEach((output) => {
this._outputs.push(new PyTorchArgument(output, [ new PyTorchConnection(output, null, null) ]));
this._outputs.push(new PyTorchArgument(output, true, [ new PyTorchConnection(output, null, null) ]));
});
}

_loadModule(parent, groups, inputs) {

if (parent.__type__ &&
parent.__type__.startsWith('torch.nn.modules.') &&
!parent.__type__.startsWith('torch.nn.modules.container.')) {
var node = new PyTorchNode(parent, groups, inputs);
this._nodes.push(node);
return [];
}

if (!parent._modules) {
throw new PyTorchError('Module does not contain modules.');
}
Expand Down Expand Up @@ -315,8 +325,9 @@ class PyTorchGraph {
}

class PyTorchArgument {
constructor(name, connections) {
constructor(name, visible, connections) {
this._name = name;
this._visible = visible;
this._connections = connections;
}

Expand All @@ -325,7 +336,7 @@ class PyTorchArgument {
}

get visible() {
return true;
return this._visible;
}

get connections() {
Expand Down Expand Up @@ -366,7 +377,7 @@ class PyTorchNode {
this._operator = module.__type__.split('.').pop();

this._inputs = [];
this._inputs.push(new PyTorchArgument('input', connections.map((connection) => {
this._inputs.push(new PyTorchArgument('input', true, connections.map((connection) => {
return new PyTorchConnection(connection, null, null);
})));

Expand All @@ -391,12 +402,13 @@ class PyTorchNode {
else if (parameter.storage) {
initializer = new PyTorchTensor(parameter);
}
this._inputs.push(new PyTorchArgument(parameter.__id__, [ new PyTorchConnection(null, null, initializer) ]));
var visible = (this._operator != 'LSTM' || initializer == null);
this._inputs.push(new PyTorchArgument(parameter.__id__, visible, [ new PyTorchConnection(null, null, initializer) ]));
}
});

this._outputs = [];
this._outputs.push(new PyTorchArgument('output', [ new PyTorchConnection(this._name, null, null) ]));
this._outputs.push(new PyTorchArgument('output', true, [ new PyTorchConnection(this._name, null, null) ]));

this._attributes = [];
Object.keys(module).forEach((key) => {
Expand All @@ -419,7 +431,8 @@ class PyTorchNode {
}

get category() {
return PyTorchOperatorMetadata.operatorMetadata.getOperatorCategory(this._operator);
var schema = PyTorchOperatorMetadata.operatorMetadata.getSchema(this._operator);
return (schema && schema.category) ? schema.category : null;
}

get attributes() {
Expand All @@ -441,6 +454,18 @@ class PyTorchAttribute {
this._node = node;
this._name = name;
this._value = value;

var schema = PyTorchOperatorMetadata.operatorMetadata.getAttributeSchema(this._node.operator, this._name);
if (schema) {
if (schema.hasOwnProperty('visible') && !schema.visible) {
this._visible = false;
}
else if (schema.hasOwnProperty('default')) {
if (JSON.stringify(schema.default) == JSON.stringify(value)) {
this._visible = false;
}
}
}
}

get name() {
Expand All @@ -452,7 +477,7 @@ class PyTorchAttribute {
}

get visible() {
return PyTorchOperatorMetadata.operatorMetadata.getAttributeVisible(this._node.operator, this._name);
return this._visible == false ? false : true;
}
}

Expand Down Expand Up @@ -629,15 +654,11 @@ class PyTorchOperatorMetadata {
}
}

getOperatorCategory(operator) {
var schema = this._map[operator];
if (schema && schema.category) {
return schema.category;
}
return null;
getSchema(operator) {
return this._map[operator] || null;
}

getAttributeVisible(operator, name, value) {
getAttributeSchema(operator, name) {
var schema = this._map[operator];
if (schema && schema.attributes && schema.attributes.length > 0) {
if (!schema.attributesMap) {
Expand All @@ -646,18 +667,9 @@ class PyTorchOperatorMetadata {
schema.attributesMap[attribute.name] = attribute;
});
}
var attribute = schema.attributesMap[name];

if (attribute) {
if (attribute.hasOwnProperty('visible')) {
return attribute.visible;
}
if (attribute.hasOwnProperty('default')) {
return JSON.stringify(attribute.default) == JSON.stringify(value);
}
}
return schema.attributesMap[name] || null;
}
return true;
return null;
}
}

Expand Down

0 comments on commit c4ce60d

Please sign in to comment.