From 0e977822009738ff03484dbccc4a730c3ca1e872 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Fri, 10 Jan 2025 21:00:02 -0800 Subject: [PATCH] Update pytorch.js (#842) --- source/python.js | 428 ++++++++++++++++++++++++++++++++++++++++------ source/pytorch.js | 4 +- 2 files changed, 375 insertions(+), 57 deletions(-) diff --git a/source/python.js b/source/python.js index 1b0648d5d1..618530f9e8 100644 --- a/source/python.js +++ b/source/python.js @@ -213,9 +213,10 @@ python.Execution = class { } }); this.registerType('ast.Constant', class extends ast.expr { - constructor(value) { + constructor(value, type) { super(); this.value = value; + this.type = type || null; } }); this.registerType('ast.Ellipsis', class extends ast.Constant { @@ -1320,40 +1321,42 @@ python.Execution = class { } const literal = this._literal(); if (literal) { - if (stack.length > 0 && literal.type === 'number' && (literal.value.startsWith('-') || literal.value.startsWith('+'))) { + if (stack.length > 0 && + (literal.type === 'int' || literal.type === 'float' || literal.type === 'complex') && + (literal.value.startsWith('-') || literal.value.startsWith('+'))) { const op = literal.value < 0 ? new ast.Sub() : new ast.Add(); const left = stack.pop(); const right = new ast.Constant(Math.abs(literal.value)); node = new ast.BinOp(left, op, right); stack.push(node); - } else if (stack.length === 1 && literal.type === 'string' && stack[0] instanceof ast.Constant && typeof stack[0].value === 'string') { + } else if (stack.length === 1 && literal.type === 'str' && stack[0] instanceof ast.Constant && typeof stack[0].value === 'string') { stack[0].value += literal.value.substring(1, literal.value.length - 1); } else { let value = literal.value; - if (literal.type === 'number') { + if (literal.type === 'int' || literal.type === 'float' || literal.type === 'complex') { switch (value) { case 'inf': value = Infinity; break; case '-inf': value = -Infinity; break; default: value = Number(value); break; } - } else if (literal.type === 'string') { + } else if (literal.type === 'str') { value = literal.value.substring(1, literal.value.length - 1); } else { throw new python.Error(`Invalid literal ${this._location()}`); } - const node = new ast.Constant(value); + const node = new ast.Constant(value, literal.type); stack.push(node); } continue; } if (this._eat('id', 'False')) { - const node = new ast.Constant(false); + const node = new ast.Constant(false, 'bool'); this._mark(node, position); stack.push(node); continue; } if (this._eat('id', 'True')) { - const node = new ast.Constant(true); + const node = new ast.Constant(true, 'bool'); this._mark(node, position); stack.push(node); continue; @@ -1515,7 +1518,7 @@ python.Execution = class { } _literal() { const token = this._tokenizer.peek(); - if (token.type === 'string' || token.type === 'number' || token.type === 'boolean') { + if (token.type === 'str' || token.type === 'bool' || token.type === 'int' || token.type === 'float' || token.type === 'complex') { this._tokenizer.read(); return token; } @@ -2019,7 +2022,7 @@ python.Execution = class { const radixText = this._text.substring(this._position, i); const radixParseText = radixText.indexOf('_') === -1 ? radixText : radixText.split('_').join(''); if (!isNaN(parseInt(radixParseText, radix))) { - return { type: 'number', value: radixText }; + return { type: 'int', value: radixText }; } } } @@ -2039,11 +2042,11 @@ python.Execution = class { } if (isDecimal) { if (this._get(i) === 'j' || this._get(i) === 'J' || this._get(i) === 'l' || this._get(i) === 'L') { - return { 'type': 'number', value: this._text.substring(this._position, i + 1) }; + return { 'type': 'complex', value: this._text.substring(this._position, i + 1) }; } const intText = this._text.substring(this._position, i); if (!isNaN(parseInt(intText, 10))) { - return { type: 'number', value: intText }; + return { type: 'int', value: intText }; } } i = this._position + sign; @@ -2079,12 +2082,12 @@ python.Execution = class { } if (i > (this._position + sign)) { if (this._get(i) === 'j' || this._get(i) === 'J') { - return { type: 'number', value: this._text.substring(this._position, i + 1) }; + return { type: 'complex', value: this._text.substring(this._position, i + 1) }; } const floatText = this._text.substring(this._position, i); const floatParseText = floatText.indexOf('_') === -1 ? floatText : floatText.split('_').join(''); if (!isNaN(parseFloat(floatParseText))) { - return { type: 'number', value: floatText }; + return { type: 'float', value: floatText }; } } } @@ -2215,7 +2218,7 @@ python.Execution = class { if (count === 1) { while (i < this._text.length) { if (this._text[i] === quote) { - return { type: 'string', value: this._text.substring(this._position, i + 1) }; + return { type: 'str', value: this._text.substring(this._position, i + 1) }; } else if (this._text[i] === '\\' && (this._get(i + 1) === quote || this._get(i + 1) === '\n' || this._get(i + 1) === '\\')) { i += 2; @@ -2228,7 +2231,7 @@ python.Execution = class { } else if (count === 3) { while (i < this._text.length) { if (this._get(i) === quote && this._get(i + 1) === quote && this._get(i + 2) === quote) { - return { type: 'string', value: this._text.substring(this._position, i + 3) }; + return { type: 'str', value: this._text.substring(this._position, i + 3) }; } else if (this._get(i) === '\\' && this._get(i + 1) === quote) { i += 2; continue; @@ -2242,7 +2245,7 @@ python.Execution = class { i++; while (i < this._text.length) { if (this._text[i] === '`') { - return { type: 'string', value: this._text.substring(this._position, i + 1) }; + return { type: 'str', value: this._text.substring(this._position, i + 1) }; } i++; } @@ -5236,23 +5239,20 @@ python.Execution = class { return this.mark(node); } if (outerNode.kind() === 'prim::Loop' || outerNode.kind() === 'c10::onnx::Loop') { - throw new python.Error('Not implemented.'); - /* const loop = new torch._C.LoopView(outerNode); - for (const auto i : c10::irange(loop.carriedOutputs().size())) { - if (outerNode.kind() == c10::onnx::Loop) { + for (let i = 0; i < loop.carriedOutputs().length; i++) { + if (outerNode.kind() === 'onnx::Loop') { this._liveValues.add(loop.bodyCarriedOutputs()[i]); continue; } - auto innerInput = loop.bodyCarriedInputs()[i]; - auto innerOutput = loop.bodyCarriedOutputs()[i]; - auto outerOutput = loop.carriedOutputs()[i]; - if (liveValues_.count(outerOutput) || innerInput->hasUses()) { + const innerInput = loop.bodyCarriedInputs()[i]; + const innerOutput = loop.bodyCarriedOutputs()[i]; + const outerOutput = loop.carriedOutputs()[i]; + if (this._liveValues.has(outerOutput) || innerInput.hasUses()) { this._liveValues.add(innerOutput); } } this._liveValues.add(loop.nextCond()); - */ } else { torch._C.AT_ASSERT(outerNode.outputs().length === node.inputs().length); for (let i = 0; i < outerNode.outputs().length; i++) { @@ -5266,6 +5266,16 @@ python.Execution = class { this._marked.add(node); return true; } + markLoop(node) { + torch._C.TORCH_INTERNAL_ASSERT(node.kind() === 'prim::Loop'); + let marked = false; + let anyMarked = false; + do { + marked = this.mark(node.blocks()[0]); + anyMarked = anyMarked || marked; + } while (marked); + return anyMarked; + } mark(...args) { if (args.length === 1 && args[0] instanceof torch.Block) { const [block] = args; @@ -7842,6 +7852,7 @@ python.Execution = class { this._is_module = is_module; this._attributes = []; this._attributeTypes = []; + this._properties = []; this._methods = new Map(); this._staticmethods = new Map(); this._constants = new Map(); @@ -7877,6 +7888,9 @@ python.Execution = class { } return method; } + methods() { + throw new python.Error('Not implemented.'); + } addStaticMethod(func) { this._staticmethods.set(func.name, func); } @@ -7931,14 +7945,19 @@ python.Execution = class { getAttributeName(slot) { return this._attributes[slot].name; } - hasConstant(/* name */) { - } - methods() { - throw new python.Error('Not implemented.'); - } addConstant(name, value) { this._constants.set(name, value); } + hasConstant(/* name */) { + } + getProperty(name) { + for (const prop of this._properties) { + if (name === prop.name) { + return prop; + } + } + return null; + } containedTypes() { return this._attributeTypes; } @@ -8489,23 +8508,28 @@ python.Execution = class { this.registerType('torch.DictType', class extends torch.Type { constructor(key, value) { super('DictType'); - this._key = key; - this._value = value; + this.types = [key, value]; } static create(key, value) { return new torch.DictType(key, value); } getKeyType() { - return this._key; + return this.types[0]; } getValueType() { - return this._value; + return this.types[1]; } hasFreeVariables() { return this.getKeyType().hasFreeVariables() || this.getValueType().hasFreeVariables(); } + createWithContained(contained_types) { + if (contained_types.length !== 2) { + throw new python.Error('Expected 2 contained types.'); + } + return torch.DictType.create(contained_types[0], contained_types[1]); + } containedTypes() { - throw new python.Error('Not implemented.'); + return this.types; } str() { return `Dict(${this.getKeyType().str()}, ${this.getValueType().str()})`; @@ -10634,10 +10658,11 @@ python.Execution = class { } }); this.registerType('torch._C.IValue', class { - constructor(value) { + constructor(value, tag) { this.value = value; - this.tag = 'None'; - if (typeof value === 'boolean') { + if (tag) { + this.tag = tag; + } else if (typeof value === 'boolean') { this.tag = 'Bool'; } else if (typeof value === 'string') { this.tag = 'String'; @@ -10661,6 +10686,12 @@ python.Execution = class { toTensor() { return this.value; } + isDouble() { + return this.tag === 'Double'; + } + toDouble() { + return this.value; + } isInt() { return this.tag === 'Int'; } @@ -10717,6 +10748,15 @@ python.Execution = class { return this._filename; } }); + this.registerType('torch._C.SourceRange', class { + constructor(node) { + this._node = node; + } + toString() { + const n = this._node; + return `${n.filename}:${n.lineno}:${n.col_offset}-${n.end_lineno}:${n.end_col_offset}`; + } + }); this.registerType('torch._C.QualifiedName', class { constructor(...args) { let name = null; @@ -10756,6 +10796,10 @@ python.Execution = class { this.registerType('torch._C.SourceImporter', class extends torch._C.Resolver { constructor(cu, constant_table, source_loader, version) { super(); + ast.AST.prototype.range = function() { + this._range = this._range || new torch._C.SourceRange(this); + return this._range; + }; this._cu = cu; this._constant_table = constant_table; this._source_loader = source_loader; @@ -11417,8 +11461,8 @@ python.Execution = class { if (!simple_parent) { throw new python.Error('Only reassignments to first-class values are allowed.'); } - const parent_type = this.unshapedType(simple_parent.type()); - as_simple_value = this.tryConvertToType(loc, this.b.owningGraph(), parent_type, as_simple_value, /*allow_conversions=*/true); + const parent_type = torch._C.unshapedType(simple_parent.type()); + as_simple_value = torch._C.tryConvertToType(loc, this.b.owningGraph(), parent_type, as_simple_value, /*allow_conversions=*/true); if (!as_simple_value.type().isSubtypeOf(parent_type)) { throw new python.Error('Incompatible types.'); } @@ -12126,7 +12170,7 @@ python.Execution = class { } } if (variants.length === 0) { - const oldSchemas = torch._C.loadPossibleHistoricOps(name.toQualString(), graph_version); + const oldSchemas = torch._C.loadPossibleHistoricOps(name, graph_version); upgrader_schemas.reserve(oldSchemas.size()); for (const old_schema_entry of oldSchemas) { const old_schema = torch._C.parseSchema(old_schema_entry); @@ -12266,9 +12310,7 @@ python.Execution = class { throw new python.Error('Unsupported value kind.'); }); this.registerFunction('torch._C.tryInsertConstant', (g, val, loc, scope) => { - const ivalue = false; - if (ivalue) { - val = new torch._C.IValue(val); // remove + if (val instanceof torch._C.IValue) { const n = g.create('prim::Constant'); if (val.isTensor()) { const ref = val.toTensor(); @@ -12365,10 +12407,10 @@ python.Execution = class { } else if (typeof val === 'boolean') { n.i_('value', val === true ? 1 : 0); type = torch.BoolType.get(); - } else if (Number.isInteger(val)) { + } else if ((!val.type && Number.isInteger(val)) || val.type === 'int') { n.i_('value', val); type = torch.IntType.get(); - } else if (typeof val === 'number') { + } else if ((!val.type && typeof val === 'number') || val.type === 'float') { n.f_('value', val); type = torch.FloatType.get(); } else if (val instanceof torch.Tensor) { @@ -12479,6 +12521,9 @@ python.Execution = class { } }); this.registerType('torch._C.SugaredValue', class { + shouldEmitUnrolled() { + return this.staticLen() !== null; + } }); this.registerType('torch._C.SimpleValue', class extends torch._C.SugaredValue { constructor(value) { @@ -12534,11 +12579,82 @@ python.Execution = class { if (prop) { return new torch._C.MethodValue(this._value, [prop.getter.name()]).call(loc, m, {}, {}, /*n_binders=*/1); } - } - if (this._value.type() instanceof torch.InterfaceType) { + } else if (this._value.type() instanceof torch.InterfaceType) { throw new python.Error('Not implemented.'); + } /* else if (this._value.type() instanceof torch.EnumType) { + const g = m.graph(); + if (field == "name") { + const n = g.insertNode(g.createEnumName(value_)); + return std::make_shared(n->output()); + } + if (field == "value") { + auto n = g.insertNode(g.createEnumValue(value_)); + return std::make_shared(n->output()); + } + } */ + if (field === 'type') { + const builtin = torch._C.BuiltinFunction.tryCreate('aten::to', new torch._C.NamedValue(loc, 'self', this._value)); + if (builtin) { + return builtin; + } } - throw new python.Error('Not implemented.'); + const builtin = torch.BuiltinFunction.tryCreate(`aten::${field}`, new torch._C.NamedValue(loc, 'self', this._value)); + if (builtin) { + return builtin; + } + if (this._value.type().isSubtypeOf(torch.TensorType.get()) && field === 'tolist') { + return new torch._C.SpecialFormValue.create('prim::tolist'); + } + if (this._value.type().isSubtypeOf(torch.TensorType.get()) && field === '__getitem__') { + return new torch._C.SpecialFormValue.create('aten::index'); + } + if (this._value.type() instanceof torch._C.GeneratorType && (field === 'manual_seed' || field === 'initial_seed' || field === 'seed')) { + const builtin = torch._C.BuiltinFunction.tryCreate(`aten::${field}`, new torch._C.NamedValue(loc, "self", this._value)); + if (builtin) { + return builtin; + } + } + throw new python.Error('Object has no attribute or method.'); + } + setAttr(loc, m, field, newValue) { + const type = this._value.type(); + if (type instanceof torch.ClassType === false) { + throw new python.Error('Cannot set attribute on non-class type.'); + } + const classType = type; + let expectedType = classType.findAttribute(field); + if (!expectedType) { + const isInitializing = m.name() === '__init__' && m.graph().inputs().length > 0 && m.graph().inputs()[0].type() === classType; + if (isInitializing) { + if (this.isRecursive(classType, newValue.type())) { + throw new python.Error('Classes that recursively contain instances of themselves are not supported.'); + } + classType.addAttribute(field, newValue.type()); + expectedType = newValue.type(); + const insertPoint = m.graph().insertPoint(); + const topLevelBlock = m.graph().block(); + if (insertPoint.owningBlock() !== topLevelBlock) { + throw new python.Error('First assignment cannot be in a control-flow block.'); + } + } else { + const prop = classType.getProperty(field); + if (prop && prop.setter) { + new torch._C.MethodValue(this._value, prop.setter.name()).call(loc, m, [newValue], [], /*n_binders=*/1); + return; + } + if (prop && !prop.setter) { + throw new python.Error('Tried to set read-only attribute.'); + } + throw new python.Error('Tried to set nonexistent attribute.'); + } + } + torch._C.AT_ASSERT(expectedType); + const newType = newValue.type(); + if (!newType.isSubtypeOf(expectedType)) { + throw new python.Error('Wrong type for attribute assignment.'); + } + const g = m.graph(); + g.insertNode(g.createSetAttr(this._value, field, newValue)); } getitem(loc, m, idx, type_hint) { const val = this.getValue(); @@ -12751,6 +12867,29 @@ python.Execution = class { } this._static_len = static_len; } + staticLen() { + return this._static_len; + } + iter() { + return this; + } + len(loc, m) { + if (this._static_len) { + return torch._C.insertConstant(m.graph(), this._static_len, loc); + } + if (this._has_only_end) { + return this._end; + } + const g = m.graph(); + return g.insert('aten::__range_length', [this._start, this._end, this._step], [], loc); + } + getitem(loc, m, idx /*, type_hint */) { + if (this._has_only_end) { + return new torch._C.SimpleValue(idx); + } + const g = m.graph(); + return new torch._C.SimpleValue(g.insert('aten::__derive_index', [idx, this._start, this._step], [], loc)); + } }); this.registerType('torch._C.SliceValue', class extends torch._C.SugaredValue { }); @@ -13307,8 +13446,10 @@ python.Execution = class { this.resolver = _resolver; this.integral_constants = new Map(); this.fp_constants = new Map(); + this.complex_constants = new Map(); this.exit_blocks = new Set(); this._typeParser = new torch._C.ScriptTypeParser(this.resolver); + this._loop_status = 'NOT_IN_LOOP'; this.environment_stack = null; this._def_stack = []; this._temp_name_count = 0; @@ -13400,6 +13541,10 @@ python.Execution = class { const stmt = stmts[i]; if (stmt instanceof ast.If) { this.emitIf(stmt); + } else if (stmt instanceof ast.While) { + this.emitWhile(stmt); + } else if (stmt instanceof ast.For) { + this.emitFor(stmt); } else if (stmt instanceof ast.Assign) { this.emitAssignment(stmt); } else if (stmt instanceof ast.Expr) { @@ -13413,9 +13558,79 @@ python.Execution = class { } } } + emitLoopCommon(range, emit_body, iter_val, targets, cond) { + let max_trip_count_val = null; + if (iter_val === null) { + max_trip_count_val = torch._C.materializeConstant(Number.MAX_SAFE_INTEGER /*std::numeric_limits::max()*/, this.graph, range, this.integral_constants); + } else { + max_trip_count_val = iter_val.len(range, this.method); + } + const n = this.graph.insertNode(this.create('prim::Loop', range, 0)); + const body_block = n.addBlock(); + { + const condition_block = n.addBlock(); + this.pushFrame(condition_block); + let out = null; + if (cond) { + const insert = new torch._C.WithInsertPoint(condition_block); + out = this.emitToBool(cond.range(), this.emitExpr(cond.value())); + insert.dispose(); + } else { + const insert = new torch._C.WithInsertPoint(n); + out = this.graph.insertConstant(true, range); + insert.dispose(); + } + condition_block.registerOutput(out); + this.popFrame(); + } + n.addInput(max_trip_count_val); + const loop_guard = new torch._C.WithLoopStatus(this, 'IN_LOOP'); + const trip_count = body_block.addInput().setType(torch.IntType.get()); + { + this.pushFrame(body_block); + const guard = new torch._C.WithInsertPoint(body_block); + if (iter_val !== null && targets) { + const cur_elem = iter_val.getitem(range, this.method, trip_count).asValue(range, this.method); + const sv = new torch._C.SimpleValue(cur_elem); + const target_exprs = targets; + this.validateAssignLhsExpr(target_exprs, range); + if (target_exprs.length > 1) { + throw new python.Error('Not implemented.'); + // const tl = torch.TupleLiteral.create(range, target_exprs); + // target_exprs = ListExpr.create(range, [tl]); + } + this.emitExprsAssign(target_exprs, [sv], range, /*n_binders=*/1); + } + emit_body(); + this.popFrame(); + guard.dispose(); + } + loop_guard.dispose(); + } + emitFor(...args) { + if (args.length === 1 && args[0] instanceof ast.For) { + const [stmt] = args; + const emit_body = () => this.emitStatements(stmt.body); + this.emitFor(stmt.target, stmt.iter, stmt.range(), emit_body); + } else if (args.length === 4) { + const [targets, itrs, loc, emit_body] = args; + if (itrs instanceof ast.Tuple) { + throw new python.Error('List of iterables is not supported currently.'); + } + const sv = this.emitSugaredExpr(itrs, 1); + const iterable = sv.iter(loc, this.method); + if (iterable.shouldEmitUnrolled()) { + this.emitUnrolledLoop(loc, emit_body, iterable, targets); + } else { + this.emitLoopCommon(loc, emit_body, iterable, [targets], null); + } + } else { + throw new python.Error('Not implemented.'); + } + } emitIf(stmt) { const cond_value = this.emitCondExpr(stmt.test); - this.emitIfElseBlocks(stmt, cond_value, stmt.body, stmt.orelse); + this.emitIfElseBlocks(stmt.range(), cond_value, stmt.body, stmt.orelse); } emitCondExpr(expr) { /* @@ -13682,10 +13897,24 @@ python.Execution = class { this.environment_stack.setSugaredVar(stmt, lhs.id, rhs_sugared_val, /*annotated_type=*/type); } else if (lhs instanceof ast.Tuple) { this.emitTupleAssign(lhs, rhs); + } else if (lhs instanceof ast.Attribute) { + this.emitSelectAssign(lhs, rhs, null, stmt.range()); } else { throw new python.Error('Unexpected expression on left-hand side of assignment.'); } } + emitSelectAssign(lhs, rhs, type, loc) { + if (!rhs) { + throw new python.Error('Expected RHS for assignment.'); + } + let type_hint = null; + if (type) { + type_hint = this._typeParser.parseTypeFromExpr(type); + } + const lhsObject = this.emitSugaredExpr(lhs.value, 1); + const rhsValue = this.emitSugaredExpr(rhs, 1, type_hint).asValue(rhs.range(), this.method); + lhsObject.setAttr(loc, this.method, lhs.attr, rhsValue); + } emitTupleAssign(...args) { if (args.length === 2) { const [tl, rhs] = args; @@ -13967,10 +14196,12 @@ python.Execution = class { return torch._C.insertConstant(this.graph, c.value, c); } emitConst(c) { - if (Number.isInteger(c.value)) { - return torch._C.materializeConstant(c.value, this.graph, c, this.integral_constants); - } else if (typeof c.value === 'number') { - return torch._C.materializeConstant(c.value, this.graph, c, this.fp_constants); + if (c.type === 'int') { + return torch._C.materializeConstant(new torch._C.IValue(c.value, 'Int'), this.graph, c.range(), this.integral_constants); + } else if (c.type === 'complex') { + return torch._C.materializeConstant(new torch._C.IValue(c.value, 'Complex'), this.graph, c.range(), this.complex_constants); + } else if (c.type === 'float') { + return torch._C.materializeConstant(new torch._C.IValue(c.value, 'Double'), this.graph, c.range(), this.fp_constants); } throw new python.Error(`Unsupported constant type.`); } @@ -14327,6 +14558,10 @@ python.Execution = class { this.environment_stack = this.environment_stack.next; return old_frame; } + addBlockInput(b, type, name) { + const g = b.owningGraph(); + g.createStore(name, b.addInput(name).setType(type)).insertAfter(b.param_node()); + } addBlockOutput(exit_block, type, name) { const insert = new torch._C.WithInsertPoint(exit_block); const g = exit_block.owningGraph(); @@ -14342,6 +14577,11 @@ python.Execution = class { const g = n.owningGraph(); g.createStore(name, out).insertAfter(n); } + addNodeInput(n, type, name) { + const g = n.owningGraph(); + const inp = g.createLoad(name, type).insertBefore(n).output(); + n.addInput(inp); + } addIfLoadStores(n) { const [true_block, false_block] = n.blocks(); const true_vars = this.addControlFlowLoadStores(true_block); @@ -14366,6 +14606,22 @@ python.Execution = class { this.addNodeOutput(n, unified, x); } } + addLoopLoadStores(n) { + const [body_block] = n.blocks(); + const loop_vars = this.addControlFlowLoadStores(body_block); + for (const name of loop_vars.definedVariables()) { + const parent_type = this.environment_stack.findInAnyFrame(name); + if (!parent_type) { + continue; + } + const block_type = loop_vars.findInThisFrame(name); + const unified_type = torch._C.unifyTypes(parent_type, block_type); + this.addNodeInput(n, parent_type, name); + this.addBlockInput(body_block, unified_type, name); + this.addBlockOutput(body_block, block_type, name); + this.addNodeOutput(n, unified_type, name); + } + } addControlFlowLoadStores(block) { this.pushFrame(block); for (const n of block.nodes()) { @@ -14407,6 +14663,37 @@ python.Execution = class { run(/* graph */) { } }); + this.registerType('torch._C.LoopView', class { + constructor(node) { + torch._C.AT_ASSERT(node.kind() === 'prim::Loop' || node.kind() === 'onnx::Loop'); + this._node = node; + } + bodyBlock() { + return this._node.blocks()[0]; + } + nextCond() { + return this.bodyBlock().outputs()[0]; + } + carriedOutputs() { + return this._node.outputs(); + } + bodyCarriedInputs() { + return this.bodyBlock().inputs().slice(1); + } + bodyCarriedOutputs() { + return this.bodyBlock().outputs().slice(1); + } + }); + this.registerType('torch._C.WithLoopStatus', class { + constructor(to_ir, new_status) { + this._to_ir = to_ir; + this._prev = this._to_ir._loop_status; + this._to_ir._loop_status = new_status; + } + dispose() { + this._to_ir._loop_status = this._prev; + } + }); this.registerFunction('torch._C.InlineLoopCondition', (/* graph */) => { }); this.registerType('torch._C.EraseLoadStores', class { @@ -14543,6 +14830,37 @@ python.Execution = class { this._target_block = block; } } + transformLoop(node) { + const loop = new torch._C.LoopView(node); + const body = loop.bodyBlock(); + const exit_pair = this.transformExits(body); + if (this.getExitStatus(exit_pair) === 'WONT' || this.getExitStatus(exit_pair) === 'THROWS') { + return this.constructWontExitPair(); + } + const insert = new torch._C.WithInsertPoint(body); + const new_if = this._graph.insertNode(this._graph.create('prim::If', 0)); + new_if.addInput(exit_pair.hasExited()); + new_if.addBlock().registerOutput(this._false_val); + new_if.addBlock().registerOutput(loop.nextCond()); + const new_condition = new_if.addOutput().setType(torch.BoolType.get()); + loop.bodyBlock().eraseOutput(0); + loop.bodyBlock().insertOutput(0, new_condition); + node.addInput(this._false_val); + body.addInput().setType(torch.BoolType.get()); + body.registerOutput(exit_pair.hasExited()); + const new_has_exited = node.addOutput().setType(torch.BoolType.get()); + for (const exit_value of exit_pair.exitValues()) { + const typ = exit_value.type(); + node.addInput(this.getUnitValue(typ)); + node.addOutput().setType(typ); + body.addInput().setType(typ); + body.registerOutput(exit_value); + } + const exit_vals = node.outputs().slice(node.outputs().size() - exit_pair.exitValues().size()); + const result = new torch._C.ExitPair(new_has_exited, exit_vals); + insert.dispose(); + return result; + } calcIfExitStatus(then_status, else_status) { if (then_status === 'THROWS') { return else_status; diff --git a/source/pytorch.js b/source/pytorch.js index 77a671d9cf..43926b67e2 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -555,7 +555,7 @@ pytorch.Node = class { } const sourceRange = node.sourceRange(); if (sourceRange) { - this.metadata.push(new pytorch.Argument('source', sourceRange.replace(/^at\s/, '').replace(/\.$/, ''))); + this.metadata.push(new pytorch.Argument('source', sourceRange.toString().replace(/^at\s/, '').replace(/\.$/, ''))); } } else if (torch && obj instanceof torch.fx.node.Node) { if (obj.op === 'call_function') { @@ -1533,7 +1533,7 @@ pytorch.Execution = class extends python.Execution { constructor(sources, metadata) { super(sources); - this.to_ir = false; + this.to_ir = true; this._metadata = metadata; const execution = this; const torch = this.torch;