Skip to content

Commit

Permalink
PyTorch float16 support (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 8, 2018
1 parent 3b9b7d7 commit 1199916
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pytorch.ModelFactory = class {
constructorTable['torch.nn.modules.batchnorm.BatchNorm1d'] = function () {};
constructorTable['torch.nn.modules.batchnorm.BatchNorm2d'] = function () {};
constructorTable['torch.nn.modules.batchnorm.BatchNorm3d'] = function () {};
constructorTable['torch.nn.modules.container.ModuleList'] = function () {};
constructorTable['torch.nn.modules.container.Sequential'] = function () {};
constructorTable['torch.nn.modules.conv.Conv1d'] = function () {};
constructorTable['torch.nn.modules.conv.Conv2d'] = function () {};
Expand Down Expand Up @@ -121,6 +122,7 @@ pytorch.ModelFactory = class {
constructorTable['torch.nn.parameter.Parameter'] = function(data, requires_grad) { this.data = data; this.requires_grad = requires_grad; };
constructorTable['torch.ByteStorage'] = function (size) { this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8'; };
constructorTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; };
constructorTable['torch.HalfStorage'] = function (size) { this.size = size; this.dataTypeSize = 2; this.dataType = 'float16'; };
constructorTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; };
constructorTable['torch.DoubleStorage'] = function (size) { this.size = size; this.dataTypeSize = 8; this.dataType = 'float64'; };
constructorTable['torch.FloatTensor'] = function () {
Expand Down Expand Up @@ -602,6 +604,11 @@ pytorch.Tensor = class {
context.index += 1;
context.count++;
break;
case 'float16':
results.push(context.dataView.getFloat16(context.index, this._littleEndian));
context.index += 2;
context.count++;
break;
case 'float32':
results.push(context.dataView.getFloat32(context.index, this._littleEndian));
context.index += 4;
Expand Down

0 comments on commit 1199916

Please sign in to comment.