diff --git a/source/python.js b/source/python.js index 042af0c494..15d03a9dfc 100644 --- a/source/python.js +++ b/source/python.js @@ -6755,10 +6755,10 @@ python.Execution = class { return this.kind() === rhs.kind(); } isSubtypeOf(rhs) { - if (rhs.kind() === 'OptionalType') { + if (rhs.kind() === 'OptionalType' && this.kind() !== 'OptionalType') { return rhs.getElementType().equals(this); } - return false; + return this.equals(rhs); } str() { if (this._kind === 'VarType' && this._annotation_str) { @@ -6814,6 +6814,9 @@ python.Execution = class { findAttribute(name) { return this._attributes.get(name); } + getAttribute(name) { + return this._attributes.get(name); + } hasConstant(/* name */) { } methods() { @@ -6848,7 +6851,7 @@ python.Execution = class { super('ListType'); this._elem = elem; } - static get(elem) { + static create(elem) { return new torch.ListType(elem); } getElementType() { @@ -6981,6 +6984,9 @@ python.Execution = class { equals(rhs) { return this.kind() === rhs.kind(); } + isSubtypeOf(/* rhs */) { + return true; + } str() { return 'NoneType'; } @@ -7144,7 +7150,7 @@ python.Execution = class { this._key = key; this._value = value; } - static get(key, value) { + static create(key, value) { return new torch.DictType(key, value); } getKeyType() { @@ -7415,7 +7421,7 @@ python.Execution = class { const value_type = this.parseType().first; L.expect(')'); alias_info = this.parseAliasAnnotation(); - real_value = torch.DictType.get(key_type, value_type); + real_value = torch.DictType.create(key_type, value_type); fake_value = real_value; } else if (L.eat('Union')) { L.next(); @@ -7454,8 +7460,8 @@ python.Execution = class { while (true) { if (L.kind === '[]') { L.expect('[]'); - fake_value = torch.ListType.get(fake_value); - real_value = torch.ListType.get(real_value); + fake_value = torch.ListType.create(fake_value); + real_value = torch.ListType.create(real_value); let container = this.parseAliasAnnotation(); if (alias_info) { if (!container) { @@ -7524,8 +7530,8 @@ python.Execution = class { L.whitespace(0); let N = null; if (L.eat('[')) { - fake_type = torch.ListType.get(fake_type); - real_type = torch.ListType.get(real_type); + fake_type = torch.ListType.create(fake_type); + real_type = torch.ListType.create(real_type); if (L.kind === '#') { N = Number(L.value); L.next(); @@ -7932,8 +7938,28 @@ python.Execution = class { this._block = new torch.Block(this); this._insert_before = this.return_node(); } - create(kind) { - return new torch.Node(this, kind); + create(kind, ...args) { + let inputs = null; + let num_outputs = 1; + if (args.length === 2 && Array.isArray(args[0]) && typeof args[1] === 'number') { + [inputs, num_outputs] = args; + } else if (args.length === 1) { + if (typeof args[0] === 'number') { + [num_outputs] = args; + } else if (Array.isArray(args[0])) { + [inputs] = args; + } + } + const n = new torch.Node(this, kind); + if (inputs) { + for (const i of inputs) { + n.addInput(i); + } + } + for (let i = 0; i < num_outputs; i++) { + n.addOutput(); + } + return n; } inputs() { return this._block.inputs(); @@ -7954,7 +7980,123 @@ python.Execution = class { return this._block.addInput(name); } insertNode(node) { - node.insertBefore(this._insert_before); + return node.insertBefore(this._insert_before); + } + insertConstant(val) { + const n = this.create('prim::Constant'); + this.insertNode(n); + let type = null; + if (val === null) { + n.ival_('value', val); + type = torch.NoneType.get(); + } else if (typeof val === 'string') { + n.s_('value', val); + type = torch.StringType.get(); + } else if (Array.isArray(val) && val.every((item) => typeof item === 'string')) { + n.ss_('value', val); + type = torch.ListType.create(torch.StringType.get()); + } else if (typeof val === 'boolean') { + // return value; + n.i_('value', val === true ? 1 : 0); + type = torch.BoolType.get(); + } else if (Number.isInteger(val)) { + n.i_('value', val); + type = torch.IntType.get(); + } else if (typeof val === 'number') { + // return value; + n.f_('value', val); + type = torch.FloatType.get(); + } else { + throw new python.Error(`Unsupported value type '${typeof value}'.`); + } + if (type) { + n.output().setType(type); + } + return n.output(); + } + createList(contained_type, values) { + const n = this.create('prim::ListConstruct', values); + for (const v of values) { + if (!v.type().isSubtypeOf(contained_type)) { + throw new python.Error('Invalid list item.'); + } + } + n.output().setType(torch.ListType.create(contained_type)); + return n; + } + createDict(key_type, value_type, keys, values) { + if (keys.length !== values.length) { + throw new python.Error('Invalid dictionary size.'); + } + const n = this.create('prim::DictConstruct'); + const length = keys.length; + for (let i = 0; i < length; i++) { + if (!keys[i].type().isSubtypeOf(key_type)) { + throw new python.Error('Invalid key.'); + } + if (!values[i].type().isSubtypeOf(value_type)) { + throw new python.Error('Invalid value.'); + } + n.addInput(keys[i]); + n.addInput(values[i]); + } + n.output().setType(torch.DictType.create(key_type, value_type)); + return n; + } + createObject(type) { + const node = this.create('prim::CreateObject'); + node.output().setType(type); + return node; + } + createIsInstance(v, types) { + const n = this.create('prim::isinstance', [v], 1); + n.tys_('types', types); + n.output().setType(torch.BoolType.get()); + return n; + } + createSetAttr(obj, field, newValue) { + const n = this.create('prim::SetAttr', [obj, newValue], 0); + n.s_('name', field); + return n; + } + createGetAttr(obj, field) { + const n = this.create('prim::GetAttr', [obj]); + n.s_('name', field); + const classType = obj.type(); + const outputType = classType.getAttribute(field); + n.output().setType(outputType); + n.output().setDebugName(/^[0-9]+$/.test(field) ? `_${field}` : field); + return n; + } + insertUncheckedCast(v, type) { + const n = this.insertNode(this.create('prim::unchecked_cast', [v])); + n.output().setType(type); + return n.output(); + } + insertToList(v, type) { + let dim = 0; + let ptr = type; + while (ptr instanceof torch.ListType) { + ptr = ptr.getElementType(); + dim += 1; + } + let elem_ty = 0; + if (ptr instanceof torch.IntType) { + elem_ty = 0; + } else if (ptr instanceof torch.FloatType) { + elem_ty = 1; + } else if (ptr instanceof torch.BoolType) { + elem_ty = 2; + } else if (ptr instanceof torch.ComplexType) { + elem_ty = 3; + } else { + throw new python.Error(`Unsupported list type '${type.kind()}'.`); + } + const dim_val = this.insertConstant(dim); + const elem_ty_val = this.insertConstant(elem_ty); + const n = this.insertNode(this.create('prim::tolist', [v, dim_val, elem_ty_val])); + n.output().setType(type); + return n.output(); } insertPoint() { return this._insert_before; @@ -7994,8 +8136,8 @@ python.Execution = class { this.registerType('torch.Block', class { constructor(graph) { this._graph = graph; - this._input = graph.create('prim::Param'); - this._output = graph.create('prim::Return'); + this._input = graph.create('prim::Param', 0); + this._output = graph.create('prim::Return', 0); this._input.next = this._output; this._input.prev = this._output; this._output.next = this._input; @@ -8091,6 +8233,12 @@ python.Execution = class { outputs() { return this._outputs; } + output() { + if (this._outputs.length !== 1) { + throw new python.Error('Node has multiple outputs.'); + } + return this._outputs[0]; + } blocks() { return this._blocks; } @@ -8214,6 +8362,12 @@ python.Execution = class { f(name) { return this._values.get(name)[0]; } + tys_(name, value) { + this._values.set(name, [value, 'tys']); + } + tys(name) { + return this._values.get(name)[0]; + } ival_(name, value) { this._values.set(name, [value, 'ival']); } diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index 3c18cc047f..de0d60d439 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -6228,6 +6228,9 @@ { "name": "prim::shape(Tensor self) -> int[]" }, + { + "name": "prim::tolist(...) -> ..." + }, { "name": "prim::type(Device self) -> str" }, diff --git a/source/pytorch.js b/source/pytorch.js index 8eab62fe9b..9d29d0223a 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -405,6 +405,7 @@ pytorch.Node = class { case 'i': value = node.i(name); type = 'int64'; break; case 'f': value = node.f(name); type = 'float32'; break; case 'ss': value = node.ss(name); type = 'string[]'; break; + case 'tys': value = node.tys(name).map((ty) => pytorch.Utility.toType(ty)); type = 'type[]'; break; case 'ival': value = node.ival(name); break; default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`); } @@ -1570,7 +1571,7 @@ pytorch.Execution = class extends python.Execution { execution.variable(this.serialized_model_tensor); this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1; const type = new pytorch.nnapi.Graph(this.serialized_model); - const node = execution.graph.create(type); + const node = execution.graph.create(type, 0); execution.graph.insertNode(node); for (const tensor of inputs) { const value = execution.variable(tensor); @@ -1674,41 +1675,6 @@ pytorch.Execution = class extends python.Execution { return this._graph; } - constant(value) { - const torch = this.torch; - const node = this.graph.create('prim::Constant'); - this.graph.insertNode(node); - let type = null; - if (value === null) { - node.ival_('value', value); - type = torch.NoneType.get(); - } else if (typeof value === 'string') { - node.s_('value', value); - type = torch.StringType.get(); - } else if (Array.isArray(value) && value.every((item) => typeof item === 'string')) { - node.ss_('value', value); - type = torch.ListType.get(torch.StringType.get()); - } else if (typeof value === 'boolean') { - // return value; - node.i_('value', value === true ? 1 : 0); - type = torch.BoolType.get(); - } else if (Number.isInteger(value)) { - node.i_('value', value); - type = torch.IntType.get(); - } else if (typeof value === 'number') { - // return value; - node.f_('value', value); - type = torch.FloatType.get(); - } else { - throw new pytorch.Error(`Unsupported value type '${typeof value}'.`); - } - if (type) { - value = node.addOutput(); - value.setType(type); - } - return value; - } - variable(obj, node) { const torch = this.torch; if (this._values.has(obj)) { @@ -1767,7 +1733,7 @@ pytorch.Execution = class extends python.Execution { const value = this.builtins[expr.id]; const entries = Object.entries(value).map(([name, value]) => { if (Array.isArray(value) && value.length > 0 && value.every((item) => typeof item === 'string')) { - value = this.constant(value); + value = this._graph.insertConstant(value); return [name, value]; } return [name, value]; @@ -1815,7 +1781,7 @@ pytorch.Execution = class extends python.Execution { return super.target(expr, context); } - expression(expr, context) { + expression(expr, context, typehint) { if (!this.trace) { return super.expression(expr, context); } @@ -1824,7 +1790,7 @@ pytorch.Execution = class extends python.Execution { switch (expr.type) { case 'Constant': { if (expr.value === true || expr.value === false) { - return this.constant(expr.value); + return this._graph.insertConstant(expr.value); } break; } @@ -1833,7 +1799,7 @@ pytorch.Execution = class extends python.Execution { if (target instanceof ast.Name) { let value = this.expression(expr.value, context); if (typeof value === 'string' || typeof value === 'boolean' || typeof value === 'number') { - value = this.constant(value); + value = this._graph.insertConstant(value); } else if (typeof value !== 'object' && value !== undefined) { throw new pytorch.Error(`Unsupported assignment value type '${typeof value}'.`); } @@ -1848,7 +1814,7 @@ pytorch.Execution = class extends python.Execution { context.target.pop(); if (target.elts.every((item) => item instanceof ast.Name)) { if (value instanceof torch.Value) { - const node = this._graph.create('prim::TupleUnpack'); + const node = this._graph.create('prim::TupleUnpack', 0); node.setSourceRange(expr.location); this.graph.insertNode(node); node.addInput(value); @@ -1894,8 +1860,10 @@ pytorch.Execution = class extends python.Execution { const func = expr.func; if (func instanceof ast.Name && func.id === 'annotate') { const type = this.type(expr.args[0]); - let value = this.expression(expr.args[1], context); - if (value instanceof torch.Tensor) { + const [, obj] = expr.args; + let value = this.expression(obj, context, type); + if (value instanceof torch.Tensor || + (value instanceof torch.Value && value.type() instanceof torch.TensorType)) { let name = null; if (type instanceof torch.IntType) { name = 'IntImplicit'; @@ -1914,11 +1882,11 @@ pytorch.Execution = class extends python.Execution { const target = new ast.Name('torch'); return this.call(target, name, expr.args.slice(1), context); } - if (value instanceof torch.Value) { - value.setType(type); + if (value instanceof torch.Value && !type.equals(value.type())) { + throw new pytorch.Error('Invalid annotation type hint.'); } if (value === null) { - value = this.constant(value); + value = this._graph.insertConstant(value); value.setType(type); } return value; @@ -1928,29 +1896,33 @@ pytorch.Execution = class extends python.Execution { const node = this._graph.create('prim::Uninitialized'); node.setSourceRange(expr.location); this.graph.insertNode(node); - const value = node.addOutput(); - value.setType(type); - return value; + node.output().setType(type); + return node.output(); } if (func instanceof ast.Name && func.id === 'unchecked_cast') { let value = this.expression(expr.args[1], context); + if (value instanceof torch.Value === false) { // remove + value = this.variable(value); + } const type = this.type(expr.args[0]); - const node = this._graph.create('prim::unchecked_cast'); - this.graph.insertNode(node); - node.addInput(this.variable(value)); - value = node.addOutput(); - value.setType(type); - return value; + return this.graph.insertUncheckedCast(value, type); } if (func instanceof ast.Name && func.id === 'isinstance') { - let value = this.expression(expr.args[1], context); - // const type = this.type(expression.args[0]); - const node = this._graph.create('prim::isinstance'); + const value = this.expression(expr.args[0], context); + let [, types] = expr.args; + if (types instanceof ast.Tuple) { + types = types.elts.map((expr) => this.type(expr)); + } else { + types = [this.type(types)]; + } + const v = this.variable(value); // remove + const node = this._graph.createIsInstance(v, types); this.graph.insertNode(node); - node.addInput(this.variable(value)); - value = node.addOutput(); - value.setType(torch.BoolType.get()); - return value; + return node.output(); + } + if (func.attr === 'tolist' && expr.args.length === 0) { + const target = this.target(func.value, context); + return this.graph.insertToList(target, typehint); } return super.expression(expr, context); } @@ -1965,22 +1937,18 @@ pytorch.Execution = class extends python.Execution { } if (type instanceof torch.ListType) { let index = this.expression(elt, context); - const node = this._graph.create('aten::__getitem__.t'); - this.graph.insertNode(node); - node.addInput(value); if (Number.isInteger(index)) { - index = this.constant(index); + index = this._graph.insertConstant(index); } - node.addInput(index); - const output = node.addOutput(); - output.setType(type.getElementType()); - return output; + const node = this._graph.create('aten::__getitem__.t', [value, index]); + this.graph.insertNode(node); + node.output().setType(type.getElementType()); + return node.output(); } if (type instanceof torch.DictType) { let key = this.expression(elt, context); - const node = this._graph.create('aten::__getitem__.t'); + const node = this._graph.create('aten::__getitem__.t', [value]); this.graph.insertNode(node); - node.addInput(value); if (type.getKeyType() instanceof torch.StringType && typeof key === 'string') { const value = new torch.Value(node); value.value = key; @@ -1991,9 +1959,8 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error(`Unsupported dictionary key type.`); } node.addInput(key); - const output = node.addOutput(); - output.setType(type.getValueType()); - return output; + node.output().setType(type.getValueType()); + return node.output(); } if (type instanceof torch.TupleType) { const index = this.expression(elt, context); @@ -2003,14 +1970,13 @@ pytorch.Execution = class extends python.Execution { if (index instanceof torch.Value) { node.addInput(index); } else if (Number.isInteger(index)) { - const value = this.constant(index); + const value = this._graph.insertConstant(index); node.addInput(value); } else { throw new pytorch.Error(`Unsupported tuple index type.`); } - const output = node.addOutput(); - output.setType(type.elements()[index]); - return output; + node.output().setType(type.elements()[index]); + return node.output(); } } } @@ -2020,46 +1986,41 @@ pytorch.Execution = class extends python.Execution { const target = this.target(expr.value, context); const attr = expr.attr; if (target instanceof torch.Value && target.type() instanceof torch.ClassType) { - const type = target.type().findAttribute(attr); - const node = this.graph.create('prim::GetAttr'); + const node = this._graph.createGetAttr(target, attr); this.graph.insertNode(node); - node.s_(attr); - node.addInput(target); - const value = node.addOutput(); - value.setType(type); - return value; + return node.output(); } return target[attr]; } case 'List': { const list = expr.elts.map((item) => this.expression(item, context)); if (/* list.length > 0 && */ list.every((item) => item instanceof torch.Value || pytorch.Utility.isTensor(item) || Number.isInteger(item) || typeof item === 'string' || item === null)) { - const node = this._graph.create('prim::ListConstruct'); - this.graph.insertNode(node); - const output = node.addOutput(); + const values = []; + let item_type = null; for (const item of list) { + let type = null; if (item instanceof torch.Value) { - node.addInput(item); - output.setType(torch.ListType.get(item.type())); - } else if (Number.isInteger(item)) { - const value = this.constant(item); - node.addInput(value); - output.setType(torch.ListType.get(torch.IntType.get())); - } else if (typeof item === 'string') { - const value = this.constant(item); - node.addInput(value); - output.setType(torch.ListType.get(torch.StringType.get())); + values.push(item); + type = item.type(); + } else if (Number.isInteger(item) || typeof item === 'string' || item === null) { + const value = this._graph.insertConstant(item); + values.push(value); + type = value.type(); } else if (pytorch.Utility.isTensor(item)) { const value = this.variable(item, null); - node.addInput(value); - output.setType(torch.ListType.get(torch.TensorType.get())); + values.push(value); + type = torch.TensorType.get(); } else { - const value = new torch.Value(node); - value.value = item; - node.addInput(value); + throw new pytorch.Error('Unsupported list item type.'); + } + if (!item_type || item_type.isSubtypeOf(type)) { + item_type = type; } } - return output; + const contained_type = typehint ? typehint.getElementType() : item_type; + const node = this._graph.createList(contained_type, values); + this.graph.insertNode(node); + return node.output(); } break; } @@ -2079,7 +2040,7 @@ pytorch.Execution = class extends python.Execution { node.addInput(value); types.push(value.type()); } else if (item === null || Number.isInteger(item) || typeof item === 'number' || typeof item === 'boolean' || typeof item === 'string') { - const value = this.constant(item); + const value = this._graph.insertConstant(item); node.addInput(value); types.push(value.type()); } else { @@ -2090,30 +2051,33 @@ pytorch.Execution = class extends python.Execution { } elements.push(item); } - const value = node.addOutput(); - value.setType(torch.TupleType.get(types)); - return value; + node.output().setType(torch.TupleType.get(types)); + return node.output(); } case 'Dict': { - const node = this._graph.create('prim::DictConstruct'); - this.graph.insertNode(node); + const keys = []; + const values = []; let keyType = null; let valueType = null; for (let i = 0; i < expr.keys.length; i++) { const key = this.expression(expr.keys[i], context); const keyValue = this.variable(key, null); - keyType = keyValue.type(); - node.addInput(keyValue); + if (!keyType || keyType.isSubtypeOf(keyValue.type())) { + keyType = keyValue.type(); + } + keys.push(keyValue); const value = this.expression(expr.values[i], context); const valueValue = this.variable(value, null); - valueType = valueValue.type(); - node.addInput(valueValue); - } - const output = node.addOutput(); - if (keyType && valueType) { - output.setType(torch.DictType.get(keyType, valueType)); + if (!valueType || valueType.isSubtypeOf(valueValue.type())) { + valueType = valueValue.type(); + } + values.push(valueValue); } - return output; + const key_type = typehint ? typehint.getKeyType() : keyType; + const value_type = typehint ? typehint.getValueType() : valueType; + const node = this._graph.createDict(key_type, value_type, keys, values); + this.graph.insertNode(node); + return node.output(); } default: { break; @@ -2367,11 +2331,14 @@ pytorch.Execution = class extends python.Execution { count.set(node, 1); } } - if (count.size > 0 && Array.from(count).every(([node, count]) => node.outputs().length === 1 && node.outputs()[0].uses().length <= count)) { + /* + if (count.size > 0 && + Array.from(count).every(([node, count]) => node.outputs().length === 1 && node.outputs()[0].uses().length <= count)) { for (const node of state) { node.destroy(); } } + */ if (test === true || test === false) { continue; } @@ -2429,7 +2396,7 @@ pytorch.Execution = class extends python.Execution { return value.type(); }; this.variables(condition, condition); - const node = this._graph.create('prim::If'); + const node = this._graph.create('prim::If', 0); node.setSourceRange(stmt.location); this.graph.insertNode(node); node.addInput(test); @@ -2490,7 +2457,7 @@ pytorch.Execution = class extends python.Execution { throw new pytorch.Error("Unsupported condition."); } if (stmt instanceof ast.For) { - const node = this._graph.create('prim::Loop'); + const node = this._graph.create('prim::Loop', 0); node.setSourceRange(stmt.location); this.graph.insertNode(node); const loop = stmt; @@ -2498,7 +2465,9 @@ pytorch.Execution = class extends python.Execution { const range = this.expression(loop.iter, context); const variable = loop.target; for (const current of range) { - this.statement({ type: '=', target: variable, expression: { type: 'number', value: current } }, context); + const constant = new ast.Constant(current); + const stmt = new ast.Assign(variable, constant); + this.statement(stmt, context); const value = this.block(loop.body.statements, context); if (value !== undefined) { return value; @@ -2509,7 +2478,7 @@ pytorch.Execution = class extends python.Execution { } } if (stmt instanceof ast.While) { - const node = this._graph.create('prim::Loop'); + const node = this._graph.create('prim::Loop', 0); node.setSourceRange(stmt.location); this.graph.insertNode(node); const test = this.expression(stmt.test, context); @@ -2571,7 +2540,7 @@ pytorch.Execution = class extends python.Execution { switch (expr.value.id) { case 'List': { const type = this.type(elts[0]); - return torch.ListType.get(type); + return torch.ListType.create(type); } case 'Optional': { const type = this.type(elts[0]); @@ -2584,7 +2553,7 @@ pytorch.Execution = class extends python.Execution { case 'Dict': { const key = this.type(elts[0]); const value = this.type(elts[1]); - return torch.DictType.get(key, value); + return torch.DictType.create(key, value); } case 'Final': { return this.type(elts[0]); @@ -2602,6 +2571,8 @@ pytorch.Execution = class extends python.Execution { case 'float': return torch.FloatType.get(); case 'number': return torch.NumberType.get(); case 'bool': return torch.BoolType.get(); + case 'list': return torch.Type.get('AnyListType'); + case 'tuple': return torch.Type.get('AnyTupleType'); case 'None': return torch.NoneType.get(); case 'NoneType': return torch.NoneType.get(); default: throw new pytorch.Error(`Unsupported type expression '${expr.value}'.`); @@ -2635,12 +2606,10 @@ pytorch.Execution = class extends python.Execution { if (identifier) { const type = this._resolver.resolveType(identifier); if (type) { - const node = this.graph.create('prim::CreateObject'); + const node = this.graph.createObject(type); node.setSourceRange(location); this.graph.insertNode(node); - const value = node.addOutput(); - value.setType(type); - return value; + return node.output(); } } } @@ -2650,7 +2619,7 @@ pytorch.Execution = class extends python.Execution { if (args.length === 0) { return obj; } - const node = this.graph.create('prim::CallMethod'); + const node = this.graph.create('prim::CallMethod', 0); node.setSourceRange(location); this.graph.insertNode(node); node.s_('name', name); @@ -2676,7 +2645,7 @@ pytorch.Execution = class extends python.Execution { const value = this.variable(arg); node.addInput(value); } - return node.addOutput(); + return node.output(); } const prefix = this.identifier(target); if (prefix && prefix !== 'self' && !prefix.startsWith('self.') && prefix.indexOf('.') !== -1) { @@ -2691,9 +2660,8 @@ pytorch.Execution = class extends python.Execution { const value = this.variable(arg); node.addInput(value); } - const output = node.addOutput(); - output.setType(type); - return output; + node.output().setType(type); + return node.output(); } if (type instanceof torch.ClassType) { const node = this.graph.create('prim::CallMethod'); @@ -2704,14 +2672,14 @@ pytorch.Execution = class extends python.Execution { const value = this.variable(arg); node.addInput(value); } - return node.addOutput(); + return node.output(); } } return super.call(target, name, args, context); } const [schema, evalArgs] = overload; const op = schema.overload_name ? `${schema.name}.${schema.overload_name}` : schema.name; - const node = this._graph.create(op); + const node = this._graph.create(op, 0); node.setSourceRange(location); this.graph.insertNode(node); const referencedParameters = []; @@ -2776,9 +2744,8 @@ pytorch.Execution = class extends python.Execution { value.setType(torch.TensorType.get()); list.addInput(value); } - const output = list.addOutput(); - output.setType(torch.ListType.get(torch.TensorType.get())); - input = output; + list.output().setType(torch.ListType.create(torch.TensorType.get())); + input = list.output(); match = true; } } else { @@ -2925,20 +2892,20 @@ pytorch.Execution = class extends python.Execution { if (!type) { throw new pytorch.Error(); } - type = torch.ListType.get(type); + type = torch.ListType.create(type); break; } default: { if (type instanceof torch.DictType) { const keyType = varTypes.map(type.getKeyType()); const valueType = varTypes.map(type.getValueType()); - type = torch.DictType.get(keyType, valueType); + type = torch.DictType.create(keyType, valueType); } else if (type instanceof torch.TupleType && type.elements().length === 2) { const elements = type.elements().map((type) => varTypes.map(type)); - type = torch.ListType.get(torch.TupleType.get(elements)); + type = torch.ListType.create(torch.TupleType.get(elements)); } else if (type instanceof torch.ListType && type.getElementType() instanceof torch.TupleType) { const elements = type.getElementType().elements().map((type) => varTypes.map(type)); - type = torch.ListType.get(torch.TupleType.get(elements)); + type = torch.ListType.create(torch.TupleType.get(elements)); } else { throw new pytorch.Error(`Unsupported return type '${type.str()}'.`); } @@ -3014,7 +2981,7 @@ pytorch.Execution = class extends python.Execution { (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.IntType)) || (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.ListType && obj.type().getElementType().getElementType() instanceof torch.IntType); case 'SymInt[1]': - return this.isType(obj, torch.IntType.get()) || this.isType(obj, torch.ListType.get(torch.IntType.get())); + return this.isType(obj, torch.IntType.get()) || this.isType(obj, torch.ListType.create(torch.IntType.get())); case 'float': { return obj !== null && (typeof obj === 'number' || obj instanceof Number) || (obj instanceof torch.Value && (obj.type() instanceof torch.FloatType || obj.type() instanceof torch.IntType)); } @@ -3129,9 +3096,9 @@ pytorch.Execution = class extends python.Execution { } else if (Number(value) === value) { return torch.FloatType.get(); } else if (Array.isArray(value) && value.every((item) => Number(item) === item && item % 1 === 0)) { - return torch.ListType.get(torch.IntType.get()); + return torch.ListType.create(torch.IntType.get()); } else if (Array.isArray(value) && value.every((item) => Number(item) === item)) { - return torch.ListType.get(torch.FloatType.get()); + return torch.ListType.create(torch.FloatType.get()); } else if (value instanceof torch.Value) { return value.type(); } @@ -3428,6 +3395,8 @@ pytorch.Utility = class { case 'Layout': return 'Layout'; case 'VarType': return type.annotation_str; case 'NoneType': return 'None'; + case 'AnyListType': return 'list'; + case 'AnyTupleType': return 'tuple'; default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`); } } diff --git a/test/models.json b/test/models.json index 1fd7133912..c6ecd7434a 100644 --- a/test/models.json +++ b/test/models.json @@ -5698,6 +5698,13 @@ "format": "PyTorch Package v1.9", "link": "https://github.com/lutzroeder/netron/issues/928" }, + { + "type": "pytorch", + "target": "m4-sWE-0.1B.script.pt", + "source": "https://github.com/user-attachments/files/17967188/m4-sWE-0.1B.script.pt.zip[m4-sWE-0.1B.script.pt]", + "format": "TorchScript v1.6", + "link": "https://github.com/lutzroeder/netron/issues/1061" + }, { "type": "pytorch", "target": "mask_depthwise_conv.pt", @@ -5720,6 +5727,13 @@ "assert": "model.graphs[0].nodes[0].inputs.length == 1", "link": "https://github.com/facebookresearch/kill-the-bits/tree/master/src/models/compressed" }, + { + "type": "pytorch", + "target": "mask_rcnn.pt", + "source": "https://github.com/user-attachments/files/17966950/mask_rcnn.pt.zip[mask_rcnn.pt]", + "format": "TorchScript v1.7", + "link": "https://github.com/lutzroeder/netron/issues/1061" + }, { "type": "pytorch", "target": "mcunet-5fps.pkl",