Skip to content

Commit

Permalink
Add PyTorch test file (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 16, 2022
1 parent 25c351a commit 4570fba
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
34 changes: 16 additions & 18 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4275,15 +4275,18 @@ python.Execution = class {
return tensor;
});
this.registerFunction('torch._utils._rebuild_tensor', function (storage, storage_offset, size, stride) {
if (Array.isArray(storage) && storage.length === 5 && storage[0] === 'storage') {
const storage_type = storage[1];
const size = storage[4];
storage = new storage_type(size);
}
const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
const tensor = self.invoke(name, []);
tensor.__setstate__([ storage, storage_offset, size, stride ]);
return tensor;
});
this.registerFunction('torch._utils._rebuild_tensor_v2', function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
const tensor = self.invoke(name, []);
tensor.__setstate__([ storage, storage_offset, size, stride ]);
const tensor = execution.invoke('torch._utils._rebuild_tensor', [ storage, storage_offset, size, stride ]);
tensor.requires_grad = requires_grad;
tensor.backward_hooks = backward_hooks;
return tensor;
Expand All @@ -4294,12 +4297,8 @@ python.Execution = class {
return obj;
});
this.registerFunction('torch._utils._rebuild_qtensor', function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
const tensor = self.invoke(name, []);
tensor.__setstate__([ storage, storage_offset, size, stride ]);
const tensor = execution.invoke('torch._utils._rebuild_tensor_v2', [ storage, storage_offset, size, stride, requires_grad, backward_hooks ]);
tensor.quantizer_params = quantizer_params;
tensor.requires_grad = requires_grad;
tensor.backward_hooks = backward_hooks;
return tensor;
});
this.registerFunction('torch._set_item', function(dict, key, value) {
Expand Down Expand Up @@ -4622,21 +4621,20 @@ python.Execution = class {
const module_source_map = new Map();
const deserialized_objects = new Map();
unpickler.persistent_load = (saved_id) => {
const typename = saved_id.shift();
const data = saved_id;
const typename = saved_id[0];
switch (typename) {
case 'module': {
const module = data[0];
const source = data[2];
const module = saved_id[1];
const source = saved_id[3];
module_source_map.set(module, source);
return data[0];
return saved_id[1];
}
case 'storage': {
const storage_type = data.shift();
const root_key = data.shift();
data.shift(); // location
const size = data.shift();
const view_metadata = data.shift();
const storage_type = saved_id[1];
const root_key = saved_id[2];
/// const location = saved_id[3];
const size = saved_id[4];
const view_metadata = saved_id[5];
if (!deserialized_objects.has(root_key)) {
const obj = new storage_type(size);
deserialized_objects.set(root_key, obj);
Expand Down
12 changes: 6 additions & 6 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -1229,20 +1229,20 @@ pytorch.Container.Zip.Script = class {
const execution = this.execution;
const unpickler = execution.invoke('pickle.Unpickler', [ data ]);
unpickler.persistent_load = (saved_id) => {
const typename = saved_id.shift();
const typename = saved_id[0];
switch (typename) {
case 'storage': {
const storage_type = saved_id.shift();
const root_key = saved_id.shift();
/* const location = */ saved_id.shift();
const size = saved_id.shift();
const storage_type = saved_id[1];
const root_key = saved_id[2];
// const location = saved_id[3];
const size = saved_id[4];
if (!loaded_storages.has(root_key)) {
const storage = new storage_type(size);
storage._set_cdata(storage_map.get(root_key));
loaded_storages.set(root_key, storage);
}
const storage = loaded_storages.get(root_key);
const view_metadata = saved_id.shift();
const view_metadata = saved_id[5];
if (view_metadata) {
const view_key = view_metadata.shift();
view_metadata.shift(); // view_offset
Expand Down
8 changes: 8 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4160,6 +4160,14 @@
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "data.pkl",
"source": "https://github.com/lutzroeder/netron/files/9795497/data.pkl.zip[data.pkl]",
"format": "Pickle",
"error": "Unknown type name '__torch__.torchvision.models.alexnet.___torch_mangle_30.AlexNet' in 'data.pkl'.",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "d2go.pt",
Expand Down

0 comments on commit 4570fba

Please sign in to comment.