Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 94 additions & 40 deletions sjsonnet/src/sjsonnet/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,17 @@ class Evaluator(
}
}

def visitArr(e: Arr)(implicit scope: ValScope): Val =
Val.Arr(e.pos, e.value.map(visitAsLazy))
def visitArr(e: Arr)(implicit scope: ValScope): Val = {
val src = e.value
val len = src.length
val res = new Array[Eval](len)
var i = 0
while (i < len) {
res(i) = visitAsLazy(src(i))
i += 1
}
Val.Arr(e.pos, res)
}

def visitSelectSuper(e: SelectSuper)(implicit scope: ValScope): Val = {
val sup = scope.bindings(e.selfIdx + 1).asInstanceOf[Val.Obj]
Expand Down Expand Up @@ -755,14 +764,18 @@ class Evaluator(
}

def visitImportBin(e: ImportBin): Val.Arr = {
Val.Arr(
e.pos,
importer
.resolveAndReadOrFail(e.value, e.pos, binaryData = true)
._2
.readRawBytes()
.map(x => Val.Num(e.pos, (x & 0xff).doubleValue))
)
val bytes = importer
.resolveAndReadOrFail(e.value, e.pos, binaryData = true)
._2
.readRawBytes()
val len = bytes.length
val res = new Array[Eval](len)
var i = 0
while (i < len) {
res(i) = Val.Num(e.pos, (bytes(i) & 0xff).doubleValue)
i += 1
}
Val.Arr(e.pos, res)
}

def visitImport(e: Import): Val = {
Expand Down Expand Up @@ -1208,13 +1221,28 @@ class Evaluator(
val k = visitFieldName(fieldName, offset)
if (k != null) {
val fieldKey = k
val v = new Val.Obj.Member(plus, sep) {
def invoke(self: Val.Obj, sup: Val.Obj, fs: FileScope, ev: EvalScope): Val = {
checkStackDepth(rhs.pos, fieldKey)
try visitExpr(rhs)(makeNewScope(self, sup))
finally decrementStackDepth()
}
// ConstMember fast path: for simple field bodies (Val literals or
// parent-scope ValidId that already resolved to a Val), pre-compute
// the value and skip the Member closure + scope extension in invoke.
val constVal: Val = rhs match {
case v: Val => v
case vid: ValidId if vid.nameIdx < scope.length =>
scope.bindings(vid.nameIdx) match {
case v: Val => v
case _ => null
}
case _ => null
}
val v: Val.Obj.Member =
if (constVal ne null) new Val.Obj.ConstMember(plus, sep, constVal)
else
new Val.Obj.Member(plus, sep) {
def invoke(self: Val.Obj, sup: Val.Obj, fs: FileScope, ev: EvalScope): Val = {
checkStackDepth(rhs.pos, fieldKey)
try visitExpr(rhs)(makeNewScope(self, sup))
finally decrementStackDepth()
}
}
if (fieldCount == 0) {
singleKey = k
singleMember = v
Expand Down Expand Up @@ -1301,7 +1329,9 @@ class Evaluator(
val builder = new java.util.LinkedHashMap[String, Val.Obj.Member]
val compScopes = visitComp(e.first :: e.rest, Array(compScope))
if (debugStats != null) debugStats.objectCompIterations += compScopes.length
for (s <- compScopes) {
var ci = 0
while (ci < compScopes.length) {
val s = compScopes(ci)
visitExpr(e.key)(s) match {
case Val.Str(_, k) =>
val previousValue = builder.put(
Expand All @@ -1323,6 +1353,7 @@ class Evaluator(
case Val.Null(_) => // do nothing
case x => fieldNameTypeError(x, e.pos)
}
ci += 1
}
val valueCache = if (sup == null) {
Val.Obj.getEmptyValueCacheForObjWithoutSuper(builder.size())
Expand Down Expand Up @@ -1359,18 +1390,23 @@ class Evaluator(
}
visitComp(rest, newScopes.result())
case (spec @ IfSpec(offset, expr)) :: rest =>
visitComp(
rest,
scopes.filter(visitExpr(expr)(_) match {
case Val.True(_) => true
case Val.False(_) => false
val filtered = collection.mutable.ArrayBuilder.make[ValScope]
filtered.sizeHint(scopes.length)
var i = 0
while (i < scopes.length) {
val s = scopes(i)
visitExpr(expr)(s) match {
case Val.True(_) => filtered += s
case Val.False(_) => // skip
case other =>
Error.fail(
"Condition must be boolean, got " + other.prettyName,
spec
)
})
)
}
i += 1
}
visitComp(rest, filtered.result())
case Nil => scopes
}

Expand All @@ -1380,27 +1416,34 @@ class Evaluator(
case (x: Val.Str, y: Val.Str) => Util.compareStringsByCodepoint(x.str, y.str)
case (x: Val.Bool, y: Val.Bool) => x.asBoolean.compareTo(y.asBoolean)
case (x: Val.Arr, y: Val.Arr) =>
val len = math.min(x.length, y.length)
val xArr = x.asLazyArray
val yArr = y.asLazyArray
val len = math.min(xArr.length, yArr.length)
// Phase 1: skip shared Eval references (e.g. from array concat)
var i = 0
while (i < len && (xArr(i) eq yArr(i))) { i += 1 }
// Phase 2: compare from first mismatch onwards
while (i < len) {
val xi = x.value(i)
val yi = y.value(i)
// Reference equality short-circuit for shared array elements
if (!(xi eq yi)) {
// Inline numeric fast path to avoid polymorphic compare() dispatch
val cmp = xi match {
case xn: Val.Num =>
yi match {
case yn: Val.Num => java.lang.Double.compare(xn.asDouble, yn.asDouble)
case _ => compare(xi, yi)
}
case _ => compare(xi, yi)
val xe = xArr(i)
val ye = yArr(i)
if (!(xe eq ye)) {
val xi = xe.value
val yi = ye.value
if (!(xi eq yi)) {
val cmp = xi match {
case xn: Val.Num =>
yi match {
case yn: Val.Num => java.lang.Double.compare(xn.asDouble, yn.asDouble)
case _ => compare(xi, yi)
}
case _ => compare(xi, yi)
}
if (cmp != 0) return cmp
}
if (cmp != 0) return cmp
}
i += 1
}
Integer.compare(x.length, y.length)
Integer.compare(xArr.length, yArr.length)
case _ => Error.fail("Cannot compare " + x.prettyName + " with " + y.prettyName, x.pos)
}

Expand All @@ -1423,9 +1466,20 @@ class Evaluator(
case y: Val.Arr =>
val xlen = x.length
if (xlen != y.length) return false
val xArr = x.asLazyArray
val yArr = y.asLazyArray
var i = 0
// Phase 1: skip shared Eval references
while (i < xlen && (xArr(i) eq yArr(i))) { i += 1 }
// Phase 2: compare remaining elements
while (i < xlen) {
if (!equal(x.value(i), y.value(i))) return false
val xe = xArr(i)
val ye = yArr(i)
if (!(xe eq ye)) {
val xv = xe.value
val yv = ye.value
if (!(xv eq yv) && !equal(xv, yv)) return false
}
i += 1
}
true
Expand Down
60 changes: 39 additions & 21 deletions sjsonnet/src/sjsonnet/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,24 @@ class Parser(
else Fail.opaque("no duplicate field: " + overlap)
}.flatMapX {
case (exprs, None) =>
val b =
exprs.iterator.filter(_.isInstanceOf[Expr.Bind]).asInstanceOf[Iterator[Expr.Bind]].toArray
// Single-pass classification: classify members into binds, fields, and asserts
// in one loop instead of three separate iterator.filter passes.
val bindsBuilder = new scala.collection.mutable.ArrayBuilder.ofRef[Expr.Bind]
val fieldsBuilder = new scala.collection.mutable.ArrayBuilder.ofRef[Expr.Member.Field]
val assertsBuilder = new scala.collection.mutable.ArrayBuilder.ofRef[Expr.Member.AssertStmt]
var i = 0
while (i < exprs.length) {
exprs(i) match {
case bind: Expr.Bind => bindsBuilder += bind
case field: Expr.Member.Field => fieldsBuilder += field
case assert: Expr.Member.AssertStmt => assertsBuilder += assert
case _ =>
}
i += 1
}
val b = bindsBuilder.result()
val fields = fieldsBuilder.result()

val seen = collection.mutable.Set.empty[String]
var overlap: String = null
b.foreach {
Expand All @@ -813,27 +829,24 @@ class Parser(
Fail.opaque("no duplicate local: " + overlap)
} else {
val binds = if (b.isEmpty) null else b

val fields = exprs.iterator
.filter(_.isInstanceOf[Expr.Member.Field])
.asInstanceOf[Iterator[Expr.Member.Field]]
.toArray
val asserts = {
val a = exprs.iterator
.filter(_.isInstanceOf[Expr.Member.AssertStmt])
.asInstanceOf[Iterator[Expr.Member.AssertStmt]]
.toArray
val a = assertsBuilder.result()
if (a.isEmpty) null else a
}
if (binds == null && asserts == null && fields.forall(_.isStatic))
Pass(Val.staticObject(pos, fields, internedStaticFieldSets, internedStrings))
else Pass(Expr.ObjBody.MemberList(pos, binds, fields, asserts))
}
case (exprs, Some(comps)) =>
val preLocals = exprs
.takeWhile(_.isInstanceOf[Expr.Bind])
.map(_.asInstanceOf[Expr.Bind])
if (preLocals.nonEmpty && exprs.length == preLocals.length) {
// Index-based single-pass: find preLocals (leading Bind exprs)
var preEnd = 0
while (preEnd < exprs.length && exprs(preEnd).isInstanceOf[Expr.Bind]) preEnd += 1
val preLocals = {
val arr = new Array[Expr.Bind](preEnd)
var j = 0; while (j < preEnd) { arr(j) = exprs(j).asInstanceOf[Expr.Bind]; j += 1 }
arr
}
if (preLocals.length > 0 && exprs.length == preLocals.length) {
Fail.opaque("object comprehension must have a field")
} else
exprs(preLocals.length) match {
Expand All @@ -850,10 +863,15 @@ class Parser(
} else {
Expr.Function(offset, args, rhsBody)
}
val postLocals = exprs
.drop(preLocals.length + 1)
.takeWhile(_.isInstanceOf[Expr.Bind])
.map(_.asInstanceOf[Expr.Bind])
val postLocals = {
val start = preEnd + 1
var end = start
while (end < exprs.length && exprs(end).isInstanceOf[Expr.Bind]) end += 1
val arr = new Array[Expr.Bind](end - start)
var j = 0;
while (j < arr.length) { arr(j) = exprs(start + j).asInstanceOf[Expr.Bind]; j += 1 }
arr
}

/*
* Prevent duplicate fields in list comprehension. See: https://github.com/databricks/sjsonnet/issues/99
Expand All @@ -871,11 +889,11 @@ class Parser(
Pass(
Expr.ObjBody.ObjComp(
pos,
preLocals.toArray,
preLocals,
lhs,
rhs,
plus,
postLocals.toArray,
postLocals,
comps._1,
comps._2.toList
)
Expand Down
12 changes: 10 additions & 2 deletions sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,11 @@ object Val {
}
if (simple) {
if (tailstrictMode == TailstrictModeEnabled) {
argsL.foreach(_.value)
var i = 0
while (i < argsL.length) {
argsL(i).value
i += 1
}
}
val newScope = defSiteValScope.extendSimple(argsL)
val result = evalRhs(newScope, ev, funDefFileScope, outerPos)
Expand Down Expand Up @@ -1022,7 +1026,11 @@ object Val {
}
}
if (tailstrictMode == TailstrictModeEnabled) {
argVals.foreach(_.value)
var i = 0
while (i < argVals.length) {
argVals(i).value
i += 1
}
}
val result = evalRhs(newScope, ev, funDefFileScope, outerPos)
if (tailstrictMode == TailstrictModeDisabled) TailCall.resolve(result) else result
Expand Down
19 changes: 17 additions & 2 deletions sjsonnet/src/sjsonnet/stdlib/SetModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@ object SetModule extends AbstractFunctionModule {
val sortedIndices = if (keyType == classOf[Val.Str]) {
indices.sortBy(i => keys(i).cast[Val.Str].asString)(Util.CodepointStringOrdering)
} else if (keyType == classOf[Val.Num]) {
indices.sortBy(i => keys(i).cast[Val.Num].asDouble)
// Extract doubles into primitive array for unboxed comparison
val dkeys = new Array[Double](keys.length)
var di = 0
while (di < dkeys.length) {
dkeys(di) = keys(di).asInstanceOf[Val.Num].asDouble; di += 1
}
indices.sortWith((a, b) => java.lang.Double.compare(dkeys(a), dkeys(b)) < 0)
} else if (keyType == classOf[Val.Arr]) {
indices.sortBy(i => keys(i).cast[Val.Arr])(ev.compare(_, _))
} else {
Expand All @@ -128,7 +134,16 @@ object SetModule extends AbstractFunctionModule {
if (keyType == classOf[Val.Str]) {
strict.map(_.cast[Val.Str]).sortBy(_.asString)(Util.CodepointStringOrdering)
} else if (keyType == classOf[Val.Num]) {
strict.map(_.cast[Val.Num]).sortBy(_.asDouble)
// Primitive double sort: extract doubles, sort with DualPivotQuicksort,
// reconstruct Val.Num array. Avoids Comparator virtual dispatch + boxing.
val n = strict.length
val doubles = new Array[Double](n)
var di = 0
while (di < n) { doubles(di) = strict(di).asInstanceOf[Val.Num].asDouble; di += 1 }
java.util.Arrays.sort(doubles)
di = 0
while (di < n) { strict(di) = Val.Num(pos, doubles(di)); di += 1 }
strict
} else if (keyType == classOf[Val.Arr]) {
strict.map(_.cast[Val.Arr]).sortBy(identity)(ev.compare(_, _))
} else if (keyType == classOf[Val.Obj]) {
Expand Down
Loading