From 11999165bf1cbf1e9c09fcee990693a3345b2f62 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 7 Nov 2018 19:27:59 -0800 Subject: [PATCH] PyTorch float16 support (#133) --- src/pytorch.js | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/pytorch.js b/src/pytorch.js index 60789e2725..eb72cbb543 100755 --- a/src/pytorch.js +++ b/src/pytorch.js @@ -79,6 +79,7 @@ pytorch.ModelFactory = class { constructorTable['torch.nn.modules.batchnorm.BatchNorm1d'] = function () {}; constructorTable['torch.nn.modules.batchnorm.BatchNorm2d'] = function () {}; constructorTable['torch.nn.modules.batchnorm.BatchNorm3d'] = function () {}; + constructorTable['torch.nn.modules.container.ModuleList'] = function () {}; constructorTable['torch.nn.modules.container.Sequential'] = function () {}; constructorTable['torch.nn.modules.conv.Conv1d'] = function () {}; constructorTable['torch.nn.modules.conv.Conv2d'] = function () {}; @@ -121,6 +122,7 @@ pytorch.ModelFactory = class { constructorTable['torch.nn.parameter.Parameter'] = function(data, requires_grad) { this.data = data; this.requires_grad = requires_grad; }; constructorTable['torch.ByteStorage'] = function (size) { this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8'; }; constructorTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; }; + constructorTable['torch.HalfStorage'] = function (size) { this.size = size; this.dataTypeSize = 2; this.dataType = 'float16'; }; constructorTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; }; constructorTable['torch.DoubleStorage'] = function (size) { this.size = size; this.dataTypeSize = 8; this.dataType = 'float64'; }; constructorTable['torch.FloatTensor'] = function () { @@ -602,6 +604,11 @@ pytorch.Tensor = class { context.index += 1; context.count++; break; + case 'float16': + results.push(context.dataView.getFloat16(context.index, this._littleEndian)); + context.index += 2; + context.count++; + break; case 'float32': results.push(context.dataView.getFloat32(context.index, this._littleEndian)); context.index += 4;