forked from twitter-archive/torch-ipc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathallreduce.lua
58 lines (53 loc) · 1.96 KB
/
allreduce.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
local opt = lapp [[
Options:
-h,--host (default '127.0.0.1') host name of the server
-p,--port (default 8080) port number of the server
-n,--numNodes (default 1) number of nodes
-x,--node (default 1) which node index is this?
-b,--base (default 2) power of 2 base of the tree of nodes
-d,--dimensions (default '1000,1000') comma delimited tensor dimensions
-i,--iterations (default 1000) number of send/recv iterations
--verify verify contents of transmission (slows things down)
--verbose print lots of network stats
--cuda use CUDA tensors
]]
-- Load our requires
local ipc = require 'libipc'
local sys = require 'sys'
local Tree = require 'ipc.Tree'
-- Load cutorch if CUDA was requested
if opt.cuda then
print('loading cutorch...')
local ok = pcall(require, 'cutorch')
if ok then
print('cutorch loaded ok.')
end
end
-- Create a big tensor
local dimensions = string.split(opt.dimensions, ",")
for i = 1,#dimensions do
dimensions[i] = tonumber(dimensions[i])
end
local unpack = unpack or table.unpack
local t0 = torch.randn(unpack(dimensions)):float()
if opt.cuda then
t0 = t0:cuda()
end
-- Create the tree of nodes
local client,server
if opt.node == 1 then
server = ipc.server(opt.host, opt.port)
server:clients(opt.numNodes - 1, function(client) end)
else
client = ipc.client(opt.host, opt.port)
end
local tree = Tree(opt.node, opt.numNodes, opt.base, server, client, opt.host, opt.port + opt.node)
-- Iterate!
sys.tic()
for i = 1,opt.iterations do
tree.allReduce(t0, function(a, b) return a:add(b) end)
end
print('did '..opt.iterations..' in '..sys.toc()..' seconds')
if opt.verbose then
tree.netStats()
end