From eceaf0c211bb26282527be6a7ce7a320ac9a3b8f Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 8 Dec 2021 12:38:02 -0500 Subject: [PATCH] Add TorchScript test file (#842) (#851) --- source/python.js | 2 +- source/pytorch-metadata.json | 36 +++++++++++++++++----- source/pytorch.js | 60 +++++++++++++++++++++++++++++++++--- test/models.json | 11 +++++-- 4 files changed, 94 insertions(+), 15 deletions(-) diff --git a/source/python.js b/source/python.js index 2a3d3648541..4d0c5f7c5b1 100644 --- a/source/python.js +++ b/source/python.js @@ -2800,7 +2800,7 @@ python.Execution = class { break; } case 'var': { - context.set(statement.name, undefined); + context.set(statement.name, statement.initializer ? this.expression(statement.initializer, context) : undefined); break; } case '=': { diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index 097bcb4648d..de0cd861709 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -2720,13 +2720,11 @@ ] }, { - "name": "torch.stack", + "name": "torch.stack:", "category": "Tensor", - "attributes": [ - { "name": "dim", "type": "int64", "default": 0 } - ], "inputs": [ - { "name": "inputs", "type": "Tensor[]" } + { "name": "inputs", "type": "Tensor[]" }, + { "name": "dim", "type": "int64", "default": 0 } ], "outputs": [ { "name": "output", "type": "Tensor" } @@ -3301,14 +3299,22 @@ }, { "name": "torch.mul:Scalar", - "attributes": [ + "inputs": [ + { "name": "input", "type": "Tensor" }, { "name": "other", "type": "Scalar" } ], + "outputs": [ + { "name": "output", "type": "Tensor" } + ] + }, + { + "name": "torch.mul:ScalarT", "inputs": [ - { "name": "input", "type": "Tensor" } + { "name": "input", "type": "Tensor[]" }, + { "name": "other", "type": "Scalar" } ], "outputs": [ - { "name": "output", "type": "Tensor" } + { "name": "output", "type": "Tensor[]" } ] }, { @@ -6061,6 +6067,20 @@ { "name": "output3", "type": "Tensor" } ] }, + { + "name": "torch._unique2:", + "inputs": [ + { "name": "self", "type": "Tensor" }, + { "name": "sorted", "type": "boolean", "default": false }, + { "name": "return_inverse", "type": "boolean", "default": false }, + { "name": "return_counts", "type": "boolean", "default": false } + ], + "outputs": [ + { "name": "output", "type": "Tensor" }, + { "name": "output", "type": "Tensor" }, + { "name": "output", "type": "Tensor" } + ] + }, { "name": "torch._weight_norm", "inputs": [ diff --git a/source/pytorch.js b/source/pytorch.js index c6b5e06e2e4..c75cced92fd 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1399,8 +1399,13 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error(message); }); this.registerFunction('range', function(start, stop, step) { - if (start !== undefined && Number.isInteger(start) && stop === undefined && step === undefined) { - return Array(start).keys(); + if (stop === undefined && step === undefined) { + if (Number.isInteger(start)) { + return Array(start).keys(); + } + if (isNaN(start)) { + return []; + } } throw new pytorch.Error('Unsupported function range(' + JSON.stringify(start) + ', ' + JSON.stringify(stop) + ', ' + JSON.stringify(step) + ')'); }); @@ -1604,6 +1609,10 @@ pytorch.Execution = class extends python.Execution { } throw new pytorch.Error("Unknown 'torch.ge' expression type."); }); + this.registerFunction('torch.is_floating_point', function(tensor) { + const type = tensor.dtype.scalar_type(); + return (type === 5 || type === 6 || type === 7); + }); this.registerFunction('torch.jit._pickle.build_boollist', function(data) { return data; }); @@ -1743,12 +1752,15 @@ pytorch.Execution = class extends python.Execution { return []; }); this.registerFunction('torch.slice', function(l, start, end, step) { + if (!Array.isArray(l)) { + throw new pytorch.Error('Slicing expected array'); + } step = step || 1; if (step !== 1) { throw new pytorch.Error('Slicing only supports step=1'); } start = Math.max(0, start >= 0 ? start : l.length + start); - end = Math.min(l.length, end); + end = Math.min(l.length, end || Number.MAX_SAFE_INTEGER); return l.slice(start, end); }); this.registerFunction('torch.sub', function(left, right) { @@ -1801,6 +1813,10 @@ pytorch.Execution = class extends python.Execution { constructor(size, dtype) { this._size = size; this._dtype = dtype; + this._device = null; + } + get device() { + return null; } get dtype() { return this._dtype; @@ -1902,6 +1918,9 @@ pytorch.Execution = class extends python.Execution { this.registerType('torch.Tensor', class { constructor() { } + get device() { + return this.storage().device; + } get dtype() { return this.storage().dtype; } @@ -2946,6 +2965,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { case 'torch.quantize_per_tensor': case 'torch.relu_': case 'torch.hardtanh_': + case 'torch.upsample_bilinear2d': case 'torch.unsqueeze': case 'ops.prepacked.conv2d_clamp_run': { parameter.resize_([ NaN, NaN, NaN, NaN ]); @@ -2973,6 +2993,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { } case 'torch.mean': case 'torch.mul': + case 'torch.div': case 'torch.batch_norm': case 'torch.gelu': case 'torch.relu': @@ -2983,7 +3004,8 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { } break; } - case 'torch.add': { + case 'torch.add': + case 'torch.sub': { const input = this.expression(args[0], context); if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { parameter.resize_(input.size()); @@ -2996,6 +3018,13 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { } break; } + case 'torch.select': { + const input = this.expression(args[0], context); + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + parameter.resize_(Array(input.size().length - 1).fill(NaN)); + } + break; + } case 'torch.layer_norm': { const input = this.expression(args[0], context); const normalized_shape = this.expression(args[1], context); @@ -3176,6 +3205,29 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution { tensor.resize_(Array(number).fill(NaN)); } } + // val = torch.slice(torch.size(img), -2) + // if torch.eq(torch.len(val), 2): + // pass + // else: + // ops.prim.RaiseException("AssertionError: ") + if (assign.type === '=' && + condition.type === 'if' && + pytorch.Utility.isCall(assign.expression, 'torch.slice', 2) && + pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.size', 1) && + pytorch.Utility.isCall(condition.condition, 'torch.eq', 2) && + pytorch.Utility.isCall(condition.condition.arguments[0], 'torch.len', 1) && + pytorch.Utility.isEqual(condition.condition.arguments[0].arguments[0], assign.target) && + condition.else.statements.length == 1 && + pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) { + const tensor = this.expression(assign.expression.arguments[0].arguments[0], context); + if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) { + const start = this.expression(assign.expression.arguments[1], context); + const value = this.expression(condition.condition.arguments[1], context); + if (Number.isInteger(start) && start < 0 && Number.isInteger(value) && value > 0) { + tensor.resize_(Array(value - start).fill(NaN)); + } + } + } } if (statements.length > 1) { const size = statements[0]; diff --git a/test/models.json b/test/models.json index 7b2cd363a18..3918fba6396 100644 --- a/test/models.json +++ b/test/models.json @@ -4326,8 +4326,8 @@ { "type": "pytorch", "target": "fasterrcnn_resnet50_fpn.pt", - "source": "https://github.com/lutzroeder/netron/files/6040364/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]", - "error": "Unsupported function 'torch.full' in 'fasterrcnn_resnet50_fpn.pt'.", + "source": "https://github.com/lutzroeder/netron/files/7677467/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]", + "error": "Unknown torch.add expression type in 'fasterrcnn_resnet50_fpn.pt'.", "link": "https://github.com/lutzroeder/netron/issues/689" }, { @@ -4859,6 +4859,13 @@ "source": "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth", "format": "PyTorch v0.1.1" }, + { + "type": "pytorch", + "target": "ssdlite320_mobilenet_v3_large.pt", + "source": "https://github.com/lutzroeder/netron/files/7677468/ssdlite320_mobilenet_v3_large.pt.zip[ssdlite320_mobilenet_v3_large.pt]", + "format": "TorchScript v1.6", + "link": "https://github.com/lutzroeder/netron/issues/842" + }, { "type": "pytorch", "target": "superpoint_v1.pth",