-
Notifications
You must be signed in to change notification settings - Fork 283
/
Copy pathbase.lua
83 lines (75 loc) · 1.92 KB
/
base.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
--
-- Copyright (c) 2014, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the Apache 2 license found in the
-- LICENSE file in the root directory of this source tree.
--
function g_disable_dropout(node)
if type(node) == "table" and node.__typename == nil then
for i = 1, #node do
node[i]:apply(g_disable_dropout)
end
return
end
if string.match(node.__typename, "Dropout") then
node.train = false
end
end
function g_enable_dropout(node)
if type(node) == "table" and node.__typename == nil then
for i = 1, #node do
node[i]:apply(g_enable_dropout)
end
return
end
if string.match(node.__typename, "Dropout") then
node.train = true
end
end
function g_cloneManyTimes(net, T)
local clones = {}
local params, gradParams = net:parameters()
local mem = torch.MemoryFile("w"):binary()
mem:writeObject(net)
for t = 1, T do
-- We need to use a new reader for each clone.
-- We don't want to use the pointers to already read objects.
local reader = torch.MemoryFile(mem:storage(), "r"):binary()
local clone = reader:readObject()
reader:close()
local cloneParams, cloneGradParams = clone:parameters()
for i = 1, #params do
cloneParams[i]:set(params[i])
cloneGradParams[i]:set(gradParams[i])
end
clones[t] = clone
collectgarbage()
end
mem:close()
return clones
end
function g_init_gpu(args)
local gpuidx = args
gpuidx = gpuidx[1] or 1
print(string.format("Using %s-th gpu", gpuidx))
cutorch.setDevice(gpuidx)
g_make_deterministic(1)
end
function g_make_deterministic(seed)
torch.manualSeed(seed)
cutorch.manualSeed(seed)
torch.zeros(1, 1):cuda():uniform()
end
function g_replace_table(to, from)
assert(#to == #from)
for i = 1, #to do
to[i]:copy(from[i])
end
end
function g_f3(f)
return string.format("%.3f", f)
end
function g_d(f)
return string.format("%d", torch.round(f))
end