Skip to content

Commit 9b5caef

Browse files
committed
perf: comprehension fuse scope+eval and inline BinaryOp(ValidId,ValidId) fast path
Fuse comprehension scope building with body evaluation, eliminating redundant scope allocation in the innermost loop. When the body is BinaryOp(ValidId,ValidId), inline the scope lookups and binary-op dispatch entirely, avoiding 3× visitExpr overhead per iteration. Key changes: - ValScope.extendMutable(): creates a scope with one extra mutable slot for reuse across iterations (safe because results are eagerly evaluated, not captured in lazy thunks) - visitCompInline: split by rest (Nil vs non-Nil), with BinaryOp fast path for innermost loops - evalBinaryOpNumNum: @switch-dispatched Num×Num fast path covering all comparison, arithmetic, modulo, bitwise, and shift operators with full safety checks (overflow, division-by-zero, safe integer range) - visitBinaryOpValues: polymorphic fallback for non-Num operands covering string concat/format, object merge, array concat, equality, and 'in' Benchmark: comparison2 -53.1% (74.1 → 34.8 ms/op), zero regressions across 35 benchmarks. Upstream: jit branch commits 3466461 (fuse) + 71545ba (inline)
1 parent 3ff6d95 commit 9b5caef

4 files changed

Lines changed: 289 additions & 5 deletions

File tree

sjsonnet/src/sjsonnet/Evaluator.scala

Lines changed: 223 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,229 @@ class Evaluator(
189189
visitExpr(e.returned)(s)
190190
}
191191

