diff --git a/src/pytorch-metadata.json b/src/pytorch-metadata.json index 4e025df53a..553bc1fd02 100755 --- a/src/pytorch-metadata.json +++ b/src/pytorch-metadata.json @@ -156,13 +156,6 @@ "package": "torch.nn.modules.activation" } }, - { - "name": "Threshold", - "schema": { - "category": "Activation", - "package": "torch.nn.modules.activation" - } - }, { "name": "ReLU", "schema": { @@ -529,6 +522,12 @@ "package": "torch.nn.modules.padding" } }, + { + "name": "PixelShuffle", + "schema": { + "package": "torch.nn.modules.pixelshuffle" + } + }, { "name": "InstanceNorm1d", "schema": { diff --git a/src/pytorch.js b/src/pytorch.js index 32d5b75e36..4e020f9169 100755 --- a/src/pytorch.js +++ b/src/pytorch.js @@ -120,6 +120,7 @@ pytorch.ModelFactory = class { constructorTable['torch.nn.modules.padding.ConstantPad1d'] = function () {}; constructorTable['torch.nn.modules.padding.ConstantPad2d'] = function () {}; constructorTable['torch.nn.modules.padding.ConstantPad3d'] = function () {}; + constructorTable['torch.nn.modules.pixelshuffle.PixelShuffle'] = function () {}; constructorTable['torch.nn.modules.pooling.AvgPool1d'] = function () {}; constructorTable['torch.nn.modules.pooling.AvgPool2d'] = function () {}; constructorTable['torch.nn.modules.pooling.AvgPool3d'] = function () {}; @@ -138,6 +139,7 @@ pytorch.ModelFactory = class { constructorTable['torch.nn.modules.upsampling.Upsample'] = function() {}; constructorTable['torch.nn.parallel.data_parallel.DataParallel'] = function() {}; constructorTable['torch.nn.parameter.Parameter'] = function(data, requires_grad) { this.data = data; this.requires_grad = requires_grad; }; + constructorTable['torch.nn.utils.spectral_norm.SpectralNorm'] = function () {}; constructorTable['torch.nn.utils.weight_norm.WeightNorm'] = function () {}; constructorTable['torchvision.models.alexnet.AlexNet'] = function () {}; constructorTable['torchvision.models.densenet.DenseNet'] = function () {};