From 8fc84002b2b710dd558c0df17053bedd5c7e9b21 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Mon, 4 Nov 2024 20:02:57 -0800 Subject: [PATCH] Update pytorch.js (#1061) --- source/pytorch.js | 91 ++++++++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 41 deletions(-) diff --git a/source/pytorch.js b/source/pytorch.js index a44c2d29cf..40933bbad1 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1899,6 +1899,7 @@ pytorch.Execution = class extends python.Execution { if (expression.target.type === 'id' && expression.target.value === 'uninitialized') { const type = this.type(expression.args[0]); const node = this._graph.create('prim::Uninitialized'); + node.setSourceRange(expression.location); this.graph.insertNode(node); const value = node.addOutput(); value.setType(type); @@ -2159,9 +2160,11 @@ pytorch.Execution = class extends python.Execution { const input = node.inputs()[0].node(); if (input.kind() === 'prim::TupleConstruct') { const value = input.inputs()[index]; - const node = value.node(); - if (node.kind() === 'prim::Constant') { - return pytorch.Utility.constant(node, 'value'); + const constant = value.node(); + if (constant.kind() === 'prim::Constant') { + state.push(node); + state.push(constant); + return pytorch.Utility.constant(constant, 'value'); } } } @@ -2325,8 +2328,8 @@ pytorch.Execution = class extends python.Execution { block(statements, context) { const torch = this.torch; statements = Array.prototype.slice.call(statements); - while (statements.length > 0) { - if (statements.length > 1) { + for (let i = 0; i < statements.length;) { + if (i < statements.length - 1) { const containsVariableReference = (statements, value) => { if (statements) { for (const statement of statements) { @@ -2340,30 +2343,45 @@ pytorch.Execution = class extends python.Execution { } return false; }; - const [assign, condition] = statements; + const assign = statements[i]; + const condition = statements[i + 1]; // _x = // if _x: // ... if (assign.type === '=' && condition.type === 'if' && assign.target.type === 'id' && condition.test.type === 'id' && assign.target.value === condition.test.value && - !containsVariableReference(statements.slice(2), condition.test.value) && - (!statements[1].body || !containsVariableReference(statements[1].body.statements), condition.test.value) && - (!statements[1].orelse || !containsVariableReference(statements[1].orelse.statements, condition.test.value))) { - statements.shift(); - statements[0] = { + !containsVariableReference(statements.slice(i + 2), condition.test.value) && + (!condition.body || !containsVariableReference(condition.body.statements), condition.test.value) && + (!condition.orelse || !containsVariableReference(condition.orelse.statements, condition.test.value))) { + statements.splice(i, 2, { + location: condition.location, type: 'if', test: assign.expression, body: condition.body, orelse: condition.orelse, - location: condition.location, - }; + }); } } - const [condition] = statements; + const condition = statements[i]; if (condition.type === 'if') { const state = []; let test = this.static(condition.test, context, state); + if (test === null) { + test = false; + } else if (typeof test === 'boolean') { + test = test === true; + } else if (Number.isInteger(test)) { + test = test !== 0; + } else if (typeof test === 'string') { + test = test && test.length > 0; + } + if (test === true) { + statements.splice(i, 1, ...condition.body.statements); + } else if (test === false) { + statements.splice(i, 1, ...condition.orelse.statements); + } + const count = new Map(); for (const node of state) { if (count.has(node)) { @@ -2377,33 +2395,20 @@ pytorch.Execution = class extends python.Execution { node.destroy(); } } - if (test === null) { - test = false; - } else if (typeof test === 'boolean') { - test = test === true; - } else if (Number.isInteger(test)) { - test = test !== 0; - } else if (typeof test === 'string') { - test = test && test.length > 0; - } - if (test === true) { - statements.shift(); - statements = condition.body.statements.concat(statements); - continue; - } - if (test === false) { - statements.shift(); - statements = condition.orelse.statements.concat(statements); + + if (test === true || test === false) { continue; } } - if (statements.length > 0) { - const statement = statements.shift(); + if (i < statements.length) { + const statement = statements[i]; if (statement.type === 'if') { - const test = this.expression(statement.test, context); + const condition = statement; + const test = this.expression(condition.test, context); if (test instanceof torch.Value && test.type() instanceof torch.BoolType) { const refs = new Set(); - for (const statement of statements) { + for (let j = i + 1; j < statements.length; j++) { + const statement = statements[j]; if (!statement.refs) { this.variables(statement, statement); } @@ -2446,7 +2451,7 @@ pytorch.Execution = class extends python.Execution { } return value.type(); }; - this.variables(statement, statement); + this.variables(condition, condition); const node = this._graph.create('prim::If'); node.setSourceRange(statement.location); this.graph.insertNode(node); @@ -2454,15 +2459,15 @@ pytorch.Execution = class extends python.Execution { const prev = this._graph.insertPoint(); const true_block = node.addBlock(); this._graph.setInsertPoint(true_block); - let vars = __variables(statement.body.statements.concat(statement.orelse.statements)); + let vars = __variables(condition.body.statements.concat(statement.orelse.statements)); vars = new Map(Array.from(vars).map((name) => [name, {}])); - this.block(statement.body.statements, context); + this.block(condition.body.statements, context); for (const [name, entry] of vars) { entry.body = context.get(name); } const false_block = node.addBlock(); this._graph.setInsertPoint(false_block); - this.block(statement.orelse.statements, context); + this.block(condition.orelse.statements, context); for (const [name, entry] of vars) { entry.orelse = context.get(name); } @@ -2502,6 +2507,7 @@ pytorch.Execution = class extends python.Execution { } value.setType(type); } + i++; continue; } throw new pytorch.Error("Unsupported condition."); @@ -2510,6 +2516,7 @@ pytorch.Execution = class extends python.Execution { if (value !== undefined) { return value; } + i++; } } return undefined; @@ -2601,6 +2608,7 @@ pytorch.Execution = class extends python.Execution { return super.call(target, name, args, context); } const torch = this.torch; + const builtins = this.builtins; if (name === '__new__') { const identifier = pytorch.Utility.target(target); if (identifier) { @@ -2773,7 +2781,7 @@ pytorch.Execution = class extends python.Execution { } else { const value = this.variable(v); value.value = v; - if (!value.type() && v instanceof this.builtins.dict) { + if (!value.type() && v instanceof builtins.dict) { value.setType(type); } input = value; @@ -2964,6 +2972,7 @@ pytorch.Execution = class extends python.Execution { isType(obj, type, N) { const torch = this.torch; + const builtins = this.builtins; switch (type.str()) { case 'Tensor': return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null || @@ -3105,7 +3114,7 @@ pytorch.Execution = class extends python.Execution { return true; } } - if (obj instanceof this.builtins.dict) { + if (obj instanceof builtins.dict) { return true; } return false;