Skip to content

Commit

Permalink
Add PyTorch test file (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 21, 2023
1 parent 141722a commit 3c6ff61
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 55 deletions.
2 changes: 1 addition & 1 deletion source/pickle.js
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pickle.Node = class {
this.attributes.push(attribute);
} else {
stack = stack || new Set();
if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj.__class__ && obj.__class__.__module__ === value[0].__class__.__module__ && obj.__class__.__name__ === value[0].__class__.__name__)) {
if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ === value[0].__class__.__module__ && obj.__class__.__name__ === value[0].__class__.__name__)) {
const values = value.filter((value) => !stack.has(value));
const nodes = values.map((value) => {
stack.add(value);
Expand Down
49 changes: 25 additions & 24 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ python.Execution = class {
this.register('xgboost');
this.registerType('builtins.dict', dict);
this.registerType('builtins.ellipsis', class {});
this.registerType('builtins.list', class {});
this.registerType('builtins.list', class extends Array {});
this.registerType('builtins.number', class {});
this.registerFunction('builtins.__import__', function(name, globals, locals, fromlist, level) {
return execution.__import__(name, globals, locals, fromlist, level);
Expand Down Expand Up @@ -2901,7 +2901,7 @@ python.Execution = class {
break;
}
case 93: // EMPTY_LIST ']'
stack.push([]);
stack.push(execution.invoke('builtins.list', []));
break;
case 41: // EMPTY_TUPLE ')'
stack.push([]);
Expand Down Expand Up @@ -4896,8 +4896,7 @@ python.Execution = class {
const module_source_map = new Map();
const deserialized_objects = new Map();
unpickler.persistent_load = (saved_id) => {
const typename = saved_id[0];
switch (typename) {
switch (saved_id[0]) {
case 'module': {
const module = saved_id[1];
const source = saved_id[3];
Expand All @@ -4906,13 +4905,12 @@ python.Execution = class {
}
case 'storage': {
const storage_type = saved_id[1];
const root_key = saved_id[2];
/// const location = saved_id[3];
const key = saved_id[2];
const size = saved_id[4];
const view_metadata = saved_id[5];
if (!deserialized_objects.has(root_key)) {
if (!deserialized_objects.has(key)) {
const obj = new storage_type(size);
deserialized_objects.set(root_key, obj);
deserialized_objects.set(key, obj);
}
if (view_metadata) {
const view_key = view_metadata.shift();
Expand All @@ -4924,10 +4922,10 @@ python.Execution = class {
}
return deserialized_objects.get(view_key);
}
return deserialized_objects.get(root_key);
return deserialized_objects.get(key);
}
default: {
throw new python.Error("Unsupported persistent load type '" + typename + "'.");
throw new python.Error("Unsupported persistent load type '" + saved_id[0] + "'.");
}
}
};
Expand All @@ -4951,21 +4949,24 @@ python.Execution = class {
}
const loaded_storages = new Map();
const persistent_load = (saved_id) => {
const typename = saved_id[0];
if (typename !== 'storage') {
throw new python.Error("Unsupported persistent load type '" + typename + "'.");
}
const storage_type = saved_id[1];
const key = saved_id[2];
const numel = saved_id[4];
if (!loaded_storages.has(key)) {
const storage = new storage_type(numel);
const name = 'data/' + key;
const stream = entries.get(name);
storage._set_cdata(stream);
loaded_storages.set(key, storage);
switch (saved_id[0]) {
case 'storage': {
const storage_type = saved_id[1];
const key = saved_id[2];
const numel = saved_id[4];
if (!loaded_storages.has(key)) {
const storage = new storage_type(numel);
const name = 'data/' + key;
const stream = entries.get(name);
storage._set_cdata(stream);
loaded_storages.set(key, storage);
}
return loaded_storages.get(key);
}
default: {
throw new python.Error("Unsupported persistent load type '" + saved_id[0] + "'.");
}
}
return loaded_storages.get(key);
};
const data_file = entries.get('data.pkl');
const unpickler = execution.invoke('pickle.Unpickler', [ data_file ]);
Expand Down
87 changes: 59 additions & 28 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,21 @@ pytorch.ModelFactory = class {

async open(context, target) {
const metadata = await pytorch.Metadata.open(context);
const container = target;
container.on('resolve', (_, name) => {
target.on('resolve', (_, name) => {
context.exception(new pytorch.Error("Unknown type name '" + name + "'."), false);
});
await container.read(metadata);
return new pytorch.Model(metadata, container);
await target.read(metadata);
return new pytorch.Model(metadata, target);
}
};

pytorch.Model = class {

constructor(metadata, container) {
this.format = container.format;
this.producer = container.producer || '';
constructor(metadata, target) {
this.format = target.format;
this.producer = target.producer || '';
this.graphs = [];
for (const entry of container.modules) {
for (const entry of target.modules) {
const graph = new pytorch.Graph(metadata, entry[0], entry[1]);
this.graphs.push(graph);
}
Expand Down Expand Up @@ -632,7 +631,10 @@ pytorch.Tensor = class {
stream.seek(position);
return values;
}
return this._data.peek();
if (this._data) {
return this._data.peek();
}
return null;
}

decode() {
Expand Down Expand Up @@ -827,23 +829,51 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
const name = obj.__class__.__module__ + '.' + obj.__class__.__name__;
if (name.startsWith('__torch__.')) {
return new pytorch.Container.data_pkl(obj);
return new pytorch.Container.data_pkl('', obj);
}
}
if (obj && obj instanceof Map && Array.from(obj).every((entry) => entry[0] === '_metadata' || pytorch.Utility.isTensor(entry[1]))) {
return new pytorch.Container.data_pkl('map<name,tensor>', obj);
}
if (obj && pytorch.Utility.isTensor(obj)) {
return new pytorch.Container.data_pkl('tensor', obj);
}

return null;
}

constructor(data) {
constructor(type, data) {
super();
this._type = type;
this._data = data;
}

get format() {
return 'PyTorch Pickle';
switch (this._type) {
case 'tensor': return 'PyTorch Tensor';
case 'map<name,tensor>': return 'PyTorch Pickle Weights';
default: return 'PyTorch Pickle';
}
}

get modules() {
throw new pytorch.Error("PyTorch standalone 'data.pkl' not supported.");
switch (this._type) {
case 'tensor':
case 'map<name,tensor>': {
if (this._data) {
this._modules = pytorch.Utility.findWeights(this._data);
delete this._data;
}
if (!this._modules) {
throw new pytorch.Error('File does not contain root module or state dictionary.');
}
return this._modules;
}
default: {
throw new pytorch.Error("PyTorch standalone 'data.pkl' not supported.");
}
}

}
};

Expand Down Expand Up @@ -1033,14 +1063,11 @@ pytorch.Execution = class extends python.Execution {
this.storage_context = new torch._C.DeserializationStorageContext();
const unpickler = new pickle.Unpickler(stream);
unpickler.persistent_load = (saved_id) => {
const typename = saved_id.shift();
const data = saved_id;
switch (typename) {
switch (saved_id[0]) {
case 'storage': {
const storage_type = saved_id[0];
const key = saved_id[1];
/* const location = saved_id[2]; */
const size = saved_id[3];
const storage_type = saved_id[1];
const key = saved_id[2];
const size = saved_id[4];
if (!this.storage_context.has_storage(key)) {
const storage = new storage_type(size);
const stream = this.zip_reader.getRecord('.data/' + key + '.storage');
Expand All @@ -1051,22 +1078,22 @@ pytorch.Execution = class extends python.Execution {
return this.storage_context.get_storage(key);
}
case 'reduce_package': {
if (data.length === 2) {
const func = data[0];
const args = data[1];
if (saved_id.length === 2) {
const func = saved_id[1];
const args = saved_id[2];
return execution.invoke(func, args);
}
const reduce_id = data[0];
const func = data[1];
const args = data[2];
const reduce_id = saved_id[1];
const func = saved_id[2];
const args = saved_id[3];
if (!loaded_reduces.has(reduce_id)) {
const value = execution.invoke(func, [ this ].concat(args));
loaded_reduces.set(reduce_id, value);
}
return loaded_reduces.get(reduce_id);
}
default: {
throw new pytorch.Error("Unknown package typename '" + typename + "'.");
throw new pytorch.Error("Unknown package typename '" + saved_id[0] + "'.");
}
}
};
Expand Down Expand Up @@ -3434,7 +3461,11 @@ pytorch.Utility = class {

static _convertTensor(obj) {
if (obj && pytorch.Utility.isTensor(obj)) {
const module = pytorch.Utility.module();
const module = {};
module.__class__ = {
__module__: obj.__class__.__module__,
__name__: obj.__class__.__name__
};
module._parameters = new Map();
module._parameters.set('value', obj);
return new Map([ [ '', { _modules: new Map([ [ '', module ] ]) } ] ]);
Expand Down
24 changes: 22 additions & 2 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -3320,7 +3320,7 @@ view.Tensor = class {
switch (this._layout) {
case 'sparse':
case 'sparse.coo': {
return !this._values || this.indices || this._values.values.length === 0;
return !this._values || this.indices || this._values.values === null || this._values.values.length === 0;
}
default: {
switch (this._encoding) {
Expand Down Expand Up @@ -4908,7 +4908,27 @@ view.ModelContext = class {
// continue regardless of error
}
if (unpickler) {
unpickler.persistent_load = (saved_id) => saved_id;
const storages = new Map();
unpickler.persistent_load = (saved_id) => {
if (Array.isArray(saved_id) && saved_id.length > 3) {
switch (saved_id[0]) {
case 'storage': {
const storage_type = saved_id[1];
const key = saved_id[2];
const size = saved_id[4];
if (!storages.has(key)) {
const storage = new storage_type(size);
storages.set(key, storage);
}
return storages.get(key);
}
default: {
throw new python.Error("Unsupported persistent load type '" + saved_id[0] + "'.");
}
}
}
throw new view.Error("Unsupported 'persistent_load'.");
};
try {
const obj = unpickler.load();
this._content.set(type, obj);
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5353,6 +5353,13 @@
"format": "TorchScript v1.0",
"link": "https://github.com/KinglittleQ/SuperPoint_SLAM"
},
{
"type": "pytorch",
"target": "tensors.data.pkl",
"source": "https://github.com/lutzroeder/netron/files/13061412/tensors.data.pkl.zip[tensors.data.pkl]",
"format": "PyTorch Pickle Weights",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "test.8bit.pth",
Expand Down

0 comments on commit 3c6ff61

Please sign in to comment.