Skip to content

Commit

Permalink
Update PyTorch Package experiment (#928)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 9, 2022
1 parent c564f46 commit c0e14ff
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 57 deletions.
162 changes: 130 additions & 32 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -138,57 +138,54 @@ python.Parser = class {

node = this._eat('id', 'global');
if (node) {
node.variable = [];
node.names = [];
do {
node.variable.push(this._parseName());
node.names.push(this._parseName(true).value);
}
while (this._tokenizer.eat(','));
return node;
}
node = this._eat('id', 'nonlocal');
if (node) {
node.variable = [];
node.names = [];
do {
node.variable.push(this._parseName());
node.names.push(this._parseName(true).value);
}
while (this._tokenizer.eat(','));
return node;
}
node = this._eat('id', 'import');
if (node) {
node.modules = [];
node.names = [];
do {
const module = this._node('module');
module.name = this._parseExpression(-1, [], false);
const alias = this._node('alias');
alias.name = this._parseDottedName();
if (this._tokenizer.eat('id', 'as')) {
module.as = this._parseExpression(-1, [], false);
alias.asname = this._parseName(true).value;
}
node.modules.push(module);
node.names.push(alias);
}
while (this._tokenizer.eat(','));
return node;
}
node = this._eat('id', 'from');
if (node) {
node.type = 'import_from';
node.level = 0;
const dots = this._tokenizer.peek();
if (dots && Array.from(dots.type).every((c) => c == '.')) {
this._eat(dots.type);
node.level = Array.from(dots.type).length;
node.module = this._parseExpression();
}
else {
node.level = 0;
node.module = this._parseExpression();
}
node.module = this._parseDottedName();
this._tokenizer.expect('id', 'import');
node.names = [];
const close = this._tokenizer.eat('(');
do {
const alias = this._node('alias');
alias.name = this._parseName();
alias.name = this._parseName(true).value;
if (this._tokenizer.eat('id', 'as')) {
alias.asname = this._parseName();
alias.asname = this._parseName(true).value;
}
node.names.push(alias);
}
Expand All @@ -203,13 +200,13 @@ python.Parser = class {

node = this._eat('id', 'class');
if (node) {
node.name = this._parseName().value;
node.name = this._parseName(true).value;
if (decorator_list) {
node.decorator_list = Array.from(decorator_list);
decorator_list = null;
}
if (this._tokenizer.peek().value === '(') {
node.base = this._parseArguments();
node.bases = this._parseArguments();
}
this._tokenizer.expect(':');
node.body = this._parseSuite();
Expand All @@ -229,7 +226,7 @@ python.Parser = class {
if (async) {
node.async = async;
}
node.name = this._parseName().value;
node.name = this._parseName(true).value;
if (decorator_list) {
node.decorator_list = Array.from(decorator_list);
decorator_list = null;
Expand Down Expand Up @@ -821,15 +818,27 @@ python.Parser = class {
return node;
}

_parseName() {
_parseName(required) {
const token = this._tokenizer.peek();
if (token.type == 'id' && !token.keyword) {
this._tokenizer.read();
return token;
}
if (required) {
throw new python.Error("Invalid syntax" + this._tokenizer.location());
}
return null;
}

_parseDottedName() {
const list = [];
do {
list.push(this._parseName(true).value);
}
while (this._tokenizer.eat('.'));
return list.join('.');
}

_parseLiteral() {
const token = this._tokenizer.peek();
if (token.type == 'string' || token.type == 'number' || token.type == 'boolean') {
Expand Down Expand Up @@ -1942,6 +1951,10 @@ python.Execution = class {
}
});
this.registerType('keras.engine.sequential.Sequential', class {});
this.registerType('lasagne.layers.conv.Conv2DLayer', class {});
this.registerType('lasagne.layers.dense.DenseLayer', class {});
this.registerType('lasagne.layers.input.InputLayer', class {});
this.registerType('lasagne.layers.pool.MaxPool2DLayer', class {});
this.registerType('lightgbm.sklearn.LGBMRegressor', class {});
this.registerType('lightgbm.sklearn.LGBMClassifier', class {});
this.registerType('lightgbm.basic.Booster', class {
Expand Down Expand Up @@ -2355,6 +2368,90 @@ python.Execution = class {
Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
}
});
this.registerType('theano.compile.function_module._constructor_Function', class {});
this.registerType('theano.compile.function_module._constructor_FunctionMaker', class {});
this.registerType('theano.compile.function_module.Supervisor', class {});
this.registerType('theano.compile.io.In', class {});
this.registerType('theano.compile.io.SymbolicOutput', class {});
this.registerType('theano.compile.mode.Mode', class {});
this.registerType('theano.compile.ops.OutputGuard', class {});
this.registerType('theano.compile.ops.Shape', class {});
this.registerType('theano.compile.ops.Shape_i', class {});
this.registerType('theano.gof.destroyhandler.DestroyHandler', class {});
this.registerType('theano.gof.fg.FunctionGraph', class {});
this.registerType('theano.gof.graph.Apply', class {});
this.registerType('theano.gof.link.Container', class {});
this.registerType('theano.gof.opt._metadict', class {});
this.registerType('theano.gof.opt.ChangeTracker', class {});
this.registerType('theano.gof.opt.MergeFeature', class {});
this.registerType('theano.gof.optdb.Query', class {});
this.registerType('theano.gof.toolbox.PreserveVariableAttributes', class {});
this.registerType('theano.gof.toolbox.ReplaceValidate', class {});
this.registerType('theano.gof.utils.scratchpad', class {});
this.registerType('theano.misc.ordered_set.Link', class {});
this.registerType('theano.misc.ordered_set.OrderedSet', class {});
this.registerType('theano.sandbox.cuda.basic_ops.HostFromGpu', class {});
this.registerType('theano.sandbox.cuda.type.CudaNdarray_unpickler', class {});
this.registerType('theano.sandbox.cuda.type.CudaNdarrayType', class {});
this.registerType('theano.sandbox.cuda.var.CudaNdarraySharedVariable', class {});
this.registerType('theano.scalar.basic.Abs', class {});
this.registerType('theano.scalar.basic.Add', class {});
this.registerType('theano.scalar.basic.Cast', class {});
this.registerType('theano.scalar.basic.Composite', class {});
this.registerType('theano.scalar.basic.EQ', class {});
this.registerType('theano.scalar.basic.GE', class {});
this.registerType('theano.scalar.basic.Identity', class {});
this.registerType('theano.scalar.basic.IntDiv', class {});
this.registerType('theano.scalar.basic.Inv', class {});
this.registerType('theano.scalar.basic.LE', class {});
this.registerType('theano.scalar.basic.LT', class {});
this.registerType('theano.scalar.basic.Mul', class {});
this.registerType('theano.scalar.basic.Neg', class {});
this.registerType('theano.scalar.basic.Scalar', class {});
this.registerType('theano.scalar.basic.ScalarConstant', class {});
this.registerType('theano.scalar.basic.ScalarVariable', class {});
this.registerType('theano.scalar.basic.Second', class {});
this.registerType('theano.scalar.basic.Sgn', class {});
this.registerType('theano.scalar.basic.specific_out', class {});
this.registerType('theano.scalar.basic.Sub', class {});
this.registerType('theano.scalar.basic.Switch', class {});
this.registerType('theano.scalar.basic.Tanh', class {});
this.registerType('theano.scalar.basic.transfer_type', class {});
this.registerType('theano.scalar.basic.TrueDiv', class {});
this.registerType('theano.tensor.basic.Alloc', class {});
this.registerType('theano.tensor.basic.Dot', class {});
this.registerType('theano.tensor.basic.MaxAndArgmax', class {});
this.registerType('theano.tensor.basic.Reshape', class {});
this.registerType('theano.tensor.basic.ScalarFromTensor', class {});
this.registerType('theano.tensor.blas.Dot22', class {});
this.registerType('theano.tensor.blas.Dot22Scalar', class {});
this.registerType('theano.tensor.blas.Gemm', class {});
this.registerType('theano.tensor.elemwise.DimShuffle', class {});
this.registerType('theano.tensor.elemwise.Elemwise', class {});
this.registerType('theano.tensor.elemwise.Sum', class {});
this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d', class {});
this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs', class {});
this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights', class {});
this.registerType('theano.tensor.nnet.corr.CorrMM', class {});
this.registerType('theano.tensor.nnet.corr.CorrMM_gradInputs', class {});
this.registerType('theano.tensor.nnet.corr.CorrMM_gradWeights', class {});
this.registerType('theano.tensor.nnet.nnet.CrossentropyCategorical1Hot', class {});
this.registerType('theano.tensor.nnet.nnet.CrossentropyCategorical1HotGrad', class {});
this.registerType('theano.tensor.nnet.nnet.CrossentropySoftmax1HotWithBiasDx', class {});
this.registerType('theano.tensor.nnet.nnet.CrossentropySoftmaxArgmax1HotWithBias', class {});
this.registerType('theano.tensor.nnet.nnet.Softmax', class {});
this.registerType('theano.tensor.nnet.nnet.SoftmaxGrad', class {});
this.registerType('theano.tensor.nnet.nnet.SoftmaxWithBias', class {});
this.registerType('theano.tensor.opt.MakeVector', class {});
this.registerType('theano.tensor.opt.ShapeFeature', class {});
this.registerType('theano.tensor.sharedvar.TensorSharedVariable', class {});
this.registerType('theano.tensor.signal.pool.MaxPoolGrad', class {});
this.registerType('theano.tensor.signal.pool.Pool', class {});
this.registerType('theano.tensor.subtensor.Subtensor', class {});
this.registerType('theano.tensor.type.TensorType', class {});
this.registerType('theano.tensor.var.TensorConstant', class {});
this.registerType('theano.tensor.var.TensorConstantSignature', class {});
this.registerType('theano.tensor.var.TensorVariable', class {});
this.registerType('thinc.describe.Biases', class {
__setstate__(state) {
Object.assign(this, state);
Expand Down Expand Up @@ -3159,31 +3256,32 @@ python.Execution = class {
break;
}
case 'import': {
for (const module of statement.modules) {
const moduleName = python.Utility.target(module.name);
const globals = this.package(moduleName);
if (module.as) {
context.set(module.as, globals);
for (const alias of statement.names) {
const module = this.package(alias.name);
if (alias.asname) {
context.set(alias.asname, module);
}
else {
context.setx(alias.name, module);
}
}
break;
}
case 'import_from': {
let module = null;
let moduleName = python.Utility.target(statement.module);
if (statement.level > 0) {
let paths = context.getx('__file__').split('/');
paths = paths.slice(0, paths.length - statement.level);
paths.push(moduleName.replace('.', '/'));
moduleName = paths.join('/');
module = this.package(moduleName);
paths.push(statement.module.replace('.', '/'));
const name = paths.join('/');
module = this.package(name);
}
else {
module = this._package(moduleName, context);
module = this._package(statement.module, context);
}
for (const entry of statement.names) {
const name = entry.name.value;
const asname = entry.asname ? entry.asname.value : null;
const name = entry.name;
const asname = entry.asname ? entry.asname : null;
context.set(asname ? asname : name, module[name]);
}
break;
Expand Down
42 changes: 18 additions & 24 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ pytorch.Graph = class {
break;
}
case 'module': {
this._type = (graph.obj.__module__ && graph.obj.__name__) ? (graph.obj.__module__ + '.' + graph.obj.__name__) : '';
this._loadModule(metadata, graph.obj, [], []);
this._type = (graph.data.__module__ && graph.data.__name__) ? (graph.data.__module__ + '.' + graph.data.__name__) : '';
this._loadModule(metadata, graph.data, [], []);
break;
}
case 'weights': {
for (const state_group of graph.layers) {
for (const state_group of graph.data) {
const attributes = state_group.attributes || [];
const inputs = state_group.states.map((parameter) => {
return new pytorch.Parameter(parameter.name, true,
Expand Down Expand Up @@ -2985,31 +2985,25 @@ pytorch.Container.Zip.Package = class extends pytorch.Container.Zip {
}
}
execution.registerFunction('torch.jit._script.unpackage_script_module', function(script_module_id) {
// torch.jit._script.RecursiveScriptModule
return script_module_id;
return "torch.jit._script.RecursiveScriptModule('" + script_module_id + "')";
});
const unpickler = python.Unpickler.open(stream);
const root = unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
if (root.model) {
const location = {
/* if (root.model) {
const location = {6
model: '.data/ts_code/' + root.model + '/data.pkl',
code: '.data/ts_code/code/',
data: '.data/',
};
const graph = new pytorch.Container.Zip.Pickle.Script(this._entries, execution, location, name);
this._graphs.push(graph);
}
else {
const obj = pytorch.Utility.findModule(root);
if (Array.isArray(obj) && obj.length === 1) {
obj[0].type = 'module';
obj[0].name = obj[0].name || name;
this._graphs.push(obj[0]);
}
else {
throw new pytorch.Error('Unsupported packaged model.');
}
}
else { */
this._graphs.push({
name: name,
type: 'module',
data: root
});
}
}
return this._graphs;
Expand Down Expand Up @@ -3915,11 +3909,11 @@ pytorch.Utility = class {
}
if (obj) {
if (obj._modules) {
return [ { name: '', obj: obj } ];
return [ { name: '', data: obj } ];
}
const objKeys = Object.keys(obj).filter((key) => obj[key] && obj[key]._modules);
if (objKeys.length > 1) {
return objKeys.map((key) => { return { name: key, obj: obj[key] }; });
return objKeys.map((key) => { return { name: key, data: obj[key] }; });
}
}
}
Expand Down Expand Up @@ -3967,7 +3961,7 @@ pytorch.Utility = class {
const argument = { id: '', value: obj };
const parameter = { name: 'value', arguments: [ argument ] };
layers.push({ states: [ parameter ] });
return [ { layers: layers } ];
return [ { data: layers } ];
}
return null;
}
Expand All @@ -3989,7 +3983,7 @@ pytorch.Utility = class {
}
}
layers.push(layer);
return [ { layers: layers } ];
return [ { data: layers } ];
}
if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
const layers = [];
Expand All @@ -4011,7 +4005,7 @@ pytorch.Utility = class {
}
layers.push(layer);
}
return [ { layers: layers } ];
return [ { data: layers } ];
}
}
return null;
Expand Down Expand Up @@ -4200,7 +4194,7 @@ pytorch.Utility = class {
}
graphs.push({
name: graph_key,
layers: layers.values()
data: layers.values()
});
}
return graphs;
Expand Down
2 changes: 1 addition & 1 deletion test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4934,7 +4934,7 @@
"type": "pytorch",
"target": "v3_1_ru.pt",
"source": "https://github.com/lutzroeder/netron/files/9075630/v3_1_ru.pt.zip[v3_1_ru.pt]",
"error": "Found non-callable @@iterator in 'v3_1_ru.pt'.",
"format": "PyTorch Package v1.9",
"link": "https://github.com/lutzroeder/netron/issues/928"
},
{
Expand Down

0 comments on commit c0e14ff

Please sign in to comment.