From c0e14fff2a0ba8d3ccb76a7b66c2fe1a39f62b6b Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 9 Jul 2022 10:35:33 -0700 Subject: [PATCH] Update PyTorch Package experiment (#928) --- source/python.js | 162 +++++++++++++++++++++++++++++++++++++--------- source/pytorch.js | 42 ++++++------ test/models.json | 2 +- 3 files changed, 149 insertions(+), 57 deletions(-) diff --git a/source/python.js b/source/python.js index 570797cbc8..fa71a493c7 100644 --- a/source/python.js +++ b/source/python.js @@ -138,32 +138,32 @@ python.Parser = class { node = this._eat('id', 'global'); if (node) { - node.variable = []; + node.names = []; do { - node.variable.push(this._parseName()); + node.names.push(this._parseName(true).value); } while (this._tokenizer.eat(',')); return node; } node = this._eat('id', 'nonlocal'); if (node) { - node.variable = []; + node.names = []; do { - node.variable.push(this._parseName()); + node.names.push(this._parseName(true).value); } while (this._tokenizer.eat(',')); return node; } node = this._eat('id', 'import'); if (node) { - node.modules = []; + node.names = []; do { - const module = this._node('module'); - module.name = this._parseExpression(-1, [], false); + const alias = this._node('alias'); + alias.name = this._parseDottedName(); if (this._tokenizer.eat('id', 'as')) { - module.as = this._parseExpression(-1, [], false); + alias.asname = this._parseName(true).value; } - node.modules.push(module); + node.names.push(alias); } while (this._tokenizer.eat(',')); return node; @@ -171,24 +171,21 @@ python.Parser = class { node = this._eat('id', 'from'); if (node) { node.type = 'import_from'; + node.level = 0; const dots = this._tokenizer.peek(); if (dots && Array.from(dots.type).every((c) => c == '.')) { this._eat(dots.type); node.level = Array.from(dots.type).length; - node.module = this._parseExpression(); - } - else { - node.level = 0; - node.module = this._parseExpression(); } + node.module = this._parseDottedName(); this._tokenizer.expect('id', 'import'); node.names = []; const close = this._tokenizer.eat('('); do { const alias = this._node('alias'); - alias.name = this._parseName(); + alias.name = this._parseName(true).value; if (this._tokenizer.eat('id', 'as')) { - alias.asname = this._parseName(); + alias.asname = this._parseName(true).value; } node.names.push(alias); } @@ -203,13 +200,13 @@ python.Parser = class { node = this._eat('id', 'class'); if (node) { - node.name = this._parseName().value; + node.name = this._parseName(true).value; if (decorator_list) { node.decorator_list = Array.from(decorator_list); decorator_list = null; } if (this._tokenizer.peek().value === '(') { - node.base = this._parseArguments(); + node.bases = this._parseArguments(); } this._tokenizer.expect(':'); node.body = this._parseSuite(); @@ -229,7 +226,7 @@ python.Parser = class { if (async) { node.async = async; } - node.name = this._parseName().value; + node.name = this._parseName(true).value; if (decorator_list) { node.decorator_list = Array.from(decorator_list); decorator_list = null; @@ -821,15 +818,27 @@ python.Parser = class { return node; } - _parseName() { + _parseName(required) { const token = this._tokenizer.peek(); if (token.type == 'id' && !token.keyword) { this._tokenizer.read(); return token; } + if (required) { + throw new python.Error("Invalid syntax" + this._tokenizer.location()); + } return null; } + _parseDottedName() { + const list = []; + do { + list.push(this._parseName(true).value); + } + while (this._tokenizer.eat('.')); + return list.join('.'); + } + _parseLiteral() { const token = this._tokenizer.peek(); if (token.type == 'string' || token.type == 'number' || token.type == 'boolean') { @@ -1942,6 +1951,10 @@ python.Execution = class { } }); this.registerType('keras.engine.sequential.Sequential', class {}); + this.registerType('lasagne.layers.conv.Conv2DLayer', class {}); + this.registerType('lasagne.layers.dense.DenseLayer', class {}); + this.registerType('lasagne.layers.input.InputLayer', class {}); + this.registerType('lasagne.layers.pool.MaxPool2DLayer', class {}); this.registerType('lightgbm.sklearn.LGBMRegressor', class {}); this.registerType('lightgbm.sklearn.LGBMClassifier', class {}); this.registerType('lightgbm.basic.Booster', class { @@ -2355,6 +2368,90 @@ python.Execution = class { Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null)); } }); + this.registerType('theano.compile.function_module._constructor_Function', class {}); + this.registerType('theano.compile.function_module._constructor_FunctionMaker', class {}); + this.registerType('theano.compile.function_module.Supervisor', class {}); + this.registerType('theano.compile.io.In', class {}); + this.registerType('theano.compile.io.SymbolicOutput', class {}); + this.registerType('theano.compile.mode.Mode', class {}); + this.registerType('theano.compile.ops.OutputGuard', class {}); + this.registerType('theano.compile.ops.Shape', class {}); + this.registerType('theano.compile.ops.Shape_i', class {}); + this.registerType('theano.gof.destroyhandler.DestroyHandler', class {}); + this.registerType('theano.gof.fg.FunctionGraph', class {}); + this.registerType('theano.gof.graph.Apply', class {}); + this.registerType('theano.gof.link.Container', class {}); + this.registerType('theano.gof.opt._metadict', class {}); + this.registerType('theano.gof.opt.ChangeTracker', class {}); + this.registerType('theano.gof.opt.MergeFeature', class {}); + this.registerType('theano.gof.optdb.Query', class {}); + this.registerType('theano.gof.toolbox.PreserveVariableAttributes', class {}); + this.registerType('theano.gof.toolbox.ReplaceValidate', class {}); + this.registerType('theano.gof.utils.scratchpad', class {}); + this.registerType('theano.misc.ordered_set.Link', class {}); + this.registerType('theano.misc.ordered_set.OrderedSet', class {}); + this.registerType('theano.sandbox.cuda.basic_ops.HostFromGpu', class {}); + this.registerType('theano.sandbox.cuda.type.CudaNdarray_unpickler', class {}); + this.registerType('theano.sandbox.cuda.type.CudaNdarrayType', class {}); + this.registerType('theano.sandbox.cuda.var.CudaNdarraySharedVariable', class {}); + this.registerType('theano.scalar.basic.Abs', class {}); + this.registerType('theano.scalar.basic.Add', class {}); + this.registerType('theano.scalar.basic.Cast', class {}); + this.registerType('theano.scalar.basic.Composite', class {}); + this.registerType('theano.scalar.basic.EQ', class {}); + this.registerType('theano.scalar.basic.GE', class {}); + this.registerType('theano.scalar.basic.Identity', class {}); + this.registerType('theano.scalar.basic.IntDiv', class {}); + this.registerType('theano.scalar.basic.Inv', class {}); + this.registerType('theano.scalar.basic.LE', class {}); + this.registerType('theano.scalar.basic.LT', class {}); + this.registerType('theano.scalar.basic.Mul', class {}); + this.registerType('theano.scalar.basic.Neg', class {}); + this.registerType('theano.scalar.basic.Scalar', class {}); + this.registerType('theano.scalar.basic.ScalarConstant', class {}); + this.registerType('theano.scalar.basic.ScalarVariable', class {}); + this.registerType('theano.scalar.basic.Second', class {}); + this.registerType('theano.scalar.basic.Sgn', class {}); + this.registerType('theano.scalar.basic.specific_out', class {}); + this.registerType('theano.scalar.basic.Sub', class {}); + this.registerType('theano.scalar.basic.Switch', class {}); + this.registerType('theano.scalar.basic.Tanh', class {}); + this.registerType('theano.scalar.basic.transfer_type', class {}); + this.registerType('theano.scalar.basic.TrueDiv', class {}); + this.registerType('theano.tensor.basic.Alloc', class {}); + this.registerType('theano.tensor.basic.Dot', class {}); + this.registerType('theano.tensor.basic.MaxAndArgmax', class {}); + this.registerType('theano.tensor.basic.Reshape', class {}); + this.registerType('theano.tensor.basic.ScalarFromTensor', class {}); + this.registerType('theano.tensor.blas.Dot22', class {}); + this.registerType('theano.tensor.blas.Dot22Scalar', class {}); + this.registerType('theano.tensor.blas.Gemm', class {}); + this.registerType('theano.tensor.elemwise.DimShuffle', class {}); + this.registerType('theano.tensor.elemwise.Elemwise', class {}); + this.registerType('theano.tensor.elemwise.Sum', class {}); + this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d', class {}); + this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs', class {}); + this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights', class {}); + this.registerType('theano.tensor.nnet.corr.CorrMM', class {}); + this.registerType('theano.tensor.nnet.corr.CorrMM_gradInputs', class {}); + this.registerType('theano.tensor.nnet.corr.CorrMM_gradWeights', class {}); + this.registerType('theano.tensor.nnet.nnet.CrossentropyCategorical1Hot', class {}); + this.registerType('theano.tensor.nnet.nnet.CrossentropyCategorical1HotGrad', class {}); + this.registerType('theano.tensor.nnet.nnet.CrossentropySoftmax1HotWithBiasDx', class {}); + this.registerType('theano.tensor.nnet.nnet.CrossentropySoftmaxArgmax1HotWithBias', class {}); + this.registerType('theano.tensor.nnet.nnet.Softmax', class {}); + this.registerType('theano.tensor.nnet.nnet.SoftmaxGrad', class {}); + this.registerType('theano.tensor.nnet.nnet.SoftmaxWithBias', class {}); + this.registerType('theano.tensor.opt.MakeVector', class {}); + this.registerType('theano.tensor.opt.ShapeFeature', class {}); + this.registerType('theano.tensor.sharedvar.TensorSharedVariable', class {}); + this.registerType('theano.tensor.signal.pool.MaxPoolGrad', class {}); + this.registerType('theano.tensor.signal.pool.Pool', class {}); + this.registerType('theano.tensor.subtensor.Subtensor', class {}); + this.registerType('theano.tensor.type.TensorType', class {}); + this.registerType('theano.tensor.var.TensorConstant', class {}); + this.registerType('theano.tensor.var.TensorConstantSignature', class {}); + this.registerType('theano.tensor.var.TensorVariable', class {}); this.registerType('thinc.describe.Biases', class { __setstate__(state) { Object.assign(this, state); @@ -3159,31 +3256,32 @@ python.Execution = class { break; } case 'import': { - for (const module of statement.modules) { - const moduleName = python.Utility.target(module.name); - const globals = this.package(moduleName); - if (module.as) { - context.set(module.as, globals); + for (const alias of statement.names) { + const module = this.package(alias.name); + if (alias.asname) { + context.set(alias.asname, module); + } + else { + context.setx(alias.name, module); } } break; } case 'import_from': { let module = null; - let moduleName = python.Utility.target(statement.module); if (statement.level > 0) { let paths = context.getx('__file__').split('/'); paths = paths.slice(0, paths.length - statement.level); - paths.push(moduleName.replace('.', '/')); - moduleName = paths.join('/'); - module = this.package(moduleName); + paths.push(statement.module.replace('.', '/')); + const name = paths.join('/'); + module = this.package(name); } else { - module = this._package(moduleName, context); + module = this._package(statement.module, context); } for (const entry of statement.names) { - const name = entry.name.value; - const asname = entry.asname ? entry.asname.value : null; + const name = entry.name; + const asname = entry.asname ? entry.asname : null; context.set(asname ? asname : name, module[name]); } break; diff --git a/source/pytorch.js b/source/pytorch.js index 12b4c28f37..05677ae076 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -156,12 +156,12 @@ pytorch.Graph = class { break; } case 'module': { - this._type = (graph.obj.__module__ && graph.obj.__name__) ? (graph.obj.__module__ + '.' + graph.obj.__name__) : ''; - this._loadModule(metadata, graph.obj, [], []); + this._type = (graph.data.__module__ && graph.data.__name__) ? (graph.data.__module__ + '.' + graph.data.__name__) : ''; + this._loadModule(metadata, graph.data, [], []); break; } case 'weights': { - for (const state_group of graph.layers) { + for (const state_group of graph.data) { const attributes = state_group.attributes || []; const inputs = state_group.states.map((parameter) => { return new pytorch.Parameter(parameter.name, true, @@ -2985,13 +2985,12 @@ pytorch.Container.Zip.Package = class extends pytorch.Container.Zip { } } execution.registerFunction('torch.jit._script.unpackage_script_module', function(script_module_id) { - // torch.jit._script.RecursiveScriptModule - return script_module_id; + return "torch.jit._script.RecursiveScriptModule('" + script_module_id + "')"; }); const unpickler = python.Unpickler.open(stream); const root = unpickler.load((name, args) => execution.invoke(name, args), persistent_load); - if (root.model) { - const location = { + /* if (root.model) { + const location = {6 model: '.data/ts_code/' + root.model + '/data.pkl', code: '.data/ts_code/code/', data: '.data/', @@ -2999,17 +2998,12 @@ pytorch.Container.Zip.Package = class extends pytorch.Container.Zip { const graph = new pytorch.Container.Zip.Pickle.Script(this._entries, execution, location, name); this._graphs.push(graph); } - else { - const obj = pytorch.Utility.findModule(root); - if (Array.isArray(obj) && obj.length === 1) { - obj[0].type = 'module'; - obj[0].name = obj[0].name || name; - this._graphs.push(obj[0]); - } - else { - throw new pytorch.Error('Unsupported packaged model.'); - } - } + else { */ + this._graphs.push({ + name: name, + type: 'module', + data: root + }); } } return this._graphs; @@ -3915,11 +3909,11 @@ pytorch.Utility = class { } if (obj) { if (obj._modules) { - return [ { name: '', obj: obj } ]; + return [ { name: '', data: obj } ]; } const objKeys = Object.keys(obj).filter((key) => obj[key] && obj[key]._modules); if (objKeys.length > 1) { - return objKeys.map((key) => { return { name: key, obj: obj[key] }; }); + return objKeys.map((key) => { return { name: key, data: obj[key] }; }); } } } @@ -3967,7 +3961,7 @@ pytorch.Utility = class { const argument = { id: '', value: obj }; const parameter = { name: 'value', arguments: [ argument ] }; layers.push({ states: [ parameter ] }); - return [ { layers: layers } ]; + return [ { data: layers } ]; } return null; } @@ -3989,7 +3983,7 @@ pytorch.Utility = class { } } layers.push(layer); - return [ { layers: layers } ]; + return [ { data: layers } ]; } if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) { const layers = []; @@ -4011,7 +4005,7 @@ pytorch.Utility = class { } layers.push(layer); } - return [ { layers: layers } ]; + return [ { data: layers } ]; } } return null; @@ -4200,7 +4194,7 @@ pytorch.Utility = class { } graphs.push({ name: graph_key, - layers: layers.values() + data: layers.values() }); } return graphs; diff --git a/test/models.json b/test/models.json index 5d26b5f531..50cbcb8862 100644 --- a/test/models.json +++ b/test/models.json @@ -4934,7 +4934,7 @@ "type": "pytorch", "target": "v3_1_ru.pt", "source": "https://github.com/lutzroeder/netron/files/9075630/v3_1_ru.pt.zip[v3_1_ru.pt]", - "error": "Found non-callable @@iterator in 'v3_1_ru.pt'.", + "format": "PyTorch Package v1.9", "link": "https://github.com/lutzroeder/netron/issues/928" }, {