Skip to content

Commit

Permalink
Add PyTorch test file (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 15, 2021
1 parent 2270de5 commit e1dd636
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
40 changes: 30 additions & 10 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -3384,17 +3384,15 @@ pytorch.Utility = class {
return null;
}

static _convertObjectList(list) {
if (list && Array.isArray(list) && list.every((obj) => obj && Object.keys(obj).filter((key) => pytorch.Utility.isTensor(obj[key]).length > 0))) {
const layers = [];
for (const obj of list) {
static _convertObjectList(obj) {
if (obj && Array.isArray(obj)) {
if (obj.every((item) => typeof item === 'number' || typeof item === 'string')) {
const layers = [];
const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : '?';
const layer = { type: type, states: [], attributes: [] };
if (obj instanceof Map) {
return null;
}
for (const key of Object.keys(obj)) {
const value = obj[key];
for (let i = 0; i < obj.length; i++) {
const key = i.toString();
const value = obj[i];
if (pytorch.Utility.isTensor(value)) {
layer.states.push({ name: key, arguments: [ { id: '', value: value } ] });
}
Expand All @@ -3403,8 +3401,30 @@ pytorch.Utility = class {
}
}
layers.push(layer);
return [ { layers: layers } ];
}
if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
const layers = [];
for (const item of obj) {
const type = item.__class__ ? item.__class__.__module__ + '.' + item.__class__.__name__ : '?';
const layer = { type: type, states: [], attributes: [] };
if (item instanceof Map) {
return null;
}
for (const entry of Object.entries(item)) {
const key = entry[0];
const value = entry[1];
if (pytorch.Utility.isTensor(value)) {
layer.states.push({ name: key, arguments: [ { id: '', value: value } ] });
}
else {
layer.attributes.push({ name: key, value: value });
}
}
layers.push(layer);
}
return [ { layers: layers } ];
}
return [ { layers: layers } ];
}
return null;
}
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4334,6 +4334,13 @@
"format": "TorchScript v1.0",
"link": "https://github.com/ApolloAuto/apollo"
},
{
"type": "pytorch",
"target": "labels.pth",
"source": "https://github.com/lutzroeder/netron/files/7350657/labels.pth.zip[labels.pth]",
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "lane_scanning_vehicle_model.pt",
Expand Down

0 comments on commit e1dd636

Please sign in to comment.