From 33d9bccd633bfb376827609d124ae6083ccba7e2 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Mon, 13 Jan 2025 19:17:16 -0800 Subject: [PATCH] Update pytorch.js (#842) --- source/python.js | 322 ++++++++++++++++++++++++----------- source/pytorch-metadata.json | 27 ++- source/pytorch.js | 42 ++++- 3 files changed, 280 insertions(+), 111 deletions(-) diff --git a/source/python.js b/source/python.js index a264901e1d..b1017bc99f 100644 --- a/source/python.js +++ b/source/python.js @@ -1216,6 +1216,7 @@ python.Execution = class { const args = []; const keywords = []; this._tokenizer.expect('('); + let tuple = false; while (!this._tokenizer.eat(')')) { if (this._tokenizer.eat('\n')) { continue; @@ -1236,14 +1237,16 @@ python.Execution = class { } else { args.push(expr); } - if (!this._tokenizer.eat(',')) { + if (this._tokenizer.eat(',')) { + tuple = true; + } else { this._tokenizer.eat('\n'); this._tokenizer.expect(')'); break; } } if (stack.length === 0 && keywords.length === 0) { - if (args.length === 1) { + if (args.length === 1 && !tuple) { [node] = args; } else { node = new ast.Tuple(args); @@ -1262,8 +1265,10 @@ python.Execution = class { stack.push(this._expressions()); } else { const value = stack.pop(); + const position = this._position(); const slice = this._slice(); node = new ast.Subscript(value, slice); + this._mark(node, position); stack.push(node); } continue; @@ -1468,27 +1473,30 @@ python.Execution = class { const elts = []; let slice = [null, null, null]; let index = 0; + let valid = false; this._tokenizer.expect('['); - while (!this._tokenizer.eat(']')) { + while (true) { if (this._tokenizer.eat(':')) { index++; - } else if (this._tokenizer.peek().type !== ']') { - const expression = this._expression(); - if (expression === null) { - throw new python.Error(`Expected expression ${this._location()}`); + valid = true; + } else if (index > 2 || this._tokenizer.match(',') || this._tokenizer.match(']')) { + if (!valid || index > 2) { + throw new python.Error(`Invalid slice at ${this._location()}`); } - slice[index++] = expression; - } - if (index > 2 || this._tokenizer.match(',') || this._tokenizer.match(']')) { - if (index === 0) { - throw new python.Error(`Invalid slice.`); - } - elts.push(index === 1 ? slice[0] : new ast.Slice(slice[0], slice[1], slice[2])); + elts.push(index === 0 ? slice[0] : new ast.Slice(slice[0], slice[1], slice[2])); slice = [null, null, null]; index = 0; - if (!this._tokenizer.match(']')) { - this._tokenizer.expect(','); + if (this._tokenizer.eat(']')) { + break; } + this._tokenizer.expect(','); + } else { + const expression = this._expression(); + if (expression === null) { + throw new python.Error(`Expected expression ${this._location()}`); + } + slice[index] = expression; + valid = true; } } if (elts.length > 1) { @@ -5155,6 +5163,9 @@ python.Execution = class { }); this.registerType('torch._C.Self', class { }); + this.registerFunction('torch._C.toValues', (g, nvs) => { + return nvs.map((v) => v.value(g)); + }); this.registerType('torch._C.SimpleSelf', class extends torch._C.Self { constructor(classType) { super(); @@ -5168,7 +5179,7 @@ python.Execution = class { return this._classType; } }); - this.registerType('torch.jit.Function', class { + this.registerType('torch._C.Function', class { isGraphFunction() { return false; } @@ -5176,7 +5187,7 @@ python.Execution = class { return this.qualname().name(); } }); - this.registerType('torch.jit.BuiltinOpFunction', class extends torch.jit.Function { + this.registerType('torch._C.BuiltinOpFunction', class extends torch._C.Function { constructor(qualname, schema) { super(); this._name = qualname; @@ -5554,6 +5565,12 @@ python.Execution = class { this.inlineIfBody(n.blocks()[block_index]); this._made_change = true; } + replaceAndRemoveIfOutput(n, i, replacement) { + n.outputs()[i].replaceAllUsesWith(replacement); + n.eraseOutput(i); + n.blocks()[0].eraseOutput(i); + n.blocks()[1].eraseOutput(i); + } removeExtraIfOutputs(n) { torch._C.TORCH_CHECK(n.kind() === 'prim::If'); const [true_block, false_block] = n.blocks(); @@ -5897,12 +5914,12 @@ python.Execution = class { } torch._C.ConstantPooling(graph); }); - this.registerType('torch._C.GraphFunction', class extends torch.jit.Function { + this.registerType('torch._C.GraphFunction', class extends torch._C.Function { constructor(name, graph, function_creator, executor_execution_mode) { super(); this._name = name; this._graph = graph; - this._executor_execution_mode = executor_execution_mode; + this._executor_execution_mode = executor_execution_mode || null; this._function_creator = function_creator; this._force_no_amp = false; } @@ -8684,7 +8701,29 @@ python.Execution = class { this.types = [key, value]; } static create(key, value) { - return new torch.DictType(key, value); + let kind = key.kind(); + if (key instanceof torch._C.DynamicType) { + kind = key.dynamicKind(); + } + switch (kind) { + case 'AnyType': + case 'IntType': + case 'BoolType': + case 'FloatType': + case 'ComplexType': + case 'StringType': + case 'TensorType': + case 'DeviceObjType': + return new torch.DictType(key, value); + default: + throw new python.Error(`Invalid dict key type '${kind}'.`); + } + } + createWithContained(contained_types) { + if (contained_types.length !== 2) { + throw new python.Error('Expected 2 contained types.'); + } + return torch.DictType.create(contained_types[0], contained_types[1]); } getKeyType() { return this.types[0]; @@ -8695,15 +8734,15 @@ python.Execution = class { hasFreeVariables() { return this.getKeyType().hasFreeVariables() || this.getValueType().hasFreeVariables(); } - createWithContained(contained_types) { - if (contained_types.length !== 2) { - throw new python.Error('Expected 2 contained types.'); - } - return torch.DictType.create(contained_types[0], contained_types[1]); - } containedTypes() { return this.types; } + equals(rhs) { + if (rhs instanceof torch.DictType) { + return this.getKeyType().equals(rhs.getKeyType()) && this.getValueType().equals(rhs.getValueType()); + } + return false; + } str() { return `Dict(${this.getKeyType().str()}, ${this.getValueType().str()})`; } @@ -9558,6 +9597,8 @@ python.Execution = class { map.set('bool', torch.BoolType.get()); map.set('complex', torch.ComplexType.get()); map.set('str', torch.StringType.get()); + map.set('Device', torch.DeviceObjType.get()); + map.set('number', torch.NumberType.get()); map.set('None', torch.NoneType.get()); map.set('NoneType', torch.NoneType.get()); map.set('Any', torch.AnyType.get()); @@ -9637,7 +9678,7 @@ python.Execution = class { } } } - return this._resolver._cu.execution.type(expr); + throw new python.Error(`Unknown type name '${name}'.`); } parseBaseTypeName(expr) { if (expr instanceof ast.Name) { @@ -9935,6 +9976,19 @@ python.Execution = class { n.output().setType(output_type); return n; } + createTupleSlice(tup, beg, step_size, num_values) { + const new_vals = []; + const tt = tup.type().expect(torch.TupleType); + let i = beg; + for (let j = 0; j < num_values; j++) { + const idx = this.insertConstant(new torch._C.IValue(i, 'Int')); + const tupleIndex = this.insertNode(this.createTupleIndex(tup, idx, tt.elements()[i])); + new_vals.push(tupleIndex.output()); + i += step_size; + } + const n = this.createTuple(new_vals); + return n; + } createDict(key_type, value_type, keys, values) { if (keys.length !== values.length) { throw new python.Error('Invalid dictionary size.'); @@ -10261,20 +10315,16 @@ python.Execution = class { return this._kind; } schema() { - if (this._op === null) { - this._op = null; - const index = this._kind.indexOf('.'); - const name = index === -1 ? this._kind : this._kind.substring(0, index); - const overload_name = index === -1 ? '' : this._kind.substring(index + 1); - const candidates = torch._C.getAllOperatorsFor(name); - for (const candidate of candidates) { - if (candidate.schema().overload_name === overload_name) { - this._op = candidate; - break; - } - } + if (this._op) { + return this._op.schema(); } - return this._op ? this._op.schema() : null; + // Node::schema() throws while torch.Node.schema() does not. + const op = this.maybeOperator(); + if (op) { + return op.schema(); + } + return null; + // return this.getOperator().schema(); } hasNamedInput(name) { for (const argument of this.schema().arguments) { @@ -10345,7 +10395,7 @@ python.Execution = class { if (maybe) { return maybe; } - throw new python.Error('Operator not found.'); + throw new python.Error(`Schema not found for node '${this.kind()}'.`); } getOperation() { return this.getOperator().getOperation(this); @@ -10894,6 +10944,8 @@ python.Execution = class { this.tag = 'Tensor'; } else if (value instanceof torch.ScriptObject) { this.tag = 'Object'; + } else if (Array.isArray(value)) { + this.tag = 'GenericList'; } else { throw new python.Error('Unsupported type.'); } @@ -11260,6 +11312,7 @@ python.Execution = class { this._cu.define(qualified_classname, [], [], methods, method_resolvers, self, false, this._version); } importNamedTuple(qualified_name, named_tuple_def) { + const type_parser = new torch._C.ScriptTypeParser(this); const field_names = []; const field_types = []; const field_defaults = []; @@ -11267,10 +11320,13 @@ python.Execution = class { if (stmt instanceof ast.AnnAssign === false) { throw new python.Error('Unexpected statement in NamedTuple body.'); } + const assign = stmt; const target = this._cu.execution.identifier(stmt.target); - const annotation = this._cu.execution.type(stmt.annotation); + // const annotation = this._cu.execution.type(stmt.annotation); + const type = type_parser.parseTypeFromExpr(assign.annotation); field_names.push(target); - field_types.push(annotation); + // field_types.push(annotation); + field_types.push(type); } const tt = torch.TupleType.createNamed(qualified_name.qualifiedName(), field_names, field_types, field_defaults); this._cu.register_type(tt); @@ -11475,7 +11531,7 @@ python.Execution = class { for (const known_method of known_type.methods || []) { const schema = new torch.FunctionSchema(known_method); const name = new torch._C.QualifiedName(prefix, schema.name); - const fn = new torch.jit.BuiltinOpFunction(name, schema); + const fn = new torch._C.BuiltinOpFunction(name, schema); type.addMethod(fn); } if (known_type.attributes) { @@ -11831,14 +11887,15 @@ python.Execution = class { let retval = this.findInAnyFrame(ident); if (!retval) { torch._C.Environment.globals = torch._C.Environment.globals || new Map([ + ['print', new torch._C.PrintValue()], ['tuple', torch._C.SpecialFormValue.create('prim::TupleConstruct')], ['float', new torch._C.MagicMethod('__float__', new torch._C.CastValue(torch.FloatType.get(), 'aten::Float'))], ['int', new torch._C.MagicMethod('__int__', new torch._C.CastValue(torch.IntType.get(), 'aten::Int'))], ['bool', new torch._C.MagicMethod('__bool__', new torch._C.CastValue(torch.BoolType.get(), 'aten::Bool'))], ['str', new torch._C.MagicMethod('__str__', new torch._C.CastValue(torch.StringType.get(), 'aten::str'))], - ["getattr", torch._C.SpecialFormValue.create('prim::GetAttr')], - ["hasattr", torch._C.SpecialFormValue.create('prim::HasAttr')], - ["isinstance", torch._C.SpecialFormValue.create('prim::isinstance')], + ['getattr', torch._C.SpecialFormValue.create('prim::GetAttr')], + ['hasattr', torch._C.SpecialFormValue.create('prim::HasAttr')], + ['isinstance', torch._C.SpecialFormValue.create('prim::isinstance')], ['range', torch._C.SpecialFormValue.create('prim::range')], ]); if (torch._C.Environment.globals.has(ident)) { @@ -11997,38 +12054,38 @@ python.Execution = class { /* for (const scalar of ['float', 'int', 'complex']) { const env = new torch.C.TemplateEnv(); - env.s("Scalar", scalar); - this.loadSource(scalar_operators_source.format(env), "aten"); + env.s('Scalar', scalar); + this.loadSource(scalar_operators_source.format(env), 'aten'); } - for (auto scalar : {"float", "int"}) { + for (auto scalar : {'float', 'int'}) { const env = new torch.C.TemplateEnv(); - env.s("Scalar", scalar); - loadSource(scalar_operators_no_complex_source.format(env), "aten"); + env.s('Scalar', scalar); + loadSource(scalar_operators_no_complex_source.format(env), 'aten'); } using str_pair = std::pair; const std::vector name_len = { - str_pair("single", "1"), - str_pair("pair", "2"), - str_pair("triple", "3"), - str_pair("quadruple", "4"), + str_pair('single', '1'), + str_pair('pair', '2'), + str_pair('triple', '3'), + str_pair('quadruple', '4'), }; - for (const auto scalar : {"float", "int"}) { + for (const auto scalar : {'float', 'int'}) { for (const auto& pair : name_len) { const env = new torch.C.TemplateEnv(); - env.s("Scalar", scalar); - env.s("name", pair.first); - env.s("Length", pair.second); - this.loadSource(_ntuple_ops.format(env), "aten"); + env.s('Scalar', scalar); + env.s('name', pair.first); + env.s('Length', pair.second); + this.loadSource(_ntuple_ops.format(env), 'aten'); } } - for (auto rhs : {"number", "Tensor"}) { + for (auto rhs : {'number', 'Tensor'}) { at::jit::TemplateEnv env; - env.s("Rhs_Type", rhs); - this.loadSource(floordiv.format(env), "aten"); + env.s('Rhs_Type', rhs); + this.loadSource(floordiv.format(env), 'aten'); } - this.loadSource(aten_ops, "aten"); - this.loadSource(aten_ops_additional, "aten"); - this.loadSource(tensor_properties, "prim"); + this.loadSource(aten_ops, 'aten'); + this.loadSource(aten_ops_additional, 'aten'); + this.loadSource(tensor_properties, 'prim'); */ } loadSource(/* source, the_namespace */) { @@ -12493,7 +12550,6 @@ python.Execution = class { if (result) { return result; } - args[0].value().type().isSubtypeOf(schema.arguments[0].type); throw new python.Error(`No matching schema '${schema.name}' found.`); }); this.registerFunction('torch._C.matchSchemas', (schemas, loc, graph, args, kwargs, self, render_errors) => { @@ -12793,6 +12849,9 @@ python.Execution = class { } else if (val instanceof torch.ScriptObject) { n.ival_('value', val); type = val.type(); + } else if (Array.isArray(val) && val.every((item) => Number.isInteger(item))) { + n.ival_('value', val); + type = torch.ListType.create(torch.IntType.get()); } else { throw new python.Error(`Unsupported value type '${typeof val}'.`); } @@ -12961,11 +13020,11 @@ python.Execution = class { throw new python.Error('Not implemented.'); } /* else if (this._value.type() instanceof torch.EnumType) { const g = m.graph(); - if (field == "name") { + if (field == 'name') { const n = g.insertNode(g.createEnumName(value_)); return std::make_shared(n->output()); } - if (field == "value") { + if (field == 'value') { const n = g.insertNode(g.createEnumValue(value_)); return std::make_shared(n->output()); } @@ -12987,7 +13046,7 @@ python.Execution = class { return torch._C.SpecialFormValue.create('aten::index'); } if (this._value.type() instanceof torch._C._GeneratorType && (field === 'manual_seed' || field === 'initial_seed' || field === 'seed')) { - const builtin = torch._C.BuiltinFunction.tryCreate(`aten::${field}`, new torch._C.NamedValue(loc, "self", this._value)); + const builtin = torch._C.BuiltinFunction.tryCreate(`aten::${field}`, new torch._C.NamedValue(loc, 'self', this._value)); if (builtin) { return builtin; } @@ -13119,7 +13178,7 @@ python.Execution = class { this.registerType('torch._C.FunctionValue', class extends torch._C.SugaredValue { constructor(...args) { super(); - if (args.length === 1 && args[0] instanceof torch.jit.Function) { + if (args.length === 1 && args[0] instanceof torch._C.Function) { this._callees = [args[0]]; } else { throw new python.Error('Not implemented.'); @@ -13137,6 +13196,19 @@ python.Execution = class { return new torch._C.SimpleValue(output); } }); + this.registerType('torch._C.NoneValue', class extends torch._C.SugaredValue { + }); + this.registerType('torch._C.PrintValue', class extends torch._C.SugaredValue { + call(loc, m, args, kwargs /*, n_binders */) { + const g = m.graph(); + if (kwargs.length > 0) { + throw new python.Error(`Print doesn't accept any keyword arguments at ${loc}.`); + } + const lowered_inputs = torch._C.toValues(m.graph(), args); + g.insertNode(g.create('prim::Print', lowered_inputs, 0).setSourceRange(loc)); + return new torch._C.NoneValue(); + } + }); this.registerType('torch._C.SpecialFormValue', class extends torch._C.SugaredValue { constructor(form) { super(); @@ -13444,7 +13516,7 @@ python.Execution = class { } } return null; - } else if (value instanceof torch.jit.Function) { + } else if (value instanceof torch._C.Function) { const fn = value; if (!fn.isGraphFunction()) { return null; @@ -13459,6 +13531,40 @@ python.Execution = class { this._instance_name = instance_name; } }); + this.registerFunction('torch._C.slice_indices_adjust', (length, start, stop, step) => { + torch._C.TORCH_CHECK(step !== 0); + torch._C.TORCH_CHECK(step >= -Number.MAX_SAFE_INTEGER); // INT64_MAX + if (start._ === Number.MAX_SAFE_INTEGER) { + start._ = (step < 0) ? Number.MAX_SAFE_INTEGER : 0; + } + if (stop._ === Number.MAX_SAFE_INTEGER) { + stop._ = (step < 0) ? Number.MIN_SAFE_INTEGER : Number.MAX_SAFE_INTEGER; + } + if (start._ < 0) { + start._ += length; + if (start._ < 0) { + start._ = (step < 0) ? -1 : 0; + } + } else if (start._ >= length) { + start._ = (step < 0) ? length - 1 : length; + } + if (stop._ < 0) { + stop._ += length; + if (stop._ < 0) { + stop._ = (step < 0) ? -1 : 0; + } + } else if (stop._ >= length) { + stop._ = (step < 0) ? length - 1 : length; + } + if (step < 0) { + if (stop._ < start._) { + return Math.floor((start._ - stop._ - 1) / (-step) + 1); + } + } else if (start._ < stop._) { + return Math.floor((stop._ - start._ - 1) / step + 1); + } + return 0; + }); this.registerFunction('torch._C.createTupleUnpack', (v) => { if (v.node().kind() === 'prim::TupleConstruct') { return v.node().inputs(); @@ -13483,11 +13589,16 @@ python.Execution = class { } */ }); - this.registerFunction('torch._C.inlineCallTo', (to_replace, callee, callee_graph) => { - if (callee_graph === undefined || typeof callee_graph === 'boolean') { - callee_graph = callee_graph === undefined ? true : callee_graph; - callee_graph = callee_graph ? callee.optimized_graph() : callee.graph(); + this.registerFunction('torch._C.inlineCallTo', (to_replace, callee, arg) => { + if (arg === undefined || typeof arg === 'boolean') { + const inline_optimized_graph = arg === undefined ? true : arg; + const graph = inline_optimized_graph ? callee.optimized_graph() : callee.graph(); + return torch._C.inlineCallTo(to_replace, callee, graph); } + if (arg instanceof torch.Graph === false) { + throw new python.Error('Invalid argument.'); + } + const callee_graph = arg; const guard = new torch._C.WithInsertPoint(to_replace); const value_map = new Map(); const new_outputs = torch._C.insertGraph(to_replace.owningGraph(), callee_graph, to_replace.inputs(), value_map); @@ -13555,7 +13666,9 @@ python.Execution = class { } case 'prim::CallMethod': { const graphFunction = torch._C.tryToGraphFunction(cur); - torch._C.inlineCallTo(cur, graphFunction); + if (graphFunction) { + torch._C.inlineCallTo(cur, graphFunction); + } break; } default: { @@ -14016,8 +14129,8 @@ python.Execution = class { throw new python.Error('With item expression must return an object.'); } const rhsClass = rhs.type(); - const enterMethod = rhsClass.findMethod("__enter__"); - const exitMethod = rhsClass.findMethod("__exit__"); + const enterMethod = rhsClass.findMethod('__enter__'); + const exitMethod = rhsClass.findMethod('__exit__'); if (!enterMethod || !exitMethod) { throw new python.Error('Object returned by with item expression does not define __enter__ and __exit__ methods.'); } @@ -14428,7 +14541,7 @@ python.Execution = class { } emitToBool(loc, v) { let out = null; - const bool_cast = this.environment_stack.getSugaredVar("bool", loc); + const bool_cast = this.environment_stack.getSugaredVar('bool', loc); out = torch._C.asSimple(bool_cast.call(loc, this.method, [new torch._C.NamedValue(v)], [], 0)); if (!out) { throw new python.Error('Could not cast value to bool.'); @@ -14469,7 +14582,7 @@ python.Execution = class { } throw new python.Error('Not implemented.'); /* - const tmp_name = this.createTempName("$tmp_assign_"); + const tmp_name = this.createTempName('$tmp_assign_'); this.environment_stack.setSugaredVar(stmt.value, tmp_name, this.emitSugaredExpr(stmt.value, 1), annotated_type=null); const ident = new ast.Name(tmp_name); for (const expr of lhs_list) { @@ -14558,7 +14671,7 @@ python.Execution = class { const var = Starred(assignee).expr(); if (var.kind() != TK_VAR) { throw( - ErrorReport(var) << "Cannot pack a tuple into a non-variable"); + ErrorReport(var) << 'Cannot pack a tuple into a non-variable'); } size_t n_matched = outputs.size() - n_binders; ArrayRef> outputs_ref = outputs; @@ -15010,7 +15123,7 @@ python.Execution = class { adj_index = tuple_len + input_index; } if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) { - throw new python.Error('Tuple index out of range.'); + throw new python.Error(`Tuple index out of range at ${loc}.`); } return adj_index; } @@ -15033,6 +15146,13 @@ python.Execution = class { } return this.graph.insertNode(this.graph.createTupleIndex(tuple_val, idx_val, output_type)).output(); } + getSliceInd(idx_val, loc) { + const ivalue = torch._C.toIValue(idx_val); + if (ivalue && ivalue.isInt()) { + return ivalue.toInt(); + } + throw new python.Error(`Tuple slice indices must be integer constants at '${loc}'.`); + } emitTupleSlice(loc, tuple_val, tuple_args) { const tuple_type = tuple_val.value(this.graph).type().expect(torch.TupleType); const tuple_len = tuple_type.elements().length; @@ -15043,16 +15163,16 @@ python.Execution = class { torch._C.TORCH_CHECK(val.isInt()); step_size = val.toInt(); } - let beg = Number.MAX_SAFE_INTEGER; // std::numeric_limits::max(); + let beg = { _: Number.MAX_SAFE_INTEGER }; // std::numeric_limits::max(); if (beg_val) { - beg = this.getAdjTupleIndex(loc, tuple_type, this.getSliceInd(beg_val.value(this.graph), loc), true); + beg = { _: this.getAdjTupleIndex(loc, tuple_type, this.getSliceInd(beg_val.value(this.graph), loc), true) }; } - let end = Number.MAX_SAFE_INTEGER; // std::numeric_limits::max(); + let end = { _: Number.MAX_SAFE_INTEGER }; // std::numeric_limits::max(); if (end_val) { - end = this.getAdjTupleIndex(loc, tuple_type, this.getSliceInd(end_val.value(this.graph), loc), true); + end = { _: this.getAdjTupleIndex(loc, tuple_type, this.getSliceInd(end_val.value(this.graph), loc), true) }; } const num_values = torch._C.slice_indices_adjust(tuple_len, beg, end, step_size); - return this.graph.insertNode(this.graph.createTupleSlice(tuple_val.value(this.graph), beg, step_size, num_values)).output(); + return this.graph.insertNode(this.graph.createTupleSlice(tuple_val.value(this.graph), beg._, step_size, num_values)).output(); } emitSliceOp(loc, sliceable, dim, start, end, step) { const args = []; @@ -15278,7 +15398,7 @@ python.Execution = class { n.addInput(zero); const new_iter = loop.bodyBlock().addInput().setType(IntType::get()); // unset unique name for jitter, its replacement does not have a name - loop.currentTripCount().setDebugName("").replaceAllUsesWith(new_iter); + loop.currentTripCount().setDebugName('').replaceAllUsesWith(new_iter); const inc_iter = g.insert(aten::add, {new_iter, one}); loop.bodyBlock().registerOutput(inc_iter); const less_than_max_trip = g.insert(aten::lt, {inc_iter, max_trip_count}); @@ -15982,7 +16102,7 @@ python.Execution = class { script_module._initializing = false; } get graph() { - // return this._c._get_method("forward").graph; + // return this._c._get_method('forward').graph; return this._c.graph; } get code_with_constants() { @@ -16841,10 +16961,10 @@ python.Execution = class { } const output_node = this.graph.output(outputs); if (serialized_graph.is_single_tensor_return) { - output_node.meta.set("val", output_node.args[0].meta.get('val')); + output_node.meta.set('val', output_node.args[0].meta.get('val')); } else { - /* output_node.meta["val"] = tuple( - arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + /* output_node.meta['val'] = tuple( + arg.meta['val'] if isinstance(arg, torch.fx.Node) else arg for arg in output_node.args[0] ) */ } @@ -18285,7 +18405,7 @@ python.Execution = class { if (typeof func === 'function') { return func.apply(callTarget, callArguments); } - throw new python.Error("Unsupported call expression."); + throw new python.Error('Unsupported call expression.'); } apply(method, args, context) { @@ -18364,19 +18484,19 @@ python.Execution = class { } else if (stmt instanceof ast.If) { const test = this.expression(stmt.test, context); if (test === true || test) { - const value = this.block(stmt.body.statements, context); + const value = this.block(stmt.body, context); if (value !== undefined) { return value; } } else if (test === false) { if (stmt.orelse) { - const value = this.block(stmt.orelse.statements, context); + const value = this.block(stmt.orelse, context); if (value !== undefined) { return value; } } } else { - throw new python.Error("Unsupported condition."); + throw new python.Error('Unsupported condition.'); } } else if (stmt instanceof ast.For) { if (stmt.target instanceof ast.Name && stmt.iter instanceof ast.Tuple === false) { @@ -18889,7 +19009,7 @@ python.BinaryReader = class { line() { const index = this._buffer.indexOf(0x0A, this._position); if (index === -1) { - throw new python.Error("Could not find end of line."); + throw new python.Error('Could not find end of line.'); } const size = index - this._position; const text = this.string(size, 'ascii'); @@ -18992,7 +19112,7 @@ python.StreamReader = class { position = this._fill(0); index = this._buffer.indexOf(0x0A, position); if (index === -1) { - throw new python.Error("Could not find end of line."); + throw new python.Error('Could not find end of line.'); } } const size = index - position; diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index a8e08f92aa..830656696e 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -232,6 +232,24 @@ { "name": "aten::find(str self, str substr, int start=0, int end=-1) -> int" }, + { + "name": "aten::count(str self, str substr, int start=0, int end=-1) -> int" + }, + { + "name": "aten::count.int(int[] self, int el) -> int" + }, + { + "name": "aten::count.float(float[] self, float el) -> int" + }, + { + "name": "aten::count.bool(bool[] self, bool el) -> int" + }, + { + "name": "aten::count.Tensor(Tensor[] self, Tensor el) -> int" + }, + { + "name": "aten::count.str(str[] self, str el) -> int" + }, { "name": "aten::splitlines(str self, bool keepends=False) -> str[]" }, @@ -415,6 +433,9 @@ { "name": "aten::__contains__.float_list(float[] l, float item) -> bool" }, + { + "name": "aten::lower(str self) -> str" + }, { "name": "prim::type(Device self) -> str" }, @@ -5696,10 +5717,12 @@ "category": "Tensor" }, { - "name": "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]" + "name": "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]", + "category": "Tensor" }, { - "name": "aten::slice.str(str string, int? start=None, int? end=None, int step=1) -> str" + "name": "aten::slice.str(str string, int? start=None, int? end=None, int step=1) -> str", + "category": "Tensor" }, { "name": "aten::reciprocal(Tensor self) -> Tensor" diff --git a/source/pytorch.js b/source/pytorch.js index a196763ca4..cf6215f6bc 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -2173,6 +2173,9 @@ pytorch.Execution = class extends python.Execution { } variables(value, scope) { + if (this.to_ir) { + throw new pytorch.Error('Not implemented.'); + } if (!scope.refs) { scope.refs = new Set(); } @@ -2283,6 +2286,9 @@ pytorch.Execution = class extends python.Execution { } block(statements, context) { + if (!this.trace) { + return super.block(statements, context); + } const ast = this.ast; const torch = this.torch; statements = Array.prototype.slice.call(statements); @@ -2524,11 +2530,12 @@ pytorch.Execution = class extends python.Execution { this._resolver.resolveType(name); } } - if (!this.trace) { return super.statement(stmt, context); } - + if (this.to_ir) { + throw new pytorch.Error('Not implemented.'); + } switch (stmt.__class__.__name__) { case 'ClassDef': { super.statement(stmt, context); @@ -2557,6 +2564,9 @@ pytorch.Execution = class extends python.Execution { } type(expr) { + if (this.to_ir) { + throw new pytorch.Error('Not implemented.'); + } const ast = this.ast; const torch = this.torch; if (expr instanceof ast.Subscript && expr.value instanceof ast.Name) { @@ -2620,6 +2630,9 @@ pytorch.Execution = class extends python.Execution { } constant(constant) { + if (this.to_ir) { + throw new pytorch.Error('Not implemented.'); + } if (!this._constants.has(constant)) { const value = this._graph.insertConstant(constant); this._constants.set(constant, value); @@ -2631,6 +2644,9 @@ pytorch.Execution = class extends python.Execution { if (!this.trace) { return super.call(target, name, args, keywords, context); } + if (this.to_ir) { + throw new pytorch.Error('Not implemented.'); + } const ast = this.ast; const torch = this.torch; if (name === '__new__') { @@ -2710,8 +2726,8 @@ pytorch.Execution = class extends python.Execution { return super.call(target, name, args, keywords, context); } const [schema, evalArgs, evalKeywords] = overload; - const op = schema.overload_name ? `${schema.name}.${schema.overload_name}` : schema.name; - const node = this.create(op, range, 0); + // const op = schema.overload_name ? `${schema.name}.${schema.overload_name}` : schema.name; + const node = this.create(schema.name, range, 0); this._graph.insertNode(node); const referencedParameters = []; const parameters = schema.arguments; @@ -2947,6 +2963,9 @@ pytorch.Execution = class extends python.Execution { } isType(obj, type, N) { + if (this.to_ir) { + throw new pytorch.Error('Not implemented.'); + } const torch = this.torch; const builtins = this.builtins; switch (type.str()) { @@ -3106,6 +3125,9 @@ pytorch.Execution = class extends python.Execution { } getType(value) { // rename + if (this.to_ir) { + throw new pytorch.Error('Not implemented.'); + } const torch = this.torch; if (value === null || value === undefined) { return undefined; @@ -3131,6 +3153,9 @@ pytorch.Execution = class extends python.Execution { } _overload(target, name, args, keywords, context) { + if (this.to_ir) { + throw new pytorch.Error('Not implemented.'); + } const ast = this.ast; const torch = this.torch; const prefix = this.identifier(target); @@ -3408,6 +3433,7 @@ pytorch.Utility = class { case 'NoneType': return 'None'; case 'AnyListType': return 'list'; case 'AnyTupleType': return 'tuple'; + case 'ClassType': return type.annotation_str; default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`); } } @@ -4072,14 +4098,14 @@ pytorch.Metadata = class { } } for (const module of modules) { - // const existing = execution.register(`torch.ops.${module}`); + const existing = execution.register(`ops.${module}`); const namespace = new torch._ops._OpNamespace(module); - /* const created = */ execution.register(`torch.ops.${module}`, namespace); - /* for (const [name, obj] of Object.entries(existing)) { + const created = execution.register(`torch.ops.${module}`, namespace); + for (const [name, obj] of Object.entries(existing)) { if (!name.startsWith('__') && !(name in created)) { created[name] = obj; } - } */ + } } } };