Skip to content

Commit

Permalink
Update pytorch.js (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jan 14, 2025
1 parent 092af4f commit cb5f27c
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 40 deletions.
157 changes: 123 additions & 34 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,7 @@ python.Execution = class {
const args = [];
const keywords = [];
this._tokenizer.expect('(');
let tuple = false;
while (!this._tokenizer.eat(')')) {
if (this._tokenizer.eat('\n')) {
continue;
Expand All @@ -1236,14 +1237,16 @@ python.Execution = class {
} else {
args.push(expr);
}
if (!this._tokenizer.eat(',')) {
if (this._tokenizer.eat(',')) {
tuple = true;
} else {
this._tokenizer.eat('\n');
this._tokenizer.expect(')');
break;
}
}
if (stack.length === 0 && keywords.length === 0) {
if (args.length === 1) {
if (args.length === 1 && !tuple) {
[node] = args;
} else {
node = new ast.Tuple(args);
Expand Down Expand Up @@ -5554,6 +5557,12 @@ python.Execution = class {
this.inlineIfBody(n.blocks()[block_index]);
this._made_change = true;
}
replaceAndRemoveIfOutput(n, i, replacement) {
n.outputs()[i].replaceAllUsesWith(replacement);
n.eraseOutput(i);
n.blocks()[0].eraseOutput(i);
n.blocks()[1].eraseOutput(i);
}
removeExtraIfOutputs(n) {
torch._C.TORCH_CHECK(n.kind() === 'prim::If');
const [true_block, false_block] = n.blocks();
Expand Down Expand Up @@ -8684,7 +8693,29 @@ python.Execution = class {
this.types = [key, value];
}
static create(key, value) {
return new torch.DictType(key, value);
let kind = key.kind();
if (key instanceof torch._C.DynamicType) {
kind = key.dynamicKind();
}
switch (kind) {
case 'AnyType':
case 'IntType':
case 'BoolType':
case 'FloatType':
case 'ComplexType':
case 'StringType':
case 'TensorType':
case 'DeviceObjType':
return new torch.DictType(key, value);
default:
throw new python.Error(`Invalid dict key type '${kind}'.`);
}
}
createWithContained(contained_types) {
if (contained_types.length !== 2) {
throw new python.Error('Expected 2 contained types.');
}
return torch.DictType.create(contained_types[0], contained_types[1]);
}
getKeyType() {
return this.types[0];
Expand All @@ -8695,15 +8726,15 @@ python.Execution = class {
hasFreeVariables() {
return this.getKeyType().hasFreeVariables() || this.getValueType().hasFreeVariables();
}
createWithContained(contained_types) {
if (contained_types.length !== 2) {
throw new python.Error('Expected 2 contained types.');
}
return torch.DictType.create(contained_types[0], contained_types[1]);
}
containedTypes() {
return this.types;
}
equals(rhs) {
if (rhs instanceof torch.DictType) {
return this.getKeyType().equals(rhs.getKeyType()) && this.getValueType().equals(rhs.getValueType());
}
return false;
}
str() {
return `Dict(${this.getKeyType().str()}, ${this.getValueType().str()})`;
}
Expand Down Expand Up @@ -9558,6 +9589,8 @@ python.Execution = class {
map.set('bool', torch.BoolType.get());
map.set('complex', torch.ComplexType.get());
map.set('str', torch.StringType.get());
map.set('Device', torch.DeviceObjType.get());
map.set('number', torch.NumberType.get());
map.set('None', torch.NoneType.get());
map.set('NoneType', torch.NoneType.get());
map.set('Any', torch.AnyType.get());
Expand Down Expand Up @@ -9637,7 +9670,7 @@ python.Execution = class {
}
}
}
return this._resolver._cu.execution.type(expr);
throw new python.Error(`Unknown type name '${name}'.`);
}
parseBaseTypeName(expr) {
if (expr instanceof ast.Name) {
Expand Down Expand Up @@ -9935,6 +9968,19 @@ python.Execution = class {
n.output().setType(output_type);
return n;
}
createTupleSlice(tup, beg, step_size, num_values) {
const new_vals = [];
const tt = tup.type().expect(torch.TupleType);
let i = beg;
for (let j = 0; j < num_values; j++) {
const idx = this.insertConstant(new torch._C.IValue(i, 'Int'));
const tupleIndex = this.insertNode(this.createTupleIndex(tup, idx, tt.elements()[i]));
new_vals.push(tupleIndex.output());
i += step_size;
}
const n = this.createTuple(new_vals);
return n;
}
createDict(key_type, value_type, keys, values) {
if (keys.length !== values.length) {
throw new python.Error('Invalid dictionary size.');
Expand Down Expand Up @@ -10261,20 +10307,16 @@ python.Execution = class {
return this._kind;
}
schema() {
if (this._op === null) {
this._op = null;
const index = this._kind.indexOf('.');
const name = index === -1 ? this._kind : this._kind.substring(0, index);
const overload_name = index === -1 ? '' : this._kind.substring(index + 1);
const candidates = torch._C.getAllOperatorsFor(name);
for (const candidate of candidates) {
if (candidate.schema().overload_name === overload_name) {
this._op = candidate;
break;
}
}
if (this._op) {
return this._op.schema();
}
return this._op ? this._op.schema() : null;
// Node::schema() throws while torch.Node.schema() does not.
const op = this.maybeOperator();
if (op) {
return op.schema();
}
return null;
// return this.getOperator().schema();
}
hasNamedInput(name) {
for (const argument of this.schema().arguments) {
Expand Down Expand Up @@ -10345,7 +10387,7 @@ python.Execution = class {
if (maybe) {
return maybe;
}
throw new python.Error('Operator not found.');
throw new python.Error(`Schema not found for node '${this.kind()}'.`);
}
getOperation() {
return this.getOperator().getOperation(this);
Expand Down Expand Up @@ -11260,17 +11302,21 @@ python.Execution = class {
this._cu.define(qualified_classname, [], [], methods, method_resolvers, self, false, this._version);
}
importNamedTuple(qualified_name, named_tuple_def) {
const type_parser = new torch._C.ScriptTypeParser(this);
const field_names = [];
const field_types = [];
const field_defaults = [];
for (const stmt of named_tuple_def.body) {
if (stmt instanceof ast.AnnAssign === false) {
throw new python.Error('Unexpected statement in NamedTuple body.');
}
const assign = stmt;
const target = this._cu.execution.identifier(stmt.target);
const annotation = this._cu.execution.type(stmt.annotation);
// const annotation = this._cu.execution.type(stmt.annotation);
const type = type_parser.parseTypeFromExpr(assign.annotation);
field_names.push(target);
field_types.push(annotation);
// field_types.push(annotation);
field_types.push(type);
}
const tt = torch.TupleType.createNamed(qualified_name.qualifiedName(), field_names, field_types, field_defaults);
this._cu.register_type(tt);
Expand Down Expand Up @@ -12493,7 +12539,6 @@ python.Execution = class {
if (result) {
return result;
}
args[0].value().type().isSubtypeOf(schema.arguments[0].type);
throw new python.Error(`No matching schema '${schema.name}' found.`);
});
this.registerFunction('torch._C.matchSchemas', (schemas, loc, graph, args, kwargs, self, render_errors) => {
Expand Down Expand Up @@ -12793,6 +12838,9 @@ python.Execution = class {
} else if (val instanceof torch.ScriptObject) {
n.ival_('value', val);
type = val.type();
} else if (Array.isArray(val) && val.every((item) => Number.isInteger(item))) {
n.ival_('value', val);
type = torch.ListType.create(torch.IntType.get());
} else {
throw new python.Error(`Unsupported value type '${typeof val}'.`);
}
Expand Down Expand Up @@ -13459,6 +13507,40 @@ python.Execution = class {
this._instance_name = instance_name;
}
});
this.registerFunction('torch._C.slice_indices_adjust', (length, start, stop, step) => {
torch._C.TORCH_CHECK(step !== 0);
torch._C.TORCH_CHECK(step >= -Number.MAX_SAFE_INTEGER); // INT64_MAX
if (start._ === Number.MAX_SAFE_INTEGER) {
start._ = (step < 0) ? Number.MAX_SAFE_INTEGER : 0;
}
if (stop._ === Number.MAX_SAFE_INTEGER) {
stop._ = (step < 0) ? Number.MIN_SAFE_INTEGER : Number.MAX_SAFE_INTEGER;
}
if (start._ < 0) {
start._ += length;
if (start._ < 0) {
start._ = (step < 0) ? -1 : 0;
}
} else if (start._ >= length) {
start._ = (step < 0) ? length - 1 : length;
}
if (stop._ < 0) {
stop._ += length;
if (stop._ < 0) {
stop._ = (step < 0) ? -1 : 0;
}
} else if (stop._ >= length) {
stop._ = (step < 0) ? length - 1 : length;
}
if (step < 0) {
if (stop._ < start._) {
return Math.floor((start._ - stop._ - 1) / (-step) + 1);
}
} else if (start._ < stop._) {
return Math.floor((stop._ - start._ - 1) / step + 1);
}
return 0;
});
this.registerFunction('torch._C.createTupleUnpack', (v) => {
if (v.node().kind() === 'prim::TupleConstruct') {
return v.node().inputs();
Expand Down Expand Up @@ -15033,6 +15115,13 @@ python.Execution = class {
}
return this.graph.insertNode(this.graph.createTupleIndex(tuple_val, idx_val, output_type)).output();
}
getSliceInd(idx_val, loc) {
const ivalue = torch._C.toIValue(idx_val);
if (ivalue && ivalue.isInt()) {
return ivalue.toInt();
}
throw new python.Error(`Tuple slice indices must be integer constants at '${loc}'.`);
}
emitTupleSlice(loc, tuple_val, tuple_args) {
const tuple_type = tuple_val.value(this.graph).type().expect(torch.TupleType);
const tuple_len = tuple_type.elements().length;
Expand All @@ -15043,16 +15132,16 @@ python.Execution = class {
torch._C.TORCH_CHECK(val.isInt());
step_size = val.toInt();
}
let beg = Number.MAX_SAFE_INTEGER; // std::numeric_limits<int64_t>::max();
let beg = { _: Number.MAX_SAFE_INTEGER }; // std::numeric_limits<int64_t>::max();
if (beg_val) {
beg = this.getAdjTupleIndex(loc, tuple_type, this.getSliceInd(beg_val.value(this.graph), loc), true);
beg = { _: this.getAdjTupleIndex(loc, tuple_type, this.getSliceInd(beg_val.value(this.graph), loc), true) };
}
let end = Number.MAX_SAFE_INTEGER; // std::numeric_limits<int64_t>::max();
let end = { _: Number.MAX_SAFE_INTEGER }; // std::numeric_limits<int64_t>::max();
if (end_val) {
end = this.getAdjTupleIndex(loc, tuple_type, this.getSliceInd(end_val.value(this.graph), loc), true);
end = { _: this.getAdjTupleIndex(loc, tuple_type, this.getSliceInd(end_val.value(this.graph), loc), true) };
}
const num_values = torch._C.slice_indices_adjust(tuple_len, beg, end, step_size);
return this.graph.insertNode(this.graph.createTupleSlice(tuple_val.value(this.graph), beg, step_size, num_values)).output();
return this.graph.insertNode(this.graph.createTupleSlice(tuple_val.value(this.graph), beg._, step_size, num_values)).output();
}
emitSliceOp(loc, sliceable, dim, start, end, step) {
const args = [];
Expand Down Expand Up @@ -18364,13 +18453,13 @@ python.Execution = class {
} else if (stmt instanceof ast.If) {
const test = this.expression(stmt.test, context);
if (test === true || test) {
const value = this.block(stmt.body.statements, context);
const value = this.block(stmt.body, context);
if (value !== undefined) {
return value;
}
} else if (test === false) {
if (stmt.orelse) {
const value = this.block(stmt.orelse.statements, context);
const value = this.block(stmt.orelse, context);
if (value !== undefined) {
return value;
}
Expand Down
27 changes: 25 additions & 2 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,24 @@
{
"name": "aten::find(str self, str substr, int start=0, int end=-1) -> int"
},
{
"name": "aten::count(str self, str substr, int start=0, int end=-1) -> int"
},
{
"name": "aten::count.int(int[] self, int el) -> int"
},
{
"name": "aten::count.float(float[] self, float el) -> int"
},
{
"name": "aten::count.bool(bool[] self, bool el) -> int"
},
{
"name": "aten::count.Tensor(Tensor[] self, Tensor el) -> int"
},
{
"name": "aten::count.str(str[] self, str el) -> int"
},
{
"name": "aten::splitlines(str self, bool keepends=False) -> str[]"
},
Expand Down Expand Up @@ -415,6 +433,9 @@
{
"name": "aten::__contains__.float_list(float[] l, float item) -> bool"
},
{
"name": "aten::lower(str self) -> str"
},
{
"name": "prim::type(Device self) -> str"
},
Expand Down Expand Up @@ -5696,10 +5717,12 @@
"category": "Tensor"
},
{
"name": "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]"
"name": "aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> t[]",
"category": "Tensor"
},
{
"name": "aten::slice.str(str string, int? start=None, int? end=None, int step=1) -> str"
"name": "aten::slice.str(str string, int? start=None, int? end=None, int step=1) -> str",
"category": "Tensor"
},
{
"name": "aten::reciprocal(Tensor self) -> Tensor"
Expand Down
Loading

0 comments on commit cb5f27c

Please sign in to comment.