diff --git a/src/pytorch-metadata.json b/src/pytorch-metadata.json index 2e4a2b18c8..9e551be44e 100755 --- a/src/pytorch-metadata.json +++ b/src/pytorch-metadata.json @@ -120,7 +120,7 @@ "schema": { "attributes": [ { - "default": true, + "default": false, "name": "inplace" }, { @@ -223,6 +223,27 @@ "package": "torch.nn.modules.batchnorm" } }, + { + "name": "Dropout2d", + "schema": { + "attributes": [ + { + "default": false, + "name": "inplace" + }, + { + "default": 0.5, + "name": "p" + }, + { + "name": "training", + "visible": false + } + ], + "category": "Dropout", + "package": "torch.nn.modules.dropout" + } + }, { "name": "Dropout", "schema": { @@ -248,7 +269,30 @@ "name": "LSTM", "schema": { "category": "Layer", - "package": "torch.nn" + "package": "torch.nn.modules.rnn", + "attributes": [ + { "name": "training", "visible": false }, + { "name": "dropout", "default": 0 }, + { "name": "dropout_state", "default": {} }, + { "name": "num_layers", "default": 1 }, + { "name": "batch_first", "visible": false }, + { "name": "bidirectional", "visible": false } + ] + } + }, + { + "name": "Embedding", + "schema": { + "category": "Transform", + "package": "torch.nn.modules.sparse", + "attributes": [ + { "name": "training", "visible": false }, + { "name": "norm_type", "default": 2 }, + { "name": "scale_grad_by_freq", "default": false }, + { "name": "sparse", "default": false }, + { "name": "max_norm", "default": null }, + { "name": "padding_idx", "default": null } + ] } } ] diff --git a/src/pytorch-model.js b/src/pytorch-model.js index f33cb0b987..430de5f3ef 100755 --- a/src/pytorch-model.js +++ b/src/pytorch-model.js @@ -15,7 +15,9 @@ class PyTorchModelFactory { callback(err, null); return; } - this._openModel(context, host, callback); + PyTorchOperatorMetadata.open(host, (err, metadata) => { + this._openModel(context, host, callback); + }); }); } @@ -39,7 +41,7 @@ class PyTorchModelFactory { } var sysInfo = unpickler.load(); if (!sysInfo.little_endian) { - callback(new PyTorchError('Unsupported system information.')); + callback(new PyTorchError('Unsupported endian format.')); return; } if (sysInfo.protocol_version != 1001) { @@ -72,9 +74,12 @@ class PyTorchModelFactory { constructorTable['torch.nn.modules.conv.ConvTranspose2d'] = function () {}; constructorTable['torch.nn.modules.conv.ConvTranspose3d'] = function () {}; constructorTable['torch.nn.modules.dropout.Dropout'] = function () {}; + constructorTable['torch.nn.modules.dropout.Dropout2d'] = function () {}; constructorTable['torch.nn.modules.linear.Linear'] = function () {}; constructorTable['torch.nn.modules.pooling.AvgPool2d'] = function () {}; constructorTable['torch.nn.modules.pooling.MaxPool2d'] = function () {}; + constructorTable['torch.nn.modules.rnn.LSTM'] = function () {}; + constructorTable['torch.nn.modules.sparse.Embedding'] = function () {}; constructorTable['torchvision.models.alexnet.AlexNet'] = function () {}; constructorTable['torchvision.models.densenet.DenseNet'] = function () {}; constructorTable['torchvision.models.densenet._DenseBlock'] = function () {}; @@ -195,10 +200,7 @@ class PyTorchModelFactory { } var model = new PyTorchModel(sysInfo, root); - - PyTorchOperatorMetadata.open(host, (err, metadata) => { - callback(null, model); - }); + callback(null, model); } catch (error) { host.exception(error, false); @@ -233,16 +235,24 @@ class PyTorchGraph { this._groups = true; var input = 'data'; - this._inputs.push(new PyTorchArgument(input, [ new PyTorchConnection(input, null, null) ])); + this._inputs.push(new PyTorchArgument(input, true, [ new PyTorchConnection(input, null, null) ])); var outputs = this._loadModule(root, [], [ input ]); outputs.forEach((output) => { - this._outputs.push(new PyTorchArgument(output, [ new PyTorchConnection(output, null, null) ])); + this._outputs.push(new PyTorchArgument(output, true, [ new PyTorchConnection(output, null, null) ])); }); } _loadModule(parent, groups, inputs) { + if (parent.__type__ && + parent.__type__.startsWith('torch.nn.modules.') && + !parent.__type__.startsWith('torch.nn.modules.container.')) { + var node = new PyTorchNode(parent, groups, inputs); + this._nodes.push(node); + return []; + } + if (!parent._modules) { throw new PyTorchError('Module does not contain modules.'); } @@ -315,8 +325,9 @@ class PyTorchGraph { } class PyTorchArgument { - constructor(name, connections) { + constructor(name, visible, connections) { this._name = name; + this._visible = visible; this._connections = connections; } @@ -325,7 +336,7 @@ class PyTorchArgument { } get visible() { - return true; + return this._visible; } get connections() { @@ -366,7 +377,7 @@ class PyTorchNode { this._operator = module.__type__.split('.').pop(); this._inputs = []; - this._inputs.push(new PyTorchArgument('input', connections.map((connection) => { + this._inputs.push(new PyTorchArgument('input', true, connections.map((connection) => { return new PyTorchConnection(connection, null, null); }))); @@ -391,12 +402,13 @@ class PyTorchNode { else if (parameter.storage) { initializer = new PyTorchTensor(parameter); } - this._inputs.push(new PyTorchArgument(parameter.__id__, [ new PyTorchConnection(null, null, initializer) ])); + var visible = (this._operator != 'LSTM' || initializer == null); + this._inputs.push(new PyTorchArgument(parameter.__id__, visible, [ new PyTorchConnection(null, null, initializer) ])); } }); this._outputs = []; - this._outputs.push(new PyTorchArgument('output', [ new PyTorchConnection(this._name, null, null) ])); + this._outputs.push(new PyTorchArgument('output', true, [ new PyTorchConnection(this._name, null, null) ])); this._attributes = []; Object.keys(module).forEach((key) => { @@ -419,7 +431,8 @@ class PyTorchNode { } get category() { - return PyTorchOperatorMetadata.operatorMetadata.getOperatorCategory(this._operator); + var schema = PyTorchOperatorMetadata.operatorMetadata.getSchema(this._operator); + return (schema && schema.category) ? schema.category : null; } get attributes() { @@ -441,6 +454,18 @@ class PyTorchAttribute { this._node = node; this._name = name; this._value = value; + + var schema = PyTorchOperatorMetadata.operatorMetadata.getAttributeSchema(this._node.operator, this._name); + if (schema) { + if (schema.hasOwnProperty('visible') && !schema.visible) { + this._visible = false; + } + else if (schema.hasOwnProperty('default')) { + if (JSON.stringify(schema.default) == JSON.stringify(value)) { + this._visible = false; + } + } + } } get name() { @@ -452,7 +477,7 @@ class PyTorchAttribute { } get visible() { - return PyTorchOperatorMetadata.operatorMetadata.getAttributeVisible(this._node.operator, this._name); + return this._visible == false ? false : true; } } @@ -629,15 +654,11 @@ class PyTorchOperatorMetadata { } } - getOperatorCategory(operator) { - var schema = this._map[operator]; - if (schema && schema.category) { - return schema.category; - } - return null; + getSchema(operator) { + return this._map[operator] || null; } - getAttributeVisible(operator, name, value) { + getAttributeSchema(operator, name) { var schema = this._map[operator]; if (schema && schema.attributes && schema.attributes.length > 0) { if (!schema.attributesMap) { @@ -646,18 +667,9 @@ class PyTorchOperatorMetadata { schema.attributesMap[attribute.name] = attribute; }); } - var attribute = schema.attributesMap[name]; - - if (attribute) { - if (attribute.hasOwnProperty('visible')) { - return attribute.visible; - } - if (attribute.hasOwnProperty('default')) { - return JSON.stringify(attribute.default) == JSON.stringify(value); - } - } + return schema.attributesMap[name] || null; } - return true; + return null; } }