Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 150 additions & 5 deletions sjsonnet/src/sjsonnet/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,156 @@ class Evaluator(
visitExpr(e.returned)(s)
}

def visitComp(e: Comp)(implicit scope: ValScope): Val =
Val.Arr(
e.pos,
visitComp(e.first :: e.rest.toList, Array(scope)).map(s => visitAsLazy(e.value)(s))
)
def visitComp(e: Comp)(implicit scope: ValScope): Val = {
val results = new collection.mutable.ArrayBuilder.ofRef[Eval]
results.sizeHint(16)
visitCompFused(e.first :: e.rest.toList, scope, e.value, results)
Val.Arr(e.pos, results.result())
}

/**
* Fused scope-building + body evaluation: eliminates intermediate scope array allocation. Instead
* of first collecting all valid scopes into an Array[ValScope] and then mapping over them with
* visitAsLazy, this method directly appends body results as it encounters valid scopes. For
* nested comprehensions like `[x+y for x in arr for y in arr if x==y]`, this avoids allocating
* O(n²) intermediate scopes — only the O(n) matching results are materialized.
*
* For innermost ForSpec with BinaryOp(ValidId,ValidId) body, inlines scope lookups and numeric
* binary-op dispatch to avoid 3× visitExpr overhead per iteration.
*/
private def visitCompFused(
specs: List[CompSpec],
scope: ValScope,
body: Expr,
results: collection.mutable.ArrayBuilder.ofRef[Eval]
): Unit = specs match {
case (spec @ ForSpec(_, name, expr)) :: rest =>
visitExpr(expr)(scope) match {
case a: Val.Arr =>
if (debugStats != null) debugStats.arrayCompIterations += a.length
val lazyArr = a.asLazyArray
if (rest.isEmpty) {
// Innermost loop: try BinaryOp(ValidId,ValidId) fast path
body match {
case binOp: BinaryOp
if binOp.lhs.tag == ExprTags.ValidId
&& binOp.rhs.tag == ExprTags.ValidId =>
// Fast path: reuse mutable scope, inline scope lookups + binary-op dispatch.
// NOTE: Evaluates eagerly (not lazy). Both go-jsonnet and jrsonnet also
// evaluate comprehensions eagerly, so this is compatible. Eagerness is
// required for mutable scope reuse — a lazy thunk would capture the
// mutable scope and see stale bindings from later iterations.
val mutableScope = scope.extendBy(1)
val slot = scope.bindings.length
val bindings = mutableScope.bindings
val lhsIdx = binOp.lhs.asInstanceOf[ValidId].nameIdx
val rhsIdx = binOp.rhs.asInstanceOf[ValidId].nameIdx
val op = binOp.op
val bpos = binOp.pos
var j = 0
while (j < lazyArr.length) {
bindings(slot) = lazyArr(j)
val l = bindings(lhsIdx).value
val r = bindings(rhsIdx).value
(l, r) match {
// Only dispatch to numeric fast path for ops it handles (0-16 except OP_in=11).
// OP_in expects string+object, OP_&&/OP_|| need short-circuit semantics.
case (ln: Val.Num, rn: Val.Num)
if op <= Expr.BinaryOp.OP_| && op != Expr.BinaryOp.OP_in =>
results += evalBinaryOpNumNum(op, ln, rn, bpos)
case _ =>
// Fallback to general evaluator for non-numeric types
results += visitExpr(binOp)(mutableScope)
}
j += 1
}
case _ =>
var j = 0
while (j < lazyArr.length) {
results += visitAsLazy(body)(scope.extendSimple(lazyArr(j)))
j += 1
}
}
} else {
// Outer loop: recurse for remaining specs
var j = 0
while (j < lazyArr.length) {
visitCompFused(rest, scope.extendSimple(lazyArr(j)), body, results)
j += 1
}
}
case r =>
Error.fail(
"In comprehension, can only iterate over array, not " + r.prettyName,
spec
)
}
case (spec @ IfSpec(offset, expr)) :: rest =>
visitExpr(expr)(scope) match {
case Val.True(_) =>
if (rest.isEmpty) results += visitAsLazy(body)(scope)
else visitCompFused(rest, scope, body, results)
case Val.False(_) => // filtered out
case other =>
Error.fail(
"Condition must be boolean, got " + other.prettyName,
spec
)
}
case Nil =>
results += visitAsLazy(body)(scope)
}

