From 4570fba2f1ec768d5fd88e145354e3f114d6ca62 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 16 Oct 2022 13:35:16 -0700 Subject: [PATCH] Add PyTorch test file (#720) --- source/python.js | 34 ++++++++++++++++------------------ source/pytorch.js | 12 ++++++------ test/models.json | 8 ++++++++ 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/source/python.js b/source/python.js index e2eda19674..f9e5d53e8a 100644 --- a/source/python.js +++ b/source/python.js @@ -4275,15 +4275,18 @@ python.Execution = class { return tensor; }); this.registerFunction('torch._utils._rebuild_tensor', function (storage, storage_offset, size, stride) { + if (Array.isArray(storage) && storage.length === 5 && storage[0] === 'storage') { + const storage_type = storage[1]; + const size = storage[4]; + storage = new storage_type(size); + } const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor'); const tensor = self.invoke(name, []); tensor.__setstate__([ storage, storage_offset, size, stride ]); return tensor; }); this.registerFunction('torch._utils._rebuild_tensor_v2', function (storage, storage_offset, size, stride, requires_grad, backward_hooks) { - const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor'); - const tensor = self.invoke(name, []); - tensor.__setstate__([ storage, storage_offset, size, stride ]); + const tensor = execution.invoke('torch._utils._rebuild_tensor', [ storage, storage_offset, size, stride ]); tensor.requires_grad = requires_grad; tensor.backward_hooks = backward_hooks; return tensor; @@ -4294,12 +4297,8 @@ python.Execution = class { return obj; }); this.registerFunction('torch._utils._rebuild_qtensor', function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) { - const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor'); - const tensor = self.invoke(name, []); - tensor.__setstate__([ storage, storage_offset, size, stride ]); + const tensor = execution.invoke('torch._utils._rebuild_tensor_v2', [ storage, storage_offset, size, stride, requires_grad, backward_hooks ]); tensor.quantizer_params = quantizer_params; - tensor.requires_grad = requires_grad; - tensor.backward_hooks = backward_hooks; return tensor; }); this.registerFunction('torch._set_item', function(dict, key, value) { @@ -4622,21 +4621,20 @@ python.Execution = class { const module_source_map = new Map(); const deserialized_objects = new Map(); unpickler.persistent_load = (saved_id) => { - const typename = saved_id.shift(); - const data = saved_id; + const typename = saved_id[0]; switch (typename) { case 'module': { - const module = data[0]; - const source = data[2]; + const module = saved_id[1]; + const source = saved_id[3]; module_source_map.set(module, source); - return data[0]; + return saved_id[1]; } case 'storage': { - const storage_type = data.shift(); - const root_key = data.shift(); - data.shift(); // location - const size = data.shift(); - const view_metadata = data.shift(); + const storage_type = saved_id[1]; + const root_key = saved_id[2]; + /// const location = saved_id[3]; + const size = saved_id[4]; + const view_metadata = saved_id[5]; if (!deserialized_objects.has(root_key)) { const obj = new storage_type(size); deserialized_objects.set(root_key, obj); diff --git a/source/pytorch.js b/source/pytorch.js index 49654a982e..a54bd58684 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1229,20 +1229,20 @@ pytorch.Container.Zip.Script = class { const execution = this.execution; const unpickler = execution.invoke('pickle.Unpickler', [ data ]); unpickler.persistent_load = (saved_id) => { - const typename = saved_id.shift(); + const typename = saved_id[0]; switch (typename) { case 'storage': { - const storage_type = saved_id.shift(); - const root_key = saved_id.shift(); - /* const location = */ saved_id.shift(); - const size = saved_id.shift(); + const storage_type = saved_id[1]; + const root_key = saved_id[2]; + // const location = saved_id[3]; + const size = saved_id[4]; if (!loaded_storages.has(root_key)) { const storage = new storage_type(size); storage._set_cdata(storage_map.get(root_key)); loaded_storages.set(root_key, storage); } const storage = loaded_storages.get(root_key); - const view_metadata = saved_id.shift(); + const view_metadata = saved_id[5]; if (view_metadata) { const view_key = view_metadata.shift(); view_metadata.shift(); // view_offset diff --git a/test/models.json b/test/models.json index 63c5c25fd6..7b9dae82c1 100644 --- a/test/models.json +++ b/test/models.json @@ -4160,6 +4160,14 @@ "format": "PyTorch v1.6", "link": "https://github.com/lutzroeder/netron/issues/720" }, + { + "type": "pytorch", + "target": "data.pkl", + "source": "https://github.com/lutzroeder/netron/files/9795497/data.pkl.zip[data.pkl]", + "format": "Pickle", + "error": "Unknown type name '__torch__.torchvision.models.alexnet.___torch_mangle_30.AlexNet' in 'data.pkl'.", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "d2go.pt",