Skip to content

Commit 96bf6a3

Browse files
author
Kevin Peng
committed
First draft of ElasticAverageCollideBinder
1 parent 5f84579 commit 96bf6a3

1 file changed

Lines changed: 273 additions & 0 deletions

File tree

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
package BIDMach.allreduce.binder
2+
3+
import java.util.ArrayDeque
4+
import java.util.concurrent.atomic.AtomicInteger
5+
import java.util.logging.Logger
6+
import scala.util.Random
7+
8+
import BIDMach.allreduce.binder.AllreduceBinder.{DataSink, DataSource}
9+
//import BIDMach.models.Model
10+
import BIDMach.updaters.Grad
11+
import BIDMat.{Mat, FMat, GMat}
12+
13+
14+
/**
15+
* Linearize input model mats, and elastic-average update to the same model.
16+
* Perform momentum exchange among several nodes in a cluster, preserving total energy of the nodes.
17+
*
18+
* @param model
19+
* @param alphaFromIter
20+
*/
21+
// FIXME: should get rndseed, node num and # nodes from worker
22+
class ElasticAverageCollideBinder(updater: Grad, alphaFromIter: Int => Float, hardness: Float, rndseed: Long, inode: Int,
23+
nnodes: Int, logger: Logger) extends AllreduceBinder {
24+
25+
val model = updater.model
26+
// Keeping track of elastic updates
27+
var tic = System.currentTimeMillis()
28+
val reduceCount = new AtomicInteger()
29+
30+
val random = new Random(rndseed)
31+
// TODO: make these GMats when applicable
32+
val rawRandVecs = new Array[Array[FMat]](nnodes)
33+
val randVecs = new Array[Array[FMat]](nnodes)
34+
val randVecSqNorms = new Array[Array[Float]](nnodes)
35+
var rvOffset = 0
36+
// TODO: think about GMats too
37+
val aelem = FMat(1, 1)
38+
39+
// TODO: make this more efficient by making use of functionality in SciFunctions etc.
40+
def genRandomVector(out: FMat) = {
41+
var i = 0
42+
val len = out.length
43+
while (i < len) {
44+
out.data(i) = random.nextGaussian().toFloat
45+
}
46+
}
47+
48+
def dotprod(a:Mat, b:Mat):Float = {
49+
aelem ~ a.contents dot b.contents
50+
aelem.dv.toFloat;
51+
}
52+
53+
// TODO: is synchronization necessary to get updater momentum lengths
54+
def initRandVecs = {
55+
if (rawRandVecs(0) eq null) {
56+
for (i <- 0 until nnodes) {
57+
rawRandVecs(i) = new Array(updater.momentum.length)
58+
59+
for ((pm, j) <- updater.momentum.iterator.zipWithIndex) {
60+
val fmat = FMat.make(pm.dims)
61+
genRandomVector(fmat.contents())
62+
pm match {
63+
case _: GMat => rawRandVecs(i)(j) = GMat(fmat)
64+
case _: FMat => rawRandVecs(i)(j) = fmat
65+
}
66+
}
67+
}
68+
69+
for (i <- 0 until nnodes) {
70+
randVecs(i) = new Array(updater.momentum.length)
71+
randVecSqNorms(i) = new Array(updater.momentum.length)
72+
for (j <- 0 until updater.momentum.length) {
73+
randVecs(i)(j) = rawRandVecs(i)(j) - rawRandVecs((i + 1) % nnodes)(j)
74+
randVecSqNorms(i)(j) = dotprod(randVecs(i)(j), randVecs(i)(j))
75+
}
76+
}
77+
}
78+
}
79+
80+
def rotateRndVecs = {
81+
val prevOffset = (rvOffset + nnodes - 1) % nnodes
82+
83+
for (randMat <- rawRandVecs(rvOffset)) {
84+
randMat match {
85+
case gmat: GMat =>
86+
val fmat = FMat.make(randMat.dims)
87+
genRandomVector(fmat)
88+
gmat <-- fmat
89+
case fmat: FMat => genRandomVector(fmat)
90+
}
91+
}
92+
93+
for (offset <- Array(prevOffset, rvOffset)) {
94+
val nextOffset = (offset + 1) % nnodes
95+
for ((v1, v2) <- randVecs(offset) zip randVecs(nextOffset)) {
96+
v1 ~ v1 - v2
97+
}
98+
for ((v, i) <- randVecs(offset).iterator.zipWithIndex) {
99+
randVecSqNorms(offset)(i) = dotprod(v, v)
100+
}
101+
}
102+
103+
rvOffset += 1
104+
if (rvOffset == nnodes) rvOffset = 0
105+
}
106+
107+
override lazy val totalDataSize: Int = {
108+
var ret = 0
109+
updater.momentum.synchronized {
110+
// Momentum mats
111+
for (p <- updater.momentum) ret += p.length
112+
// Squared magnitudes of momentum mats
113+
ret += updater.momentum.length
114+
// Dot product of momentum mats and random mats
115+
ret += updater.momentum.length
116+
}
117+
// Model mats
118+
model.modelmats.synchronized {
119+
for (mat <- model.modelmats) ret += mat.length
120+
}
121+
ret
122+
}
123+
124+
override def dataSource: DataSource = inputRequest => {
125+
initRandVecs
126+
127+
val ret: Array[Float] = new Array[Float](totalDataSize)
128+
var current = totalDataSize
129+
val myRandVecs = randVecs((rvOffset + inode) % nnodes)
130+
131+
// TODO: do we need to lock on the model and updater mats
132+
133+
// backward traversing model mats, assuming forward traversal by the training model
134+
for (mm <- model.modelmats.reverseIterator) {
135+
current -= mm.length
136+
mm match {
137+
case gmat: GMat => GMat.GPUtoCPUarraycopy(gmat.pdata, 0, ret, current, gmat.length, "ElasticAverageBinder dataSource")
138+
case fmat: FMat => System.arraycopy(fmat.contents().data, 0, ret, current, fmat.length)
139+
}
140+
}
141+
142+
// dot product of momentum and random vectors
143+
// backward traversing update mats, assuming forward traversal by updater
144+
for ((pm, r) <- updater.momentum.reverseIterator zip myRandVecs.reverseIterator) {
145+
current -= 1
146+
ret(current) = dotprod(pm, r)
147+
}
148+
149+
// squared norm of momentums
150+
for (pm <- updater.momentum.reverseIterator) {
151+
current -= 1
152+
ret(current) = dotprod(pm, pm)
153+
}
154+
155+
// backward traversing update mats, assuming forward traversal by updater
156+
for (pm <- updater.momentum.reverseIterator) {
157+
current -= pm.length
158+
pm match {
159+
case gmat: GMat => GMat.GPUtoCPUarraycopy(gmat.pdata, 0, ret, current, gmat.length, "ElasticAverageBinder dataSource")
160+
case fmat: FMat => System.arraycopy(fmat.contents().data, 0, ret, current, fmat.length)
161+
}
162+
}
163+
164+
assert(current == 0, "current should be zero after iteration")
165+
166+
AllReduceInput(ret)
167+
168+
}
169+
170+
171+
172+
override def dataSink: DataSink = reducedOutput => {
173+
174+
reduceCount.synchronized {
175+
val currentCount: Int = reduceCount.getAndIncrement()
176+
val updateCounts = 10
177+
if (currentCount % updateCounts == 0) {
178+
val toc = System.currentTimeMillis()
179+
if (currentCount > 0) {
180+
logger.info(f"elastic_updates/s=${updateCounts/((toc - tic) / 1.0e3)}%2.2f, total_updates=$currentCount")
181+
}
182+
tic = toc
183+
}
184+
}
185+
val reducedData = reducedOutput.data
186+
187+
assert(reducedData.length == totalDataSize, "Reduced output should be same length as input")
188+
189+
// backward traversing model mats, assuming forward traversal by the training model
190+
// using while instead of for loop due to performance
191+
var current = totalDataSize
192+
val alpha = alphaFromIter(reducedOutput.iteration)
193+
194+
for (mm <- model.modelmats.reverseIterator) {
195+
current -= mm.length
196+
mm.synchronized {
197+
mm match {
198+
case gmat: GMat =>
199+
val gReduced = GMat.make(gmat.dims)
200+
GMat.CPUtoGPUarraycopy(reducedData, current, gReduced.pdata, 0, gmat.length, "ElasticAverageCollideBinder dataSink")
201+
gReduced ~ gReduced / aelem.set(nnodes)
202+
gmat ~ gmat * aelem.set(1 - alpha)
203+
gReduced ~ gReduced * aelem.set(alpha)
204+
gmat ~ gReduced + gmat
205+
gReduced.free()
206+
case fmat: FMat =>
207+
val fReduced = FMat.make(fmat.dims)
208+
System.arraycopy(reducedData, current, fReduced.contents().data, 0, fmat.length)
209+
fReduced ~ fReduced / aelem.set(nnodes)
210+
fmat ~ fmat * aelem.set(1 - alpha)
211+
fReduced ~ fReduced * aelem.set(alpha)
212+
fmat ~ fReduced + fmat
213+
}
214+
}
215+
}
216+
217+
val sumPmR = new Array[Float](updater.modelmats.length)
218+
current -= updater.modelmats.length
219+
System.arraycopy(reducedData, current, sumPmR, 0, updater.modelmats.length)
220+
221+
val sumPmPm = new Array[Float](updater.modelmats.length)
222+
current -= updater.modelmats.length
223+
System.arraycopy(reducedData, current, sumPmPm, 0, updater.modelmats.length)
224+
225+
val meanP = new Array[Mat](updater.modelmats.length)
226+
for (i <- updater.modelmats.length - 1 to 0 by -1) {
227+
current -= updater.modelmats(i).length
228+
val pbar = updater.modelmats(i) match {
229+
case _: GMat =>
230+
val pbar = GMat.make(updater.modelmats(i).dims)
231+
GMat.CPUtoGPUarraycopy(reducedData, current, pbar.pdata, 0, updater.modelmats(i).length, "ElasticAverageCollideBinder dataSink")
232+
pbar
233+
case _: FMat =>
234+
val pbar = FMat.make(updater.modelmats(i).dims)
235+
System.arraycopy(reducedData, current, pbar.contents().data, 0, updater.modelmats(i).length)
236+
pbar
237+
}
238+
pbar ~ pbar / aelem.set(nnodes)
239+
meanP(i) = pbar
240+
}
241+
242+
assert(current == 0, "current should be zero after iteration")
243+
244+
for (j <- updater.modelmats.length - 1 to 0 by -1) {
245+
// TODO: not hold the lock for 1293579813753 years, but also avoid data races
246+
updater.modelmats(j) synchronized {
247+
val x = meanP(j) - updater.modelmats(j)
248+
x ~ x * aelem.set(hardness)
249+
x ~ x + updater.modelmats(j)
250+
251+
val sumXR = (1 - hardness) * sumPmR(j)
252+
val sumXXminusPmPm = hardness * (hardness - 2) * (sumPmPm(j) - nnodes * dotprod(meanP(j), meanP(j)))
253+
254+
val twoSumXR = 2 * sumXR
255+
val sumRR = randVecSqNorms.map(_(j)).reduce(_ + _)
256+
// Discriminant should always be positive for any hardness in [0, 1]
257+
val discr = twoSumXR*twoSumXR - 4*sumRR*sumXXminusPmPm
258+
val epsilon = 1e-36f
259+
val beta = if (Mat.myrand.nextFloat() < 0.5f) {
260+
(-twoSumXR + math.sqrt(discr).toFloat) / (2 * sumRR + epsilon)
261+
} else {
262+
(-twoSumXR - math.sqrt(discr).toFloat) / (2 * sumRR + epsilon)
263+
}
264+
265+
updater.modelmats(j) ~ x - aelem.set(beta) * randVecs((rvOffset + inode) % nnodes)(j)
266+
}
267+
}
268+
269+
rotateRndVecs
270+
}
271+
272+
}
273+

0 commit comments

Comments
 (0)