From f1b2eb452de228955ef8be61ff6bf7d3e2924488 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 21 Oct 2023 11:10:34 -0700 Subject: [PATCH] Add PyTorch test files (#720) --- source/pickle.js | 2 +- source/python.js | 57 +++++++------- source/pytorch.js | 197 +++++++++++++++++++++++++++------------------- source/view.js | 24 +++++- test/models.json | 28 +++++++ 5 files changed, 196 insertions(+), 112 deletions(-) diff --git a/source/pickle.js b/source/pickle.js index 1a5eaa4f5d..f83f05c147 100644 --- a/source/pickle.js +++ b/source/pickle.js @@ -109,7 +109,7 @@ pickle.Node = class { this.attributes.push(attribute); } else { stack = stack || new Set(); - if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj.__class__ && obj.__class__.__module__ === value[0].__class__.__module__ && obj.__class__.__name__ === value[0].__class__.__name__)) { + if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ === value[0].__class__.__module__ && obj.__class__.__name__ === value[0].__class__.__name__)) { const values = value.filter((value) => !stack.has(value)); const nodes = values.map((value) => { stack.add(value); diff --git a/source/python.js b/source/python.js index d1e35d20db..c8904e6738 100644 --- a/source/python.js +++ b/source/python.js @@ -1651,7 +1651,7 @@ python.Execution = class { this.register('xgboost'); this.registerType('builtins.dict', dict); this.registerType('builtins.ellipsis', class {}); - this.registerType('builtins.list', class {}); + this.registerType('builtins.list', class extends Array {}); this.registerType('builtins.number', class {}); this.registerFunction('builtins.__import__', function(name, globals, locals, fromlist, level) { return execution.__import__(name, globals, locals, fromlist, level); @@ -2901,7 +2901,7 @@ python.Execution = class { break; } case 93: // EMPTY_LIST ']' - stack.push([]); + stack.push(execution.invoke('builtins.list', [])); break; case 41: // EMPTY_TUPLE ')' stack.push([]); @@ -4211,10 +4211,10 @@ python.Execution = class { this.registerType('torchvision.models.convnext.ConvNeXt', class {}); this.registerType('torchvision.models.convnext.CNBlock', class {}); this.registerType('torchvision.models.convnext.LayerNorm2d', class {}); - this.registerType('torchvision.models.densenet.DenseNet', class {}); - this.registerType('torchvision.models.densenet._DenseBlock', class {}); - this.registerType('torchvision.models.densenet._DenseLayer', class {}); - this.registerType('torchvision.models.densenet._Transition', class {}); + this.registerType('torchvision.models.densenet.DenseNet', class extends torch.nn.modules.module.Module {}); + this.registerType('torchvision.models.densenet._DenseBlock', class extends torch.nn.modules.container.ModuleDict {}); + this.registerType('torchvision.models.densenet._DenseLayer', class extends torch.nn.modules.module.Module {}); + this.registerType('torchvision.models.densenet._Transition', class extends torch.nn.modules.container.Sequential {}); this.registerType('torchvision.models.detection._utils.BalancedPositiveNegativeSampler', class {}); this.registerType('torchvision.models.detection._utils.BoxCoder', class {}); this.registerType('torchvision.models.detection._utils.Matcher', class {}); @@ -4896,8 +4896,7 @@ python.Execution = class { const module_source_map = new Map(); const deserialized_objects = new Map(); unpickler.persistent_load = (saved_id) => { - const typename = saved_id[0]; - switch (typename) { + switch (saved_id[0]) { case 'module': { const module = saved_id[1]; const source = saved_id[3]; @@ -4906,13 +4905,12 @@ python.Execution = class { } case 'storage': { const storage_type = saved_id[1]; - const root_key = saved_id[2]; - /// const location = saved_id[3]; + const key = saved_id[2]; const size = saved_id[4]; const view_metadata = saved_id[5]; - if (!deserialized_objects.has(root_key)) { + if (!deserialized_objects.has(key)) { const obj = new storage_type(size); - deserialized_objects.set(root_key, obj); + deserialized_objects.set(key, obj); } if (view_metadata) { const view_key = view_metadata.shift(); @@ -4924,10 +4922,10 @@ python.Execution = class { } return deserialized_objects.get(view_key); } - return deserialized_objects.get(root_key); + return deserialized_objects.get(key); } default: { - throw new python.Error("Unsupported persistent load type '" + typename + "'."); + throw new python.Error("Unsupported persistent load type '" + saved_id[0] + "'."); } } }; @@ -4951,21 +4949,24 @@ python.Execution = class { } const loaded_storages = new Map(); const persistent_load = (saved_id) => { - const typename = saved_id[0]; - if (typename !== 'storage') { - throw new python.Error("Unsupported persistent load type '" + typename + "'."); - } - const storage_type = saved_id[1]; - const key = saved_id[2]; - const numel = saved_id[4]; - if (!loaded_storages.has(key)) { - const storage = new storage_type(numel); - const name = 'data/' + key; - const stream = entries.get(name); - storage._set_cdata(stream); - loaded_storages.set(key, storage); + switch (saved_id[0]) { + case 'storage': { + const storage_type = saved_id[1]; + const key = saved_id[2]; + const numel = saved_id[4]; + if (!loaded_storages.has(key)) { + const storage = new storage_type(numel); + const name = 'data/' + key; + const stream = entries.get(name); + storage._set_cdata(stream); + loaded_storages.set(key, storage); + } + return loaded_storages.get(key); + } + default: { + throw new python.Error("Unsupported persistent load type '" + saved_id[0] + "'."); + } } - return loaded_storages.get(key); }; const data_file = entries.get('data.pkl'); const unpickler = execution.invoke('pickle.Unpickler', [ data_file ]); diff --git a/source/pytorch.js b/source/pytorch.js index 38a9f8eecb..d8ab7519f2 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -14,22 +14,21 @@ pytorch.ModelFactory = class { async open(context, target) { const metadata = await pytorch.Metadata.open(context); - const container = target; - container.on('resolve', (_, name) => { + target.on('resolve', (_, name) => { context.exception(new pytorch.Error("Unknown type name '" + name + "'."), false); }); - await container.read(metadata); - return new pytorch.Model(metadata, container); + await target.read(metadata); + return new pytorch.Model(metadata, target); } }; pytorch.Model = class { - constructor(metadata, container) { - this.format = container.format; - this.producer = container.producer || ''; + constructor(metadata, target) { + this.format = target.format; + this.producer = target.producer || ''; this.graphs = []; - for (const entry of container.modules) { + for (const entry of target.modules) { const graph = new pytorch.Graph(metadata, entry[0], entry[1]); this.graphs.push(graph); } @@ -131,7 +130,7 @@ pytorch.Graph = class { const key = pair[0]; const value = pair[1]; if (value) { - const type = value.__class__.__module__ + '.' + value.__class__.__name__; + const type = value.__class__ ? value.__class__.__module__ + '.' + value.__class__.__name__ : null; switch (type) { case 'torch.nn.modules.container.Sequential': groups.push(key); @@ -632,7 +631,10 @@ pytorch.Tensor = class { stream.seek(position); return values; } - return this._data.peek(); + if (this._data) { + return this._data.peek(); + } + return null; } decode() { @@ -824,26 +826,80 @@ pytorch.Container.data_pkl = class extends pytorch.Container { static open(context) { const obj = context.open('pkl'); - if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { - const name = obj.__class__.__module__ + '.' + obj.__class__.__name__; - if (name.startsWith('__torch__.')) { - return new pytorch.Container.data_pkl(obj); + if (obj) { + if (obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { + const name = obj.__class__.__module__ + '.' + obj.__class__.__name__; + if (name.startsWith('__torch__.')) { + return new pytorch.Container.data_pkl('', obj); + } + } + if (pytorch.Utility.isTensor(obj)) { + return new pytorch.Container.data_pkl('tensor', obj); + } + if (obj instanceof Map) { + const entries = Array.from(obj).filter((entry) => entry[0] === '_metadata' || pytorch.Utility.isTensor(entry[1])); + if (entries.length > 0) { + return new pytorch.Container.data_pkl('tensor<>', obj); + } + } else if (!Array.isArray(obj)) { + const entries = Object.entries(obj).filter((entry) => entry[0] === '_metadata' || pytorch.Utility.isTensor(entry[1])); + if (entries.length > 0) { + return new pytorch.Container.data_pkl('tensor<>', obj); + } + } + for (const key of [ '', 'model', 'net' ]) { + const module = key === '' ? obj : obj[key]; + if (module && module._modules && pytorch.Utility.isInstance(module._modules,'collections.OrderedDict')) { + return new pytorch.Container.data_pkl('module', module); + } } } return null; } - constructor(data) { + constructor(type, data) { super(); + this._type = type; this._data = data; } get format() { - return 'PyTorch Pickle'; + switch (this._type) { + case 'module': return 'PyTorch'; + case 'tensor': return 'PyTorch Tensor'; + case 'tensor<>': return 'PyTorch Pickle Weights'; + default: return 'PyTorch Pickle'; + } } get modules() { - throw new pytorch.Error("PyTorch standalone 'data.pkl' not supported."); + switch (this._type) { + case 'module': { + if (this._data) { + this._modules = pytorch.Utility.findModule(this._data); + delete this._data; + } + if (!this._modules) { + throw new pytorch.Error('File does not contain root module or state dictionary.'); + } + return this._modules; + } + case 'tensor': + case 'tensor<>': { + if (this._data) { + this._modules = pytorch.Utility.findWeights(this._data); + delete this._data; + } + if (!this._modules) { + throw new pytorch.Error('File does not contain root module or state dictionary.'); + } + return this._modules; + } + default: { + throw new pytorch.Error("PyTorch standalone 'data.pkl' not supported."); + } + } + } }; @@ -1033,14 +1089,11 @@ pytorch.Execution = class extends python.Execution { this.storage_context = new torch._C.DeserializationStorageContext(); const unpickler = new pickle.Unpickler(stream); unpickler.persistent_load = (saved_id) => { - const typename = saved_id.shift(); - const data = saved_id; - switch (typename) { + switch (saved_id[0]) { case 'storage': { - const storage_type = saved_id[0]; - const key = saved_id[1]; - /* const location = saved_id[2]; */ - const size = saved_id[3]; + const storage_type = saved_id[1]; + const key = saved_id[2]; + const size = saved_id[4]; if (!this.storage_context.has_storage(key)) { const storage = new storage_type(size); const stream = this.zip_reader.getRecord('.data/' + key + '.storage'); @@ -1051,14 +1104,14 @@ pytorch.Execution = class extends python.Execution { return this.storage_context.get_storage(key); } case 'reduce_package': { - if (data.length === 2) { - const func = data[0]; - const args = data[1]; + if (saved_id.length === 2) { + const func = saved_id[1]; + const args = saved_id[2]; return execution.invoke(func, args); } - const reduce_id = data[0]; - const func = data[1]; - const args = data[2]; + const reduce_id = saved_id[1]; + const func = saved_id[2]; + const args = saved_id[3]; if (!loaded_reduces.has(reduce_id)) { const value = execution.invoke(func, [ this ].concat(args)); loaded_reduces.set(reduce_id, value); @@ -1066,7 +1119,7 @@ pytorch.Execution = class extends python.Execution { return loaded_reduces.get(reduce_id); } default: { - throw new pytorch.Error("Unknown package typename '" + typename + "'."); + throw new pytorch.Error("Unknown package typename '" + saved_id[0] + "'."); } } }; @@ -3319,12 +3372,6 @@ pytorch.Utility = class { return (a.type === 'id' && b.type === 'id' && a.value === b.value); } - static module() { - const module = {}; - module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' }; - return module; - } - static format(name, value) { // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h // kProducedFileFormatVersion @@ -3395,53 +3442,42 @@ pytorch.Utility = class { return null; } - static findWeights(root) { - if (!root) { - return null; - } - if (root instanceof Map) { - const obj = {}; - for (const pair of root) { - const key = pair[0]; - const value = pair[1]; - obj[key] = value; - } - root = obj; - } - const keys = !Array.isArray(root) ? Object.keys(root) : []; - if (keys.length > 1) { - keys.splice(0, keys.length); - } - keys.push(...[ - 'state_dict', 'state_dict_stylepredictor', 'state_dict_ghiasi', - 'state', 'model_state', 'model', 'model_state_dict', 'model_dict', 'net_dict', - 'generator', 'discriminator', 'g_state', 'module', 'params', - 'weights', 'network_weights', 'network', 'net', 'netG', 'net_states', - 'runner', '' - ]); - for (const key of keys) { - const obj = key === '' ? root : root[key]; - let graphs = null; - graphs = graphs || pytorch.Utility._convertTensor(obj); - graphs = graphs || pytorch.Utility._convertObjectList(obj); - graphs = graphs || pytorch.Utility._convertStateDict(obj); - if (graphs) { - return graphs; + static findWeights(obj) { + if (obj) { + if (pytorch.Utility.isTensor(obj)) { + const module = {}; + module.__class__ = { + __module__: obj.__class__.__module__, + __name__: obj.__class__.__name__ + }; + module._parameters = new Map(); + module._parameters.set('value', obj); + return new Map([ [ '', { _modules: new Map([ [ '', module ] ]) } ] ]); + } + const keys = !Array.isArray(obj) ? Object.keys(obj) : []; + if (keys.length > 1) { + keys.splice(0, keys.length); + } + keys.push(...[ + 'state_dict', 'state_dict_stylepredictor', 'state_dict_ghiasi', + 'state', 'model_state', 'model', 'model_state_dict', 'model_dict', 'net_dict', + 'generator', 'discriminator', 'g_state', 'module', 'params', + 'weights', 'network_weights', 'network', 'net', 'netG', 'net_states', + 'runner', '' + ]); + for (const key of keys) { + const value = key === '' ? obj : obj[key]; + let graphs = null; + graphs = graphs || pytorch.Utility._convertObjectList(value); + graphs = graphs || pytorch.Utility._convertStateDict(value); + if (graphs) { + return graphs; + } } } return null; } - static _convertTensor(obj) { - if (obj && pytorch.Utility.isTensor(obj)) { - const module = pytorch.Utility.module(); - module._parameters = new Map(); - module._parameters.set('value', obj); - return new Map([ [ '', { _modules: new Map([ [ '', module ] ]) } ] ]); - } - return null; - } - static _convertObjectList(obj) { if (obj && Array.isArray(obj)) { if (obj.every((item) => typeof item === 'number' || typeof item === 'string')) { @@ -3617,8 +3653,7 @@ pytorch.Utility = class { } layer_name = keys.join(separator); if (!layers.has(layer_name)) { - const module = pytorch.Utility.module(); - layers.set(layer_name, module); + layers.set(layer_name, {}); } const layer = layers.get(layer_name); if (pytorch.Utility.isTensor(value)) { diff --git a/source/view.js b/source/view.js index 645e60de1e..66a85140e7 100644 --- a/source/view.js +++ b/source/view.js @@ -3320,7 +3320,7 @@ view.Tensor = class { switch (this._layout) { case 'sparse': case 'sparse.coo': { - return !this._values || this.indices || this._values.values.length === 0; + return !this._values || this.indices || this._values.values === null || this._values.values.length === 0; } default: { switch (this._encoding) { @@ -4908,7 +4908,27 @@ view.ModelContext = class { // continue regardless of error } if (unpickler) { - unpickler.persistent_load = (saved_id) => saved_id; + const storages = new Map(); + unpickler.persistent_load = (saved_id) => { + if (Array.isArray(saved_id) && saved_id.length > 3) { + switch (saved_id[0]) { + case 'storage': { + const storage_type = saved_id[1]; + const key = saved_id[2]; + const size = saved_id[4]; + if (!storages.has(key)) { + const storage = new storage_type(size); + storages.set(key, storage); + } + return storages.get(key); + } + default: { + throw new python.Error("Unsupported persistent load type '" + saved_id[0] + "'."); + } + } + } + throw new view.Error("Unsupported 'persistent_load'."); + }; try { const obj = unpickler.load(); this._content.set(type, obj); diff --git a/test/models.json b/test/models.json index 789195f487..dae59c6fb2 100644 --- a/test/models.json +++ b/test/models.json @@ -4546,6 +4546,13 @@ "format": "PyTorch Package v1.9", "link": "https://github.com/lutzroeder/netron/issues/928" }, + { + "type": "pytorch", + "target": "densenet.data.pkl", + "source": "https://github.com/lutzroeder/netron/files/13064609/densenet.data.pkl.zip[densenet.data.pkl]", + "format": "PyTorch", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "densenet161_traced.pt", @@ -4594,6 +4601,13 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/931" }, + { + "type": "pytorch", + "target": "fast.ai.data.pkl", + "source": "https://github.com/lutzroeder/netron/files/13064775/fast.ai.data.pkl.zip[fast.ai.data.pkl]", + "format": "PyTorch", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "fasterrcnn_resnet50_fpn.pt", @@ -5353,6 +5367,13 @@ "format": "TorchScript v1.0", "link": "https://github.com/KinglittleQ/SuperPoint_SLAM" }, + { + "type": "pytorch", + "target": "tensors.data.pkl", + "source": "https://github.com/lutzroeder/netron/files/13061412/tensors.data.pkl.zip[tensors.data.pkl]", + "format": "PyTorch Pickle Weights", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "test.8bit.pth", @@ -5578,6 +5599,13 @@ "format": "PyTorch v1.6", "link": "https://github.com/lutzroeder/netron/issues/720" }, + { + "type": "pytorch", + "target": "yolov5n.tensor.data.pkl", + "source": "https://github.com/lutzroeder/netron/files/13064842/yolov5n.tensor.data.pkl.zip[yolov5n.tensor.data.pkl]", + "format": "PyTorch Pickle Weights", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "yolov5n.torchscript",