This repository has been archived by the owner on Jul 19, 2024. It is now read-only.
forked from facebookarchive/fb.resnet.torch
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathinit.lua
129 lines (111 loc) · 4.03 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Generic model creating code. For the specific ResNet model see
-- models/resnet.lua
--
require 'nn'
require 'cunn'
require 'cudnn'
local M = {}
function M.setup(opt, checkpoint)
local model
if checkpoint then
local modelPath = paths.concat(opt.resume, checkpoint.modelFile)
assert(paths.filep(modelPath), 'Saved model not found: ' .. modelPath)
print('=> Resuming model from ' .. modelPath)
model = torch.load(modelPath)
elseif opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain)
model = torch.load(opt.retrain)
else
-- print('=> Creating model from file: models/' .. opt.netType .. '.lua')
model = require('models/' .. opt.netType)(opt)
end
-- First remove any DataParallelTable
if torch.type(model) == 'nn.DataParallelTable' then
model = model:get(1)
end
-- optnet is an general library for reducing memory usage in neural networks
if opt.optnet then
local optnet = require 'optnet'
local imsize = opt.dataset == 'imagenet' and 224 or 32
local sampleInput = torch.zeros(4,3,imsize,imsize):cuda()
optnet.optimizeMemory(model, sampleInput, {inplace = false, mode = 'training'})
end
-- This is useful for fitting ResNet-50 on 4 GPUs, but requires that all
-- containers override backwards to call backwards recursively on submodules
if opt.shareGradInput then
M.shareGradInput(model)
end
-- For resetting the classifier when fine-tuning on a different Dataset
if opt.resetClassifier and not checkpoint then
print(' => Replacing classifier with ' .. opt.nClasses .. '-way classifier')
local orig = model:get(#model.modules)
assert(torch.type(orig) == 'nn.Linear',
'expected last layer to be fully connected')
local linear = nn.Linear(orig.weight:size(2), opt.nClasses)
linear.bias:zero()
model:remove(#model.modules)
model:add(linear:cuda())
end
-- Set the CUDNN flags
if opt.cudnn == 'fastest' then
cudnn.fastest = true
cudnn.benchmark = true
elseif opt.cudnn == 'deterministic' then
-- Use a deterministic convolution implementation
model:apply(function(m)
if m.setMode then m:setMode(1, 1, 1) end
end)
end
-- Wrap the model with DataParallelTable, if using more than one GPU
if opt.nGPU > 1 then
local gpus = torch.range(1, opt.nGPU):totable()
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end)
dpt.gradInput = nil
model = dpt:cuda()
end
local criterion = nn.CrossEntropyCriterion():cuda()
return model, criterion
end
function M.shareGradInput(model)
local function sharingKey(m)
local key = torch.type(m)
if m.__shareGradInputKey then
key = key .. ':' .. m.__shareGradInputKey
end
return key
end
-- Share gradInput for memory efficient backprop
local cache = {}
model:apply(function(m)
local moduleType = torch.type(m)
if torch.isTensor(m.gradInput) and moduleType ~= 'nn.ConcatTable' then
local key = sharingKey(m)
if cache[key] == nil then
cache[key] = torch.CudaStorage(1)
end
m.gradInput = torch.CudaTensor(cache[key], 1, 0)
end
end)
for i, m in ipairs(model:findModules('nn.ConcatTable')) do
if cache[i % 2] == nil then
cache[i % 2] = torch.CudaStorage(1)
end
m.gradInput = torch.CudaTensor(cache[i % 2], 1, 0)
end
end
return M