192-
def visitComp(e: Comp)(implicit scope: ValScope): Val =
193-
Val.Arr(
194-
e.pos,
195-
visitComp(e.first :: e.rest.toList, Array(scope)).map(s => visitAsLazy(e.value)(s))
196-
)
192+
def visitComp(e: Comp)(implicit scope: ValScope): Val = {
193+
val body = e.value
194+
val results = new collection.mutable.ArrayBuilder.ofRef[Eval]
195+
results.sizeHint(256)
196+
visitCompInline(e.first :: e.rest.toList, scope, body, results)
197+
Val.Arr(e.pos, results.result())
198+
}
199+
200+
/**
201+
* Fused scope-building + body evaluation: eliminates intermediate scope array allocation. For
202+
* innermost ForSpec with BinaryOp(ValidId,ValidId) body, inlines scope lookups and binary-op
203+
* dispatch to avoid 3× visitExpr overhead per iteration.
204+
*/
205+
private def visitCompInline(
206+
specs: List[CompSpec],
207+
scope: ValScope,
208+
body: Expr,
209+
results: collection.mutable.ArrayBuilder.ofRef[Eval]): Unit = specs match {
210+
case (spec @ ForSpec(_, name, expr)) :: rest =>
211+
visitExpr(expr)(scope) match {
212+
case a: Val.Arr =>
213+
if (debugStats != null) debugStats.arrayCompIterations += a.length
214+
val lazyArr = a.asLazyArray
215+
rest match {
216+
case Nil =>
217+
// Innermost loop — try BinaryOp(ValidId,ValidId) fast path
218+
body match {
219+
case binOp: BinaryOp
220+
if binOp.lhs.tag == ExprTags.ValidId
221+
&& binOp.rhs.tag == ExprTags.ValidId =>
222+
// Fast path: reuse mutable scope, inline scope lookups + binary-op dispatch.
223+
// Avoids 3× visitExpr dispatch overhead per iteration.
224+
// NOTE: This evaluates the body eagerly (not lazy). This differs from the
225+
// normal path which uses visitAsLazy. The practical impact is minimal —
226+
// only observable if errors in comprehension elements are never accessed.
227+
// Both go-jsonnet and jrsonnet also evaluate comprehensions eagerly.
228+
val mutableScope = scope.extendMutable()
229+
val slot = scope.bindings.length
230+
val bindings = mutableScope.bindings
231+
val lhsIdx = binOp.lhs.asInstanceOf[ValidId].nameIdx
232+
val rhsIdx = binOp.rhs.asInstanceOf[ValidId].nameIdx
233+
val op = binOp.op
234+
val bpos = binOp.pos
235+
var j = 0
236+
while (j < lazyArr.length) {
237+
bindings(slot) = lazyArr(j)
238+
val l = bindings(lhsIdx).value
239+
val r = bindings(rhsIdx).value
240+
(l, r) match {
241+
case (ln: Val.Num, rn: Val.Num) =>
242+
results += evalBinaryOpNumNum(op, ln, rn, bpos)
243+
case _ =>
244+
results += visitBinaryOpValues(op, l, r, bpos)
245+
}
246+
j += 1
247+
}
248+
case _ =>
249+
var j = 0
250+
while (j < lazyArr.length) {
251+
val newScope = scope.extendSimple(lazyArr(j))
252+
results += visitAsLazy(body)(newScope)
253+
j += 1
254+
}
255+
}
256+
case _ =>
257+
var j = 0
258+
while (j < lazyArr.length) {
259+
val newScope = scope.extendSimple(lazyArr(j))
260+
visitCompInline(rest, newScope, body, results)
261+
j += 1
262+
}
263+
}
264+
case r =>
265+
Error.fail(
266+
"In comprehension, can only iterate over array, not " + r.prettyName,
267+
spec
268+
)
269+
}
270+
case (spec @ IfSpec(offset, expr)) :: rest =>
271+
visitExpr(expr)(scope) match {
272+
case Val.True(_) =>
273+
rest match {
274+
case Nil => results += visitAsLazy(body)(scope)
275+
case _ => visitCompInline(rest, scope, body, results)
276+
}
277+
case Val.False(_) => // filtered out
278+
case other =>
279+
Error.fail(
280+
"Condition must be boolean, got " + other.prettyName,
281+
spec
282+
)
283+
}
284+
case Nil =>
285+
results += visitAsLazy(body)(scope)
286+
}
287+
288+
/**
289+
* Fast-path binary op evaluation for Num×Num operands. Handles comparison, arithmetic, modulo,
290+
* bitwise, and shift operations without visitExpr dispatch overhead. Includes the same safety
291+
* checks (overflow, division-by-zero) as visitBinaryOp.
292+
*/
293+
@inline private def evalBinaryOpNumNum(op: Int, ln: Val.Num, rn: Val.Num, pos: Position): Val = {
294+
val ld = ln.asDouble
295+
val rd = rn.asDouble
296+
(op: @switch) match {
297+
case Expr.BinaryOp.OP_< => Val.bool(pos, ld < rd)
298+
case Expr.BinaryOp.OP_> => Val.bool(pos, ld > rd)
299+
case Expr.BinaryOp.OP_<= => Val.bool(pos, ld <= rd)
300+
case Expr.BinaryOp.OP_>= => Val.bool(pos, ld >= rd)
301+
case Expr.BinaryOp.OP_== => Val.bool(pos, ld == rd)
302+
case Expr.BinaryOp.OP_!= => Val.bool(pos, ld != rd)
303+
case Expr.BinaryOp.OP_+ =>
304+
val r = ld + rd
305+
if (r.isInfinite) Error.fail("overflow", pos)
306+
Val.Num(pos, r)
307+
case Expr.BinaryOp.OP_- =>
308+
val r = ld - rd
309+
if (r.isInfinite) Error.fail("overflow", pos)
310+
Val.Num(pos, r)
311+
case Expr.BinaryOp.OP_* =>
312+
val r = ld * rd
313+
if (r.isInfinite) Error.fail("overflow", pos)
314+
Val.Num(pos, r)
315+
case Expr.BinaryOp.OP_/ =>
316+
if (rd == 0) Error.fail("division by zero", pos)
317+
val r = ld / rd
318+
if (r.isInfinite) Error.fail("overflow", pos)
319+
Val.Num(pos, r)
320+
case Expr.BinaryOp.OP_% => Val.Num(pos, ld % rd)
321+
case Expr.BinaryOp.OP_<< =>
322+
val ll = ld.toSafeLong(pos); val rr = rd.toSafeLong(pos)
323+
if (rr < 0) Error.fail("shift by negative exponent", pos)
324+
if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr)))
325+
Error.fail("numeric value outside safe integer range for bitwise operation", pos)
326+
Val.Num(pos, (ll << rr).toDouble)
327+
case Expr.BinaryOp.OP_>> =>
328+
val ll = ld.toSafeLong(pos); val rr = rd.toSafeLong(pos)
329+
if (rr < 0) Error.fail("shift by negative exponent", pos)
330+
Val.Num(pos, (ll >> rr).toDouble)
331+
case Expr.BinaryOp.OP_& =>
332+
Val.Num(pos, (ld.toSafeLong(pos) & rd.toSafeLong(pos)).toDouble)
333+
case Expr.BinaryOp.OP_^ =>
334+
Val.Num(pos, (ld.toSafeLong(pos) ^ rd.toSafeLong(pos)).toDouble)
335+
case Expr.BinaryOp.OP_| =>
336+
Val.Num(pos, (ld.toSafeLong(pos) | rd.toSafeLong(pos)).toDouble)
337+
case _ => visitBinaryOpValues(op, ln, rn, pos)
338+
}
339+
}
340+
341+
/**
342+
* Evaluate a binary op on two pre-evaluated values. Covers polymorphic operators (comparison,
343+
* equality, string concat/format, object merge, array concat, in) as fallback when the Num×Num
344+
* fast path doesn't apply.
345+
*/
346+
private def visitBinaryOpValues(op: Int, l: Val, r: Val, pos: Position): Val = {
347+
(op: @switch) match {
348+
case Expr.BinaryOp.OP_< =>
349+
(l, r) match {
350+
case (ln: Val.Num, rn: Val.Num) => Val.bool(pos, ln.asDouble < rn.asDouble)
351+
case (ls: Val.Str, rs: Val.Str) =>
352+
Val.bool(pos, Util.compareStringsByCodepoint(ls.str, rs.str) < 0)
353+
case (la: Val.Arr, ra: Val.Arr) => Val.bool(pos, compare(la, ra) < 0)
354+
case _ => failBinOp(l, op, r, pos)
355+
}
356+
case Expr.BinaryOp.OP_> =>
357+
(l, r) match {
358+
case (ln: Val.Num, rn: Val.Num) => Val.bool(pos, ln.asDouble > rn.asDouble)
359+
case (ls: Val.Str, rs: Val.Str) =>
360+
Val.bool(pos, Util.compareStringsByCodepoint(ls.str, rs.str) > 0)
361+
case (la: Val.Arr, ra: Val.Arr) => Val.bool(pos, compare(la, ra) > 0)
362+
case _ => failBinOp(l, op, r, pos)
363+
}
364+
case Expr.BinaryOp.OP_<= =>
365+
(l, r) match {
366+
case (ln: Val.Num, rn: Val.Num) => Val.bool(pos, ln.asDouble <= rn.asDouble)
367+
case (ls: Val.Str, rs: Val.Str) =>
368+
Val.bool(pos, Util.compareStringsByCodepoint(ls.str, rs.str) <= 0)
369+
case (la: Val.Arr, ra: Val.Arr) => Val.bool(pos, compare(la, ra) <= 0)
370+
case _ => failBinOp(l, op, r, pos)
371+
}
372+
case Expr.BinaryOp.OP_>= =>
373+
(l, r) match {
374+
case (ln: Val.Num, rn: Val.Num) => Val.bool(pos, ln.asDouble >= rn.asDouble)
375+
case (ls: Val.Str, rs: Val.Str) =>
376+
Val.bool(pos, Util.compareStringsByCodepoint(ls.str, rs.str) >= 0)
377+
case (la: Val.Arr, ra: Val.Arr) => Val.bool(pos, compare(la, ra) >= 0)
378+
case _ => failBinOp(l, op, r, pos)
379+
}
380+
case Expr.BinaryOp.OP_== =>
381+
if (l.isInstanceOf[Val.Func] && r.isInstanceOf[Val.Func])
382+
Error.fail("cannot test equality of functions", pos)
383+
Val.bool(pos, equal(l, r))
384+
case Expr.BinaryOp.OP_!= =>
385+
if (l.isInstanceOf[Val.Func] && r.isInstanceOf[Val.Func])
386+
Error.fail("cannot test equality of functions", pos)
387+
Val.bool(pos, !equal(l, r))
388+
case Expr.BinaryOp.OP_+ =>
389+
(l, r) match {
390+
case (ln: Val.Num, rn: Val.Num) =>
391+
val res = ln.asDouble + rn.asDouble
392+
if (res.isInfinite) Error.fail("overflow", pos)
393+
Val.Num(pos, res)
394+
case (ls: Val.Str, rs: Val.Str) => Val.Str(pos, ls.str + rs.str)
395+
case (ls: Val.Str, _) => Val.Str(pos, ls.str + Materializer.stringify(r))
396+
case (_, rs: Val.Str) => Val.Str(pos, Materializer.stringify(l) + rs.str)
397+
case (lo: Val.Obj, ro: Val.Obj) => ro.addSuper(pos, lo)
398+
case (la: Val.Arr, ra: Val.Arr) => la.concat(pos, ra)
399+
case _ => failBinOp(l, op, r, pos)
400+
}
401+
case Expr.BinaryOp.OP_% =>
402+
(l, r) match {
403+
case (ln: Val.Num, rn: Val.Num) => Val.Num(pos, ln.asDouble % rn.asDouble)
404+
case (ls: Val.Str, _) => Val.Str(pos, Format.format(ls.str, r, pos))
405+
case _ => failBinOp(l, op, r, pos)
406+
}
407+
case Expr.BinaryOp.OP_in =>
408+
(l, r) match {
409+
case (ls: Val.Str, o: Val.Obj) => Val.bool(pos, o.containsKey(ls.str))
410+
case _ => failBinOp(l, op, r, pos)
411+
}
412+
case _ => failBinOp(l, op, r, pos)
413+
}
414+
}
197415

