Skip to content

Commit

Permalink
Add PyTorch PixelShuffle (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 23, 2018
1 parent ea979d3 commit eb99b24
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,6 @@
"package": "torch.nn.modules.activation"
}
},
{
"name": "Threshold",
"schema": {
"category": "Activation",
"package": "torch.nn.modules.activation"
}
},
{
"name": "ReLU",
"schema": {
Expand Down Expand Up @@ -529,6 +522,12 @@
"package": "torch.nn.modules.padding"
}
},
{
"name": "PixelShuffle",
"schema": {
"package": "torch.nn.modules.pixelshuffle"
}
},
{
"name": "InstanceNorm1d",
"schema": {
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ pytorch.ModelFactory = class {
constructorTable['torch.nn.modules.padding.ConstantPad1d'] = function () {};
constructorTable['torch.nn.modules.padding.ConstantPad2d'] = function () {};
constructorTable['torch.nn.modules.padding.ConstantPad3d'] = function () {};
constructorTable['torch.nn.modules.pixelshuffle.PixelShuffle'] = function () {};
constructorTable['torch.nn.modules.pooling.AvgPool1d'] = function () {};
constructorTable['torch.nn.modules.pooling.AvgPool2d'] = function () {};
constructorTable['torch.nn.modules.pooling.AvgPool3d'] = function () {};
Expand All @@ -138,6 +139,7 @@ pytorch.ModelFactory = class {
constructorTable['torch.nn.modules.upsampling.Upsample'] = function() {};
constructorTable['torch.nn.parallel.data_parallel.DataParallel'] = function() {};
constructorTable['torch.nn.parameter.Parameter'] = function(data, requires_grad) { this.data = data; this.requires_grad = requires_grad; };
constructorTable['torch.nn.utils.spectral_norm.SpectralNorm'] = function () {};
constructorTable['torch.nn.utils.weight_norm.WeightNorm'] = function () {};
constructorTable['torchvision.models.alexnet.AlexNet'] = function () {};
constructorTable['torchvision.models.densenet.DenseNet'] = function () {};
Expand Down

0 comments on commit eb99b24

Please sign in to comment.