/**
* Fast-path binary op evaluation for Num×Num operands within comprehension inner loops. Handles
* the most common operations without visitExpr dispatch overhead.
*/
@inline private def evalBinaryOpNumNum(op: Int, ln: Val.Num, rn: Val.Num, pos: Position): Val = {
val ld = ln.asDouble
val rd = rn.asDouble
(op: @switch) match {
case Expr.BinaryOp.OP_+ => Val.Num(pos, ld + rd)
case Expr.BinaryOp.OP_- =>
val r = ld - rd
if (r.isInfinite) Error.fail("overflow", pos)
Val.Num(pos, r)
case Expr.BinaryOp.OP_* =>
val r = ld * rd
if (r.isInfinite) Error.fail("overflow", pos)
Val.Num(pos, r)
case Expr.BinaryOp.OP_/ =>
if (rd == 0) Error.fail("division by zero", pos)
val r = ld / rd
if (r.isInfinite) Error.fail("overflow", pos)
Val.Num(pos, r)
case Expr.BinaryOp.OP_% => Val.Num(pos, ld % rd)
case Expr.BinaryOp.OP_< => Val.bool(pos, ld < rd)
case Expr.BinaryOp.OP_> => Val.bool(pos, ld > rd)
case Expr.BinaryOp.OP_<= => Val.bool(pos, ld <= rd)
case Expr.BinaryOp.OP_>= => Val.bool(pos, ld >= rd)
case Expr.BinaryOp.OP_== => Val.bool(pos, ld == rd)
case Expr.BinaryOp.OP_!= => Val.bool(pos, ld != rd)
case Expr.BinaryOp.OP_<< =>
val ll = ld.toSafeLong(pos); val rr = rd.toSafeLong(pos)
if (rr < 0) Error.fail("shift by negative exponent", pos)
if (rr >= 1 && math.abs(ll) >= (1L << (63 - rr)))
Error.fail("numeric value outside safe integer range for bitwise operation", pos)
Val.Num(pos, (ll << rr).toDouble)
case Expr.BinaryOp.OP_>> =>
val ll = ld.toSafeLong(pos); val rr = rd.toSafeLong(pos)
if (rr < 0) Error.fail("shift by negative exponent", pos)
Val.Num(pos, (ll >> rr).toDouble)
case Expr.BinaryOp.OP_& =>
Val.Num(pos, (ld.toSafeLong(pos) & rd.toSafeLong(pos)).toDouble)
case Expr.BinaryOp.OP_^ =>
Val.Num(pos, (ld.toSafeLong(pos) ^ rd.toSafeLong(pos)).toDouble)
case Expr.BinaryOp.OP_| =>
Val.Num(pos, (ld.toSafeLong(pos) | rd.toSafeLong(pos)).toDouble)
case _ =>
// Should be unreachable: caller filters to ops 0-16 except OP_in
throw new AssertionError(s"Unexpected numeric binary op: $op")
}
}

def visitArr(e: Arr)(implicit scope: ValScope): Val =
Val.Arr(e.pos, e.value.map(visitAsLazy))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Regression test: all binary operators in comprehensions with ValidId operands
local strs = ["hello", "world"];
local nums = [1, 2, 3];
local arrs = [[1, 2], [3, 4]];

// String concatenation
local str_concat = [a + b for a in strs for b in strs];

// Numeric arithmetic
local num_add = [a + b for a in nums for b in nums];
local num_sub = [a - b for a in [10, 20] for b in [3, 5]];
local num_mul = [a * b for a in [2, 3] for b in [4, 5]];
local num_div = [a / b for a in [10, 20] for b in [2, 5]];
local num_mod = [a % b for a in [10, 7] for b in [3, 4]];

// Comparison operators
local cmp_lt = [a < b for a in nums for b in nums];
local cmp_eq = [a == b for a in nums for b in nums];
local cmp_ne = [a != b for a in nums for b in nums];

// Bitwise operators
local bw_and = [a & b for a in [3, 5] for b in [6, 7]];
local bw_or = [a | b for a in [3, 5] for b in [6, 7]];
local bw_xor = [a ^ b for a in [3, 5] for b in [6, 7]];
local bw_shl = [a << b for a in [1, 2] for b in [1, 2]];
local bw_shr = [a >> b for a in [8, 16] for b in [1, 2]];

// String formatting
local str_fmt = [a % b for a in ["val=%d", "x=%d"] for b in [42, 99]];

// Array concatenation
local arr_concat = [a + b for a in arrs for b in arrs];

// 'in' operator
local objs = [{a: 1}, {b: 2}];
local in_test = [a in b for a in ["a", "b"] for b in objs];

std.assertEqual(str_concat, ["hellohello", "helloworld", "worldhello", "worldworld"]) &&
std.assertEqual(num_add, [2, 3, 4, 3, 4, 5, 4, 5, 6]) &&
std.assertEqual(num_sub, [7, 5, 17, 15]) &&
std.assertEqual(num_mul, [8, 10, 12, 15]) &&
std.assertEqual(num_div, [5, 2, 10, 4]) &&
std.assertEqual(num_mod, [1, 2, 1, 3]) &&
std.assertEqual(cmp_lt, [false, true, true, false, false, true, false, false, false]) &&
std.assertEqual(cmp_eq, [true, false, false, false, true, false, false, false, true]) &&
std.assertEqual(cmp_ne, [false, true, true, true, false, true, true, true, false]) &&
std.assertEqual(bw_and, [2, 3, 4, 5]) &&
std.assertEqual(bw_or, [7, 7, 7, 7]) &&
std.assertEqual(bw_xor, [5, 4, 3, 2]) &&
std.assertEqual(bw_shl, [2, 4, 4, 8]) &&
std.assertEqual(bw_shr, [4, 2, 8, 4]) &&
std.assertEqual(str_fmt, ["val=42", "val=99", "x=42", "x=99"]) &&
std.assertEqual(arr_concat, [[1, 2, 1, 2], [1, 2, 3, 4], [3, 4, 1, 2], [3, 4, 3, 4]]) &&
std.assertEqual(in_test, [true, false, false, true]) &&
true
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
true