198416
def visitArr(e: Arr)(implicit scope: ValScope): Val =
199417
Val.Arr(e.pos, e.value.map(visitAsLazy))

sjsonnet/src/sjsonnet/ValScope.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ final class ValScope private (val bindings: Array[Eval]) extends AnyVal {
5959
b(bindings.length + 2) = l3
6060
new ValScope(b)
6161
}
62+
63+
/**
64+
* Extend scope with one mutable slot for in-place reuse in tight comprehension loops. The caller
65+
* can mutate `bindings(parentLength)` without allocating a new scope per iteration. Only safe
66+
* when the body does not create lazy values that capture the scope array.
67+
*/
68+
def extendMutable(): ValScope = {
69+
val b = util.Arrays.copyOf(bindings, bindings.length + 1)
70+
new ValScope(b)
71+
}
6272
}
6373

6474
object ValScope {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Regression test: all binary operators in comprehensions with ValidId operands
2+
local strs = ["hello", "world"];
3+
local nums = [1, 2, 3];
4+
local arrs = [[1, 2], [3, 4]];
5+
6+
// String concatenation
7+
local str_concat = [a + b for a in strs for b in strs];
8+
9+
// Numeric arithmetic
10+
local num_add = [a + b for a in nums for b in nums];
11+
local num_sub = [a - b for a in [10, 20] for b in [3, 5]];
12+
local num_mul = [a * b for a in [2, 3] for b in [4, 5]];
13+
local num_div = [a / b for a in [10, 20] for b in [2, 5]];
14+
local num_mod = [a % b for a in [10, 7] for b in [3, 4]];
15+
16+
// Comparison operators
17+
local cmp_lt = [a < b for a in nums for b in nums];
18+
local cmp_eq = [a == b for a in nums for b in nums];
19+
local cmp_ne = [a != b for a in nums for b in nums];
20+
21+
// Bitwise operators
22+
local bw_and = [a & b for a in [3, 5] for b in [6, 7]];
23+
local bw_or = [a | b for a in [3, 5] for b in [6, 7]];
24+
local bw_xor = [a ^ b for a in [3, 5] for b in [6, 7]];
25+
local bw_shl = [a << b for a in [1, 2] for b in [1, 2]];
26+
local bw_shr = [a >> b for a in [8, 16] for b in [1, 2]];
27+
28+
// String formatting
29+
local str_fmt = [a % b for a in ["val=%d", "x=%d"] for b in [42, 99]];
30+
31+
// Array concatenation
32+
local arr_concat = [a + b for a in arrs for b in arrs];
33+
34+
// 'in' operator
35+
local objs = [{a: 1}, {b: 2}];
36+
local in_test = [a in b for a in ["a", "b"] for b in objs];
37+
38+
std.assertEqual(str_concat, ["hellohello", "helloworld", "worldhello", "worldworld"]) &&
39+
std.assertEqual(num_add, [2, 3, 4, 3, 4, 5, 4, 5, 6]) &&
40+
std.assertEqual(num_sub, [7, 5, 17, 15]) &&
41+
std.assertEqual(num_mul, [8, 10, 12, 15]) &&
42+
std.assertEqual(num_div, [5, 2, 10, 4]) &&
43+
std.assertEqual(num_mod, [1, 2, 1, 3]) &&
44+
std.assertEqual(cmp_lt, [false, true, true, false, false, true, false, false, false]) &&
45+
std.assertEqual(cmp_eq, [true, false, false, false, true, false, false, false, true]) &&
46+
std.assertEqual(cmp_ne, [false, true, true, true, false, true, true, true, false]) &&
47+
std.assertEqual(bw_and, [2, 3, 4, 5]) &&
48+
std.assertEqual(bw_or, [7, 7, 7, 7]) &&
49+
std.assertEqual(bw_xor, [5, 4, 3, 2]) &&
50+
std.assertEqual(bw_shl, [2, 4, 4, 8]) &&
51+
std.assertEqual(bw_shr, [4, 2, 8, 4]) &&
52+
std.assertEqual(str_fmt, ["val=42", "val=99", "x=42", "x=99"]) &&
53+
std.assertEqual(arr_concat, [[1, 2, 1, 2], [1, 2, 3, 4], [3, 4, 1, 2], [3, 4, 3, 4]]) &&
54+
std.assertEqual(in_test, [true, false, false, true]) &&
55+
true
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
true

0 commit comments

Comments
 (0)