Skip to content

Commit

Permalink
Add PyTorch test files (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 22, 2023
1 parent 141722a commit f1b2eb4
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 112 deletions.
2 changes: 1 addition & 1 deletion source/pickle.js
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pickle.Node = class {
this.attributes.push(attribute);
} else {
stack = stack || new Set();
if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj.__class__ && obj.__class__.__module__ === value[0].__class__.__module__ && obj.__class__.__name__ === value[0].__class__.__name__)) {
if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ === value[0].__class__.__module__ && obj.__class__.__name__ === value[0].__class__.__name__)) {
const values = value.filter((value) => !stack.has(value));
const nodes = values.map((value) => {
stack.add(value);
Expand Down
57 changes: 29 additions & 28 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ python.Execution = class {
this.register('xgboost');
this.registerType('builtins.dict', dict);
this.registerType('builtins.ellipsis', class {});
this.registerType('builtins.list', class {});
this.registerType('builtins.list', class extends Array {});
this.registerType('builtins.number', class {});
this.registerFunction('builtins.__import__', function(name, globals, locals, fromlist, level) {
return execution.__import__(name, globals, locals, fromlist, level);
Expand Down Expand Up @@ -2901,7 +2901,7 @@ python.Execution = class {
break;
}
case 93: // EMPTY_LIST ']'
stack.push([]);
stack.push(execution.invoke('builtins.list', []));
break;
case 41: // EMPTY_TUPLE ')'
stack.push([]);
Expand Down Expand Up @@ -4211,10 +4211,10 @@ python.Execution = class {
this.registerType('torchvision.models.convnext.ConvNeXt', class {});
this.registerType('torchvision.models.convnext.CNBlock', class {});
this.registerType('torchvision.models.convnext.LayerNorm2d', class {});
this.registerType('torchvision.models.densenet.DenseNet', class {});
this.registerType('torchvision.models.densenet._DenseBlock', class {});
this.registerType('torchvision.models.densenet._DenseLayer', class {});
this.registerType('torchvision.models.densenet._Transition', class {});
this.registerType('torchvision.models.densenet.DenseNet', class extends torch.nn.modules.module.Module {});
this.registerType('torchvision.models.densenet._DenseBlock', class extends torch.nn.modules.container.ModuleDict {});
this.registerType('torchvision.models.densenet._DenseLayer', class extends torch.nn.modules.module.Module {});
this.registerType('torchvision.models.densenet._Transition', class extends torch.nn.modules.container.Sequential {});
this.registerType('torchvision.models.detection._utils.BalancedPositiveNegativeSampler', class {});
this.registerType('torchvision.models.detection._utils.BoxCoder', class {});
this.registerType('torchvision.models.detection._utils.Matcher', class {});
Expand Down Expand Up @@ -4896,8 +4896,7 @@ python.Execution = class {
const module_source_map = new Map();
const deserialized_objects = new Map();
unpickler.persistent_load = (saved_id) => {
const typename = saved_id[0];
switch (typename) {
switch (saved_id[0]) {
case 'module': {
const module = saved_id[1];
const source = saved_id[3];
Expand All @@ -4906,13 +4905,12 @@ python.Execution = class {
}
case 'storage': {
const storage_type = saved_id[1];
const root_key = saved_id[2];
/// const location = saved_id[3];
const key = saved_id[2];
const size = saved_id[4];
const view_metadata = saved_id[5];
if (!deserialized_objects.has(root_key)) {
if (!deserialized_objects.has(key)) {
const obj = new storage_type(size);
deserialized_objects.set(root_key, obj);
deserialized_objects.set(key, obj);
}
if (view_metadata) {
const view_key = view_metadata.shift();
Expand All @@ -4924,10 +4922,10 @@ python.Execution = class {
}
return deserialized_objects.get(view_key);
}
return deserialized_objects.get(root_key);
return deserialized_objects.get(key);
}
default: {
throw new python.Error("Unsupported persistent load type '" + typename + "'.");
throw new python.Error("Unsupported persistent load type '" + saved_id[0] + "'.");
}
}
};
Expand All @@ -4951,21 +4949,24 @@ python.Execution = class {
}
const loaded_storages = new Map();
const persistent_load = (saved_id) => {
const typename = saved_id[0];
if (typename !== 'storage') {
throw new python.Error("Unsupported persistent load type '" + typename + "'.");
}
const storage_type = saved_id[1];
const key = saved_id[2];
const numel = saved_id[4];
if (!loaded_storages.has(key)) {
const storage = new storage_type(numel);
const name = 'data/' + key;
const stream = entries.get(name);
storage._set_cdata(stream);
loaded_storages.set(key, storage);
switch (saved_id[0]) {
case 'storage': {
const storage_type = saved_id[1];
const key = saved_id[2];
const numel = saved_id[4];
if (!loaded_storages.has(key)) {
const storage = new storage_type(numel);
const name = 'data/' + key;
const stream = entries.get(name);
storage._set_cdata(stream);
loaded_storages.set(key, storage);
}
return loaded_storages.get(key);
}
default: {
throw new python.Error("Unsupported persistent load type '" + saved_id[0] + "'.");
}
}
return loaded_storages.get(key);
};
const data_file = entries.get('data.pkl');
const unpickler = execution.invoke('pickle.Unpickler', [ data_file ]);
Expand Down
Loading

0 comments on commit f1b2eb4

Please sign in to comment.