Skip to content

Commit 62c6ef6

Browse files
committed
perf: fuse comprehension scope+eval and inline BinaryOp(ValidId,ValidId)
Fuse comprehension scope building with body evaluation, eliminating intermediate scope array allocation. For nested comprehensions like [x+y for x in arr for y in arr if x==y], this avoids allocating O(n²) intermediate scopes — only the O(n) matching results are materialized. When the innermost body is BinaryOp(ValidId,ValidId), inline scope lookups and numeric binary-op dispatch to avoid 3× visitExpr overhead per iteration. Falls back to general visitExpr for non-numeric types. Key changes: - visitCompFused: recursive fused scope+eval loop with ArrayBuilder - evalBinaryOpNumNum: @switch-dispatched Num×Num fast path - Non-numeric fallback uses existing visitExpr (no code duplication) Upstream: jit branch commits 3466461 (fuse) + 71545ba (inline)
1 parent 3ff6d95 commit 62c6ef6

3 files changed

Lines changed: 206 additions & 5 deletions

File tree

sjsonnet/src/sjsonnet/Evaluator.scala

Lines changed: 150 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,156 @@ class Evaluator(
189189
visitExpr(e.returned)(s)
190190
}
191191

192-
def visitComp(e: Comp)(implicit scope: ValScope): Val =
193-
Val.Arr(
194-
e.pos,
195-
visitComp(e.first :: e.rest.toList, Array(scope)).map(s => visitAsLazy(e.value)(s))
196-
)
192+
def visitComp(e: Comp)(implicit scope: ValScope): Val = {
193+
val results = new collection.mutable.ArrayBuilder.ofRef[Eval]
194+
results.sizeHint(16)
195+
visitCompFused(e.first :: e.rest.toList, scope, e.value, results)
196+
Val.Arr(e.pos, results.result())
197+
}
198+
199+
/**
200+
* Fused scope-building + body evaluation: eliminates intermediate scope array allocation. Instead
201+
* of first collecting all valid scopes into an Array[ValScope] and then mapping over them with
202+
* visitAsLazy, this method directly appends body results as it encounters valid scopes. For
203+
* nested comprehensions like `[x+y for x in arr for y in arr if x==y]`, this avoids allocating
204+
* O(n²) intermediate scopes — only the O(n) matching results are materialized.
205+
*
206+
* For innermost ForSpec with BinaryOp(ValidId,ValidId) body, inlines scope lookups and numeric
207+
* binary-op dispatch to avoid 3× visitExpr overhead per iteration.
208+
*/
209+
private def visitCompFused(
210+
specs: List[CompSpec],
211+
scope: ValScope,
212+
body: Expr,
213+
results: collection.mutable.ArrayBuilder.ofRef[Eval]
214+
): Unit = specs match {
215+
case (spec @ ForSpec(_, name, expr)) :: rest =>
216+
visitExpr(expr)(scope) match {
217+
case a: Val.Arr =>
218+
if (debugStats != null) debugStats.arrayCompIterations += a.length
219+
val lazyArr = a.asLazyArray
220+
if (rest.isEmpty) {
221+
// Innermost loop: try BinaryOp(ValidId,ValidId) fast path
222+
body match {
223+
case binOp: BinaryOp
224+
if binOp.lhs.tag == ExprTags.ValidId
225+
&& binOp.rhs.tag == ExprTags.ValidId =>
226+
// Fast path: reuse mutable scope, inline scope lookups + binary-op dispatch.
227+
// NOTE: Evaluates eagerly (not lazy). Both go-jsonnet and jrsonnet also
228+
// evaluate comprehensions eagerly, so this is compatible. Eagerness is
229+
// required for mutable scope reuse — a lazy thunk would capture the
230+
// mutable scope and see stale bindings from later iterations.
231+
val mutableScope = scope.extendBy(1)
232+
val slot = scope.bindings.length
233+
val bindings = mutableScope.bindings
234+
val lhsIdx = binOp.lhs.asInstanceOf[ValidId].nameIdx
235+
val rhsIdx = binOp.rhs.asInstanceOf[ValidId].nameIdx
236+
val op = binOp.op
237+
val bpos = binOp.pos
238+
var j = 0
239+
while (j < lazyArr.length) {
240+
bindings(slot) = lazyArr(j)
241+
val l = bindings(lhsIdx).value
242+
val r = bindings(rhsIdx).value
243+
(l, r) match {
244+
// Only dispatch to numeric fast path for ops it handles (0-16 except OP_in=11).
245+
// OP_in expects string+object, OP_&&/OP_|| need short-circuit semantics.
246+
case (ln: Val.Num, rn: Val.Num)
247+
if op <= Expr.BinaryOp.OP_| && op != Expr.BinaryOp.OP_in =>
248+
results += evalBinaryOpNumNum(op, ln, rn, bpos)
249+
case _ =>
250+
// Fallback to general evaluator for non-numeric types
251+
results += visitExpr(binOp)(mutableScope)
252+
}
253+
j += 1
254+
}
255+
case _ =>
256+
var j = 0
257+
while (j < lazyArr.length) {
258+
results += visitAsLazy(body)(scope.extendSimple(lazyArr(j)))
259+
j += 1
260+
}
261+
}
262+
} else {
263+
// Outer loop: recurse for remaining specs
264+
var j = 0
265+
while (j < lazyArr.length) {
266+
visitCompFused(rest, scope.extendSimple(lazyArr(j)), body, results)
267+
j += 1
268+
}
269+
}
270+
case r =>
271+
Error.fail(
272+
"In comprehension, can only iterate over array, not " + r.prettyName,
273+
spec
274+
)
275+
}
276+
case (spec @ IfSpec(offset, expr)) :: rest =>
277+
visitExpr(expr)(scope) match {
278+
case Val.True(_) =>
279+
if (rest.isEmpty) results += visitAsLazy(body)(scope)
280+
else visitCompFused(rest, scope, body, results)
281+
case Val.False(_) => // filtered out
282+
case other =>
283+
Error.fail(
284+
"Condition must be boolean, got " + other.prettyName,
285+
spec
286+
)
287+
}
288+
case Nil =>
289+
results += visitAsLazy(body)(scope)
290+
}
291+
292+
/**
293+
* Fast-path binary op evaluation for Num×Num operands within comprehension inner loops. Handles
294+
* the most common operations without visitExpr dispatch overhead.
295+
*/
296+
@inline private def evalBinaryOpNumNum(op: Int, ln: Val.Num, rn: Val.Num, pos: Position): Val = {
297+
val ld = ln.asDouble
298+
val rd = rn.asDouble
299+
(op: @switch) match {
300+
case Expr.BinaryOp.OP_+ => Val.Num(pos, ld + rd)
301+
case Expr.BinaryOp.OP_- =>
302+
val r = ld - rd
303+
if (r.isInfinite) Error.fail("overflow", pos)
304+
Val.Num(pos, r)
305+
case Expr.BinaryOp.OP_* =>
306+
val r = ld * rd
307+
if (r.isInfinite) Error.fail("overflow", pos)
308+
Val.Num(pos, r)
309+
case Expr.BinaryOp.OP_/ =>
310+
if (rd == 0) Error.fail("division by zero", pos)
311+
val r = ld / rd
312+
if (r.isInfinite) Error.fail("overflow", pos)
313+
Val.Num(pos, r)
314+
case Expr.BinaryOp.OP_% => Val.Num(pos, ld % rd)
315+
case Expr.BinaryOp.OP_< => Val.bool(pos, ld < rd)
316+
case Expr.BinaryOp.OP_> => Val.bool(pos, ld > rd)
317+
case Expr.BinaryOp.OP_<= => Val.bool(pos, ld <= rd)
318+
case Expr.BinaryOp.OP_>= => Val.bool(pos, ld >= rd)
319+
case Expr.BinaryOp.OP_== => Val.bool(pos, ld == rd)
320+
case Expr.BinaryOp.OP_!= => Val.bool(pos, ld != rd)
321+
case Expr.BinaryOp.OP_<< =>
322+
val ll = ld.toSafeLong(pos); val rr = rd.toSafeLong(pos)
323+
if (rr < 0) Error.fail("shift by negative exponent", pos)
324+
if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr)))
325+
Error.fail("numeric value outside safe integer range for bitwise operation", pos)
326+
Val.Num(pos, (ll << rr).toDouble)
327+
case Expr.BinaryOp.OP_>> =>
328+
val ll = ld.toSafeLong(pos); val rr = rd.toSafeLong(pos)
329+
if (rr < 0) Error.fail("shift by negative exponent", pos)
330+
Val.Num(pos, (ll >> rr).toDouble)
331+
case Expr.BinaryOp.OP_& =>
332+
Val.Num(pos, (ld.toSafeLong(pos) & rd.toSafeLong(pos)).toDouble)
333+
case Expr.BinaryOp.OP_^ =>
334+
Val.Num(pos, (ld.toSafeLong(pos) ^ rd.toSafeLong(pos)).toDouble)
335+
case Expr.BinaryOp.OP_| =>
336+
Val.Num(pos, (ld.toSafeLong(pos) | rd.toSafeLong(pos)).toDouble)
337+
case _ =>
338+
// Should be unreachable: caller filters to ops 0-16 except OP_in
339+
throw new AssertionError(s"Unexpected numeric binary op: $op")
340+
}
341+
}
197342

198343
def visitArr(e: Arr)(implicit scope: ValScope): Val =
199344
Val.Arr(e.pos, e.value.map(visitAsLazy))
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Regression test: all binary operators in comprehensions with ValidId operands
2+
local strs = ["hello", "world"];
3+
local nums = [1, 2, 3];
4+
local arrs = [[1, 2], [3, 4]];
5+
6+
// String concatenation
7+
local str_concat = [a + b for a in strs for b in strs];
8+
9+
// Numeric arithmetic
10+
local num_add = [a + b for a in nums for b in nums];
11+
local num_sub = [a - b for a in [10, 20] for b in [3, 5]];
12+
local num_mul = [a * b for a in [2, 3] for b in [4, 5]];
13+
local num_div = [a / b for a in [10, 20] for b in [2, 5]];
14+
local num_mod = [a % b for a in [10, 7] for b in [3, 4]];
15+
16+
// Comparison operators
17+
local cmp_lt = [a < b for a in nums for b in nums];
18+
local cmp_eq = [a == b for a in nums for b in nums];
19+
local cmp_ne = [a != b for a in nums for b in nums];
20+
21+
// Bitwise operators
22+
local bw_and = [a & b for a in [3, 5] for b in [6, 7]];
23+
local bw_or = [a | b for a in [3, 5] for b in [6, 7]];
24+
local bw_xor = [a ^ b for a in [3, 5] for b in [6, 7]];
25+
local bw_shl = [a << b for a in [1, 2] for b in [1, 2]];
26+
local bw_shr = [a >> b for a in [8, 16] for b in [1, 2]];
27+
28+
// String formatting
29+
local str_fmt = [a % b for a in ["val=%d", "x=%d"] for b in [42, 99]];
30+
31+
// Array concatenation
32+
local arr_concat = [a + b for a in arrs for b in arrs];
33+
34+
// 'in' operator
35+
local objs = [{a: 1}, {b: 2}];
36+
local in_test = [a in b for a in ["a", "b"] for b in objs];
37+
38+
std.assertEqual(str_concat, ["hellohello", "helloworld", "worldhello", "worldworld"]) &&
39+
std.assertEqual(num_add, [2, 3, 4, 3, 4, 5, 4, 5, 6]) &&
40+
std.assertEqual(num_sub, [7, 5, 17, 15]) &&
41+
std.assertEqual(num_mul, [8, 10, 12, 15]) &&
42+
std.assertEqual(num_div, [5, 2, 10, 4]) &&
43+
std.assertEqual(num_mod, [1, 2, 1, 3]) &&
44+
std.assertEqual(cmp_lt, [false, true, true, false, false, true, false, false, false]) &&
45+
std.assertEqual(cmp_eq, [true, false, false, false, true, false, false, false, true]) &&
46+
std.assertEqual(cmp_ne, [false, true, true, true, false, true, true, true, false]) &&
47+
std.assertEqual(bw_and, [2, 3, 4, 5]) &&
48+
std.assertEqual(bw_or, [7, 7, 7, 7]) &&
49+
std.assertEqual(bw_xor, [5, 4, 3, 2]) &&
50+
std.assertEqual(bw_shl, [2, 4, 4, 8]) &&
51+
std.assertEqual(bw_shr, [4, 2, 8, 4]) &&
52+
std.assertEqual(str_fmt, ["val=42", "val=99", "x=42", "x=99"]) &&
53+
std.assertEqual(arr_concat, [[1, 2, 1, 2], [1, 2, 3, 4], [3, 4, 1, 2], [3, 4, 3, 4]]) &&
54+
std.assertEqual(in_test, [true, false, false, true]) &&
55+
true
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
true

0 commit comments

Comments
 (0)