From 0e081f1eef97e05f652296fbc299e52872ba10a5 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 26 Jan 2025 10:56:30 -0800 Subject: [PATCH] Update pytorch.js (#842) --- source/python.js | 352 ++++++++++++++++++++++++++++++++++++++++++++-- source/pytorch.js | 8 +- source/view.js | 2 +- 3 files changed, 346 insertions(+), 16 deletions(-) diff --git a/source/python.js b/source/python.js index ae59fe1a66..cc7a64e3ac 100644 --- a/source/python.js +++ b/source/python.js @@ -1352,6 +1352,7 @@ python.Execution = class { throw new python.Error(`Invalid literal ${this._location()}`); } const node = new ast.Constant(value, literal.type); + this._mark(node, position); stack.push(node); } continue; @@ -6311,8 +6312,206 @@ python.Execution = class { throw new python.Error('Not implemented.'); } }); - this.registerFunction('torch._C.handleBlock', () =>{ - // + this.registerFunction('torch._C.handleBlock', (/* block, initial_state */) =>{ + /* + const autocast_stack = []; + let incompatible_amp = null; + const current_state = () => autocast_stack.length === 0 ? initial_state : autocast_stack.top().context; + for (const node of block.nodes()) { + switch (node.kind()) { + case 'prim::CallFunction': + if (current_state() === initial_state) { + if (current_state()) { + torch._C.castTensorInputs(node, 'aten::_autocast_to_full_precision', current_state()); + } + break; + } + torch._C.TORCH_INTERNAL_ASSERT(!incompatible_amp.has_value() || incompatible_amp.value(), "Calls are not expected with AMP & JIT"); + incompatible_amp = true; + break; + case 'prim::CallMethod': + if (current_state() === initial_state) { + if (current_state()) { + torch._C.castTensorInputs(node, 'aten::_autocast_to_full_precision', current_state()); + } + break; + } + if (node.input(0).type() instanceof torch.ClassType) { + const class_type = node.input(0).type(); + const name = node.s('name'); + const fn = class_type.getMethod(name); + if (!fn.isGraphFunction()) { + torch._C.TORCH_INTERNAL_ASSERT(!incompatible_amp.has_value() || incompatible_amp.value()); + incompatible_amp = true; + } + } else { + torch._C.TORCH_INTERNAL_ASSERT(!incompatible_amp.has_value() || incompatible_amp.value()); + incompatible_amp = true; + } + break; + case 'prim::Enter': { + const autocast_scope = torch._C.parseAutocast(node.input(), current_state()); + if (autocast_scope) { + if (node.hasUses()) { + torch._C.TORCH_CHECK(false, "`with autocast() as ...` is not supported"); + } + torch._C.TORCH_INTERNAL_ASSERT(!incompatible_amp.has_value() || !incompatible_amp.value()); + incompatible_amp = false; + autocast_stack.push(autocast_scope); + } + break; + } + case 'prim::Exit': { + if (torch._C.isAutocastNode(node.input(0))) { + torch._C.TORCH_INTERNAL_ASSERT(!autocast_stack.empty()); + torch._C.TORCH_INTERNAL_ASSERT(autocast_stack.top().instance === node.input()); + torch._C.TORCH_INTERNAL_ASSERT(!incompatible_amp.has_value() || !incompatible_amp.value()); + incompatible_amp = false; + autocast_stack.pop(); + } + break; + } + case 'aten::is_autocast_enabled': { + torch._C.updateAutocastEnabledCheck(node, current_state().gpu_enabled); + break; + } + case 'aten::is_autocast_cpu_enabled': { + torch._C.updateAutocastEnabledCheck(node, current_state().cpu_enabled); + break; + } + case 'aten::_convolution': + case 'aten::conv1d': + case 'aten::conv2d': + case 'aten::conv3d': + case 'aten::conv_tbc': + case 'aten::conv_transpose1d': + case 'aten::convolution': + case 'aten::cudnn_convolution': + case 'aten::cudnn_convolution_transpose': + case 'aten::prelu': + case 'aten::addmm': + case 'aten::addmv': + case 'aten::addr': + case 'aten::matmul': + case 'aten::mm': + case 'aten::mv': + case 'aten::linear': + case 'aten::addbmm': + case 'aten::baddbmm': + case 'aten::bmm': + case 'aten::chain_matmul': + case 'aten::_thnn_fused_lstm_cell': + case 'aten::_thnn_fused_gru_cell': + case 'aten::lstm_cell': + case 'aten::gru_cell': + case 'aten::rnn_tanh_cell': + case 'aten::rnn_relu_cell': { + if (!node.schema().is_mutable()) { + torch._C.castTensorInputs(node, 'aten::_autocast_to_reduced_precision', current_state()); + } + break; + } + case 'aten::native_layer_norm': + case 'aten::acos': + case 'aten::asin': + case 'aten::cosh': + case 'aten::erfinv': + case 'aten::exp': + case 'aten::expm1': + case 'aten::log': + case 'aten::log10': + case 'aten::log2': + case 'aten::log1p': + case 'aten::reciprocal': + case 'aten::rsqrt': + case 'aten::sinh': + case 'aten::tan': + case 'aten::pow': + case 'aten::softplus': + case 'aten::gelu': + case 'aten::layer_norm': + case 'aten::group_norm': + case 'aten::frobenius_norm': + case 'aten::nuclear_norm': + case 'aten::cosine_similarity': + case 'aten::cosine_embedding_loss': + case 'aten::nll_loss': + case 'aten::nll_loss2d': + case 'aten::hinge_embedding_loss': + case 'aten::kl_div': + case 'aten::l1_loss': + case 'aten::smooth_l1_loss': + case 'aten::mse_loss': + case 'aten::margin_ranking_loss': + case 'aten::multilabel_margin_loss': + case 'aten::soft_margin_loss': + case 'aten::triplet_margin_loss': + case 'aten::multi_margin_loss': + case 'aten::binary_cross_entropy_with_logits': + case 'aten::dist': + case 'aten::pdist': + case 'aten::cdist': + case 'aten::renorm': + case 'aten::logsumexp': { + if (!node.schema().is_mutable()) { + torch._C.castTensorInputs(node, 'aten::_autocast_to_full_precision', current_state()); + } + break; + } + case 'aten::prod': + case 'aten::log_softmax': + case 'aten::cumprod': + case 'aten::cumsum': + case 'aten::sum': { + if (!node.schema().is_mutable() && !torch._C.hasExplicitDtypeArgument(node)) { + torch._C.castTensorInputs(node, 'aten::_autocast_to_full_precision', current_state()); + } + break; + } + case 'aten::softmax': { + if (!node.schema().is_mutable() && !torch._C.hasExplicitDtypeArgument(node)) { + const context = current_state(); + context.cpu_enabled = false; + torch._C.castTensorInputs(node, 'aten::_autocast_to_full_precision', context); + } + break; + } + case 'aten::addcdiv': + case 'aten::addcmul': + case 'aten::atan2': + case 'aten::bilinear': + case 'aten::cat': + case 'aten::cross': + case 'aten::dot': + case 'aten::equal': + case 'aten::index_put': + case 'aten::stack': + case 'aten::tensordot': + case 'aten::add': + case 'aten::sub': + case 'aten::mul': + case 'aten::div': { + if (!node.schema().is_mutable()) { + torch._C.castInputsToWidestType(node, current_state()); + } + break; + } + case 'aten::binary_cross_entropy': { + if (current_state()) { + torch._C.TORCH_CHECK(false, "Unsafe to autocast"); + } + break; + } + default: { + break; + } + } + for (const sub_block of node.blocks()) { + torch._C.handleBlock(sub_block, current_state()); + } + } + torch._C.TORCH_INTERNAL_ASSERT(autocast_stack.length === 0); + */ }); this.registerFunction('torch._C.autocastEnabled', () => { return true; @@ -11167,7 +11366,10 @@ python.Execution = class { return this; } sourceRange() { - return this._source_range || new torch._C.SourceRange(); + if (this._source_range) { + return this._source_range; + } + return new torch._C.SourceRange(); } print_attributes(out, ignore_subgraph) { ignore_subgraph = ignore_subgraph || false; @@ -11599,6 +11801,8 @@ python.Execution = class { return expr.body[0].value; } }); + this.registerType('torch._C.StringCordView', class { + }); this.registerType('torch._C.Source', class { constructor(text_view, filename, starting_line_no, gen_ranges /*, copies_str */) { if (text_view instanceof Uint8Array) { @@ -11612,13 +11816,38 @@ python.Execution = class { this._filename = filename; this._starting_line_no = starting_line_no; this._gen_ranges = gen_ranges; + this.calc_line_start_offsets(); } text_str() { return this._text_view; } + size() { + return this._text_view.length; + } filename() { return this._filename; } + calc_line_start_offsets() { + let pos = 0; + this._line_starting_offsets = [0]; + while ((pos = this._text_view.indexOf('\n', pos)) !== -1) { + this._line_starting_offsets.push(pos); + pos += 1; + } + } + offset_for_line(line) { + return this._line_starting_offsets[line]; + } + lineno_for_offset(offset) { + const iter = this._line_starting_offsets.findIndex((value) => value > offset); + return (iter === -1 ? this._line_starting_offsets.length : iter) - 1; + } + lineno_to_source_lineno(lineno) { + if (this._filename) { + return lineno + this._starting_line_no; + } + return lineno; + } findSourceRangeThatGenerated(range) { if (!this._gen_ranges) { return null; @@ -11627,18 +11856,37 @@ python.Execution = class { } }); this.registerType('torch._C.SourceRange', class { - constructor(node) { - this._node = node; + constructor(...args) { + if (args.length === 0) { + this._source_view = null; + } else if (args.length === 2) { + let node = null; + [this._source_view, node] = args; + this._start = this._source_view.offset_for_line(node.lineno - 1) + node.col_offset; + this._end = this._source_view.offset_for_line(node.end_lineno - 1) + node.end_col_offset; + } else if (args.length === 3) { + [this._source_view, this._start, this._end] = args; + } else { + throw new python.Error('Not implemented.'); + } } source() { - return null; + return this._source_view; } file_line_col() { - return null; + if (!this._source_view || this.source().filename() === null) { + return null; + } + const lineno = this._source_view.lineno_for_offset(this._start); + const col_offset = this._start - this._source_view.offset_for_line(lineno); + return [this._source_view.filename(), this._source_view.lineno_to_source_lineno(lineno), col_offset]; + } + start() { + return this._start; } toString() { - const n = this._node; - return n ? `${n.filename}:${n.lineno}:${n.col_offset}` : ''; + const loc = this.file_line_col(); + return loc ? `${loc[0]}:${loc[1]}:${loc[2]}` : ''; } }); this.registerType('torch._C.QualifiedName', class { @@ -11700,8 +11948,17 @@ python.Execution = class { this._source_loader = source_loader; this._version = version; this._loaded_sources = new Set(); + this._sources = new Map(); + const sources = this._sources; ast.AST.prototype.range = function() { - this._range = this._range || new torch._C.SourceRange(this); + if (!this._range) { + if (sources.has(this.filename)) { + const source_view = sources.get(this.filename); + this._range = new torch._C.SourceRange(source_view, this); + } else { + this._range = new torch._C.SourceRange(); + } + } return this._range; }; this._to_be_defined = new Map(); @@ -11914,6 +12171,7 @@ python.Execution = class { if (!src) { return; } + this._sources.set(src.filename(), src); const p = new torch._C.Parser(src); const L = p.parse(); // const p = this._cu.execution.parse(src.filename(), src.text_str(), null); @@ -12009,6 +12267,45 @@ python.Execution = class { return this._otherResolver.resolveType(name, loc); } }); + this.registerType('torch._C.SourceRangeDeserializer', class { + constructor(text_table) { + this.cached_sources = new Map(); + this._text_table = text_table || []; + } + deserialize(iv) { + torch._C.TORCH_INTERNAL_ASSERT(iv.length === 3); + const [file, start, end] = iv; + const source = this.deserialize_source(file); + return new torch._C.SourceRange(source, start, end); + } + deserialize_source(iv) { + const tup = iv; + if (this.cached_sources.has(tup)) { + return this.cached_sources.get(tup); + } + let source = null; + const tup_elems = tup; + torch._C.TORCH_INTERNAL_ASSERT(tup_elems.length === 3); + if (this._text_table.length > 0) { + const [textIndex, fnameIndex, starting_line_no] = tup_elems; + torch._C.TORCH_CHECK(fnameIndex < this._text_table.length); + const filename = this._text_table[fnameIndex]; + const pieces = []; + const strs = []; + for (const i of textIndex) { + pieces.push(this._text_table[i]); + strs.push(this._text_table[i]); + } + // const str_cord = new torch._C.StringCordView(pieces, strs); + source = new torch._C.Source(pieces.join(''), filename, starting_line_no); + } else { + const [text, filename, starting_line_no] = tup_elems; + source = new torch._C.Source(text, filename, starting_line_no); + } + this.cached_sources.set(tup, source); + return source; + } + }); this.registerType('torch._C.SourceRangeUnpickler', class { }); this.registerType('torch._C.ConcreteSourceRangeUnpickler', class extends torch._C.SourceRangeUnpickler { @@ -12019,15 +12316,42 @@ python.Execution = class { this.unpickled_records = null; } unpickle() { - if (this.unpickled_records !== null) { + if (this.unpickled_records) { return; } - /* const unpickler = new pickle.Unpickler(this.data); - const ivalues = unpickler.load(); */ + const unpickler = new pickle.Unpickler(this.data); + const ivalues = unpickler.load(); + torch._C.TORCH_CHECK(ivalues.length > 0); this.unpickled_records = []; + let lines = null; + if (ivalues[0] === 'FORMAT_WITH_STRING_TABLE') { + this.deserializer = new torch._C.SourceRangeDeserializer(ivalues[1]); + /* eslint-disable prefer-destructuring */ + lines = ivalues[2]; + /* eslint-enable prefer-destructuring */ + } else { + this.deserializer = new torch._C.SourceRangeDeserializer(); + lines = ivalues; + } + for (const tup_elems of lines) { + const [offset, range] = tup_elems; + const source_range = this.deserializer.deserialize(range); + this.unpickled_records.push([offset, source_range]); + } } - findSourceRangeThatGenerated(/* range */) { + findSourceRangeThatGenerated(range) { this.unpickle(); + const start = range.start(); + const records = this.unpickled_records; + const size = range.source().size(); + for (let i = 0; i < records.length; i++) { + const [offset, range] = this.unpickled_records[i]; + const next = i < records.length - 1 ? records[i + 1][0] : size; + if (start >= offset && start < next) { + return range; + } + } + return null; } }); this.registerFunction('torch._C.qualifierToArchivePath', (qualifier, export_prefix) => { diff --git a/source/pytorch.js b/source/pytorch.js index d9e3500d82..4d5d1c6ba3 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -555,7 +555,13 @@ pytorch.Node = class { } const sourceRange = node.sourceRange(); if (sourceRange) { - this.metadata.push(new pytorch.Argument('source', sourceRange.toString().replace(/^at\s/, '').replace(/\.$/, ''))); + this.metadata.push(new pytorch.Argument('source', sourceRange.toString().replace(/^at\s/, '').replace(/\.$/, ''), 'attribute')); + if (sourceRange.source()) { + const orig = sourceRange.source().findSourceRangeThatGenerated(sourceRange); + if (orig) { + this.metadata.push(new pytorch.Argument('generated', orig.toString(), 'attribute')); + } + } } } else if (torch && obj instanceof torch.fx.node.Node) { if (obj.op === 'call_function') { diff --git a/source/view.js b/source/view.js index 83c1d80237..a2c801f7ef 100644 --- a/source/view.js +++ b/source/view.js @@ -2855,7 +2855,7 @@ view.TextView = class extends view.Control { super(context); this.element = this.createElement('div', 'sidebar-item-value'); let className = 'sidebar-item-value-line'; - if (value) { + if (value !== null && value !== undefined) { const list = Array.isArray(value) ? value : [value]; for (const item of list) { const line = this.createElement('div', className);