diff --git a/docs/adr/0001-drop-semantic-mutate-op.md b/docs/adr/0001-drop-semantic-mutate-op.md new file mode 100644 index 0000000..7dabb51 --- /dev/null +++ b/docs/adr/0001-drop-semantic-mutate-op.md @@ -0,0 +1,224 @@ +# ADR 0001: Unify calculated measures and post-aggregation `mutate` on a single ibis-expression primitive + +- **Status:** Partially implemented — Phases 1+2 landed on `hussain/feat/calc-measure-analyzer`; Phase 3 outstanding. +- **Date:** 2026-05-08 (revised 2026-05-09 to reflect landed work) +- **Deciders:** BSL maintainers +- **Related code (current state):** `src/boring_semantic_layer/calc_analyzer.py` (new), `src/boring_semantic_layer/calc_compiler.py` (new), `src/boring_semantic_layer/nested_compile.py` (new — extracted from deleted `compile_all.py`), `src/boring_semantic_layer/ops.py` (`CalcMeasure`, `_classify_measure`, `_build_aggregation_plan`, `_compile_aggregation`, `_apply_calc_specs`, `SemanticMutateOp` — still present), `src/boring_semantic_layer/expr.py` (`SemanticMutate`, `.mutate()` chained API — still present), `src/boring_semantic_layer/measure_scope.py` (`MeasureScope`/`ColumnScope` thin proxies; curated AST removed), `src/boring_semantic_layer/serialization/extract.py` (resolver-tree calc serialization). + +## Context + +BSL had two independent mechanisms for deriving a column that depends on already-aggregated values: + +1. **Calculated measures** — declared on the model via `with_measures(...)` and classified as `calc` (vs. `base`) by `_classify_measure`. Stored on `SemanticTableOp.calc_measures` and compiled through a hand-rolled pipeline. The expression language was a curated AST: `MeasureRef | AllOf | BinOp | MethodCall | int | float`, validated by `validate_calc_ast`. +2. **`SemanticMutateOp`** — a post-aggregation chain operator built by `SemanticTable.mutate(**post)` and `SemanticAggregate.mutate(**post)`. Runs an arbitrary user lambda over the aggregated result and adds new columns via `ibis.Table.mutate`. The expression language is *all of ibis*. + +Each system got two things right and two things wrong, and the right things were *orthogonal*: + +| | Calculated measures | `SemanticMutateOp` | +|--------------------------------|----------------------------------------------------------|-----------------------------------------------------------| +| **Placement** (defined where) | ✅ on the model — reusable, catalog-visible | ❌ per-query — anonymous, not in `model.measures` | +| **Expression language** | ❌ curated AST; no `xo.case`, no windows | ✅ full ibis — windows, `xo.case`, arbitrary transforms | +| **Planner integration** | ✅ pre-agg pushdown, `AllOf` lift, structured tags | ❌ opaque — special-cased through `collect_mutates_to_join` | +| **Compilation** | ❌ hand-rolled (`compile_grouped_with_all`, `infer_calc_dtype`) | ✅ ibis compiles it for free | + +Calc measures were right about *placement* and *integration*; `SemanticMutateOp` was right about *expression language* and *compilation*. Maintaining both enshrined the suboptimal tradeoff on each axis. Growing the calc-measure AST node-by-node would have chased mutate's expressivity at the cost of a permanently growing hand-rolled compiler. Defanging mutate would have kept the curated-AST limitation in place forever. + +The right primitive is one that combines mutate's expression language with calc measures' placement and integration: **ibis expressions, declared on the model, classified by analysis rather than by AST tag.** + +## Decision + +**Unify calculated measures and `mutate` on a single primitive: ibis expressions declared on the model, with planner properties (pushability, `AllOf` lift, post-agg-only) recovered by analysis on the ibis tree rather than by curated-AST tagging.** Drop `SemanticMutateOp` as a chain operator. Per-query ad-hoc derivations go through the existing `with_measures(...)` method on `SemanticAggregate` (which the Phase 1+2 cutover already wired through the analyzer) — no new method, no parallel registration path. + +The decision is being executed in three phases: + +- **Phase 1 — Analyzer + ibis-native compiler.** Land the structural classifier (`analyze_calc_expr`) and the calc compiler (`IbisCalcScope`, `apply_calc_measures`, `lift_inline_reductions`, `compile_calc_measure`) alongside the curated-AST path. +- **Phase 2 — Hard cutover.** Replace the curated AST with the analyzer; remove `compile_grouped_with_all`, `validate_calc_ast`, and the curated AST classes. Calc measures are stored as `CalcMeasure(expr=callable)` and re-evaluated against `IbisCalcScope` at query time. `with_measures(...)` on `SemanticTable` *and* on `SemanticAggregate` (`expr.py:1589`) already routes through `_classify_measure`, so the same lambda surface that defines model-level measures also covers query-local ones. +- **Phase 3 — Drop `SemanticMutateOp`.** Remove the chain operator, its planner branches, and its serialization. The chained `.mutate(**post)` API either (a) is removed and users migrate to `.with_measures(**post).aggregate(..., *post.keys())`, or (b) survives as a thin alias that desugars to exactly that — one line, no operator node, no new method name. **Recommendation: (b)** — preserves chain ergonomics for existing user code while collapsing the operator graph to one path. + +## Implementation status + +### Phases 1+2 — Landed (branch `hussain/feat/calc-measure-analyzer`) + +Net diff: **~-727 lines** in production code. Test suite: 978 passed, 1 preexisting unrelated xorq failure. + +What's wired: + +- **`calc_analyzer.py`** — `analyze_calc_expr` walks an ibis tree (skipping `Relation` subtrees) and returns `CalcExprAnalysis(pushable, references_AllOf, has_window, post_agg_only, depends_on, inline_aggs)`. Single-pass `_scan_tree` recognizes plain `Reduction`, real `WindowFunction`, and the agg-of-agg / empty-window-over-reduction patterns that mean "totals." +- **`calc_compiler.py`** — `IbisCalcScope` (dual-table dispatch over base + virtual aggregated + virtual totals), `evaluate_calc_lambda`, `classify_calc_lambda`, `lift_inline_reductions`, `apply_calc_measures`, `compile_calc_measure`. Topological ordering of calc-of-calc chains via `topological_order_from_deps`. +- **`ops.py`** — `CalcMeasure` is the new storage shape. `_classify_measure` runs the lambda once against `IbisCalcScope`, walks the result, and routes to base or calc. `_build_aggregation_plan` / `_compile_aggregation` replace `compile_grouped_with_all`. The pre-agg path's `_apply_calc_specs` and the deferred-join arm both go through `apply_calc_measures`. +- **`measure_scope.py`** — curated AST classes (`MeasureRef`, `AllOf`, `BinOp`, `MethodCall`, `AggregationExpr`, `_PendingMethodCall`, `DeferredColumn`, `validate_calc_ast`) deleted. `MeasureScope` and `ColumnScope` survive as thin pass-through proxies for `SemanticMutateOp` (still present) and for nested-access helpers. +- **`compile_all.py`** — deleted. Nested-array helpers extracted to `nested_compile.py`. +- **`serialization/extract.py`** — `serialize_calc_measures` walks each `CalcMeasure.expr` via `expr_to_structured` and stores the resolver tree plus `description`, `requires_unnest`, and `depends_on`. `deserialize_calc_measures` rebuilds `CalcMeasure(expr=Deferred(...), depends_on=...)`. Backwards-compat for the old bare-tuple format kept at one site. +- **`utils.py`** — `serialize_resolver` / `deserialize_resolver` handle the `Item` resolver (needed for `t["prefixed.name"]`). All FrozenSlotted resolvers built via `object.__new__` go through `_finalize_frozen_slotted` so the rebuilt resolver hashes equal to a freshly-constructed one — fixes a latent bug that would surface as `AttributeError: __precomputed_hash__` when a deserialized resolver was used as a dict key. + +What this gained beyond literal cutover: + +- **Non-sum `t.all(...)` works correctly.** `t.all(measure_ref)` resolves to a `Field(totals_vt, name)`; the compiler builds a real no-group-by totals aggregation by re-running `agg_specs` on the base, applies non-AllOf calc measures to it, cross-joins it with prefixed column names, and rewrites totals references. Non-sum chains (`avg_distance / t.all(avg_distance)`) now match the right answer (overall mean, not sum-of-per-group-means). Pinned by `test_apply_calc_measures_join_with_mean_totals` (mean) and parametrized `test_apply_calc_measures_non_sum_totals` (median, min, max). +- **Inline reductions inside `t.all(...)`.** `t.value.sum() / t.all(t.value.sum())` compiles end-to-end via `lift_inline_reductions`: each unique reduction over the base is named, added to both per-group and totals aggregations, and rewritten in-place (bare → `Field(vt, anon)`, windowed → `Field(totals_vt, anon)`). Pinned by `test_lift_inline_reductions_routes_window_to_totals`. +- **Calc-of-calc.** Topologically ordered inside `apply_calc_measures` — each calc is added to the result via its own `mutate(...)` so subsequent calcs see it as a column. `depends_on` is captured at classification time and survives serialization. +- **Joined models.** `IbisCalcScope` does unique-suffix matching: `t.flight_count` resolves to `flights.flight_count` when there's exactly one such suffix. No need to rewrite stored lambdas for prefixed names. +- **Clear errors instead of opaque ibis `IntegrityError`.** `TotalsNotAvailableError` when `t.all(...)` is referenced but no totals can be built; post-rewrite assertion in `compile_calc_measure` listing the unresolved column names. + +### Phase 3 — Outstanding + +`SemanticMutateOp`, the chained `.mutate()` API, and the planner branches that special-case mutate are still in the codebase. Phase 3 work: + +1. **Reduce `.mutate(**post)` to a thin alias for `.with_measures(**post).aggregate(*current, *post.keys())`** on `SemanticAggregate`. No new method name, no new operator node, no new storage shape — `with_measures` already routes through `_classify_measure`, which already routes through the analyzer. The mutate method becomes ~3 lines of desugaring; existing user code keeps working unchanged. +2. **Remove `SemanticMutateOp`** from `ops.py` (currently `ops.py:3307`). +3. **Remove `SemanticMutate`** from `expr.py` (currently `expr.py:1615`). The three `.mutate()` methods on `SemanticTable` (`expr.py:221`), `SemanticAggregate` (`expr.py:1352`), and `SemanticMutate` (`expr.py:1659`) either get the desugaring described in (1) or are deleted. +4. **Remove mutate-aware planner branches** (enumerated below). +5. **Remove `SemanticMutateOp` registrations** from `serialization/extract.py:63,74,140`, `serialization/reconstruct.py:185`, `convert.py:24,400`, `format.py:17,149`, `chart/utils.py:131,137`. Existing tags containing `SemanticMutateOp` will fail to deserialize with an `UnknownTagError` (or equivalent) naming the offending op and pointing users at the `with_measures` equivalent in the deprecation note. No migration tool: tags are re-generated from current model definitions, and users with persisted tags either re-tag or pin the prior BSL version. +6. **Lint / deprecation pass.** Flag remaining `.mutate(` chained off semantic objects in user code if (3) opts to delete rather than alias; emit the `with_measures` equivalent. + +#### Phase 3 readiness checklist — planner branches that go away + +Each line is a concrete deletion target. Counts are from current `main`-vs-branch state. + +| File:line | What it does | What replaces it | +|---|---|---| +| `ops.py:380` | `_semantic_repr` arm for `SemanticMutateOp` | Deleted with the op | +| `ops.py:2364` | `has_prior_aggregate(SemanticMutateOp)` traversal | Deleted; no node to traverse | +| `ops.py:2374` | `is_post_agg = has_prior_aggregate(self.source)` driving the post-agg branch in `SemanticAggregateOp.to_untagged` | Stays — but the only mutate-recursion case (line 2364) goes away; remaining cases (`SemanticAggregateOp`, `SemanticGroupByOp`) still apply | +| `ops.py:2377–2392` | `collect_mutates_to_join` walks the chain collecting `SemanticMutateOp.post` dicts | Deleted; no chained mutates to collect | +| `ops.py:2433` | `collected_mutates = collect_mutates_to_join(self.source)` | Deleted | +| `ops.py:2439` | `_to_untagged_with_preagg(..., mutates=collected_mutates)` | Drop the `mutates` parameter | +| `ops.py:2553–2558` | `mutated_gb_keys` heuristic — group-by keys that aren't dims/measures/calcs are assumed to be mutate-introduced | Deleted; group-by keys with derivations are calc measures via `with_measures`, so they appear in `merged_calc_measures` | +| `ops.py:2569–2581` | Apply mutate ops to full joined table for dim-bridge use | Deleted | +| `ops.py:2706–2716` | Apply mutated group-by keys to per-table raw tables for grain computation | Deleted | +| `ops.py:2768–2770` | Local-dim handling for mutated group-by keys | Deleted | +| `ops.py:3307–3372` | `SemanticMutateOp` class itself | Deleted | +| `expr.py:32, 221, 222, 1352, 1353, 1615–1698, 1659, 1660, 1675–1698` | `SemanticMutate` class, `.mutate()` methods, imports | Deleted or reduced to alias | +| `convert.py:24, 400–402` | `_convert_semantic_mutate` to-ibis conversion | Deleted | +| `format.py:17, 149–150` | `_format_semantic_mutate` repr | Deleted | +| `chart/utils.py:131, 137` | Chart introspection skipping `SemanticMutateOp` nodes | Deleted (no nodes to skip) | +| `serialization/extract.py:63, 74, 140` | Registration + lazy stash for `SemanticMutateOp` tag | Deleted | +| `serialization/reconstruct.py:185–190` | Reconstructor for `SemanticMutateOp` | Replaced by a clear `UnknownTagError` naming the op and pointing at the `with_measures` equivalent | + +The intellectually load-bearing piece — proving you can recover pushability/AllOf-lift/post-agg classification by analyzing an ibis tree — is done. Phase 3 is mechanical deletion guided by the table above plus regression tests for each composition that the deletions touch. + +#### Composition gotchas to pin with tests before deletion + +Each row is a chained mutate composition that exists today; the right column says how the equivalent `with_measures` chain behaves. None requires new semantics — they all fall out of where `with_measures` already builds its scope — but each needs a regression test before mutate is removed. + +| Composition | Today's behavior | After (`with_measures` chain) | +|---|---|---| +| `.aggregate(...).mutate(c=lambda t: ...).filter(p)` | Filter sees `c` (mutate runs first by chain order) | `.with_measures(c=...).aggregate(..., "c").filter(p)` — same: filter applies to the aggregated table containing `c` | +| `.aggregate(...).mutate(a=...).filter(p).mutate(b=...)` | `b` sees filtered table including `a` | `.with_measures(a=...).aggregate(..., "a").filter(p).with_measures(b=...)` — `SemanticFilter.with_measures` (`expr.py:1117`) scopes on `self.op().to_untagged()` which is the filtered table, so `b` sees rows surviving `p` | +| `.aggregate(...).mutate(c=...).order_by("c")` | OrderBy operates on `c` | `.with_measures(c=...).aggregate(..., "c").order_by("c")` — same: `c` is a column on the aggregated table | +| `.aggregate(...).mutate(c=...).limit(10)` | Limit applied after `c` is added | Same — limit is post-aggregate either way | +| `.aggregate(...).mutate(c=...).limit(10).mutate(d=lambda t: t.c * 2)` | `d` sees the post-limit table | `.with_measures(c=...).aggregate(..., "c").limit(10).with_measures(d=...)` — but **note**: `SemanticLimit.with_measures` does *not* exist today (`expr.py` only defines it on `SemanticTable`/`SemanticFilter`/`SemanticAggregate`/`SemanticMutate`). Phase 3 either adds `SemanticLimit.with_measures` or rejects this composition. Recommended: add it, scoped on `self.op().to_untagged()` for consistency with `SemanticFilter.with_measures` | +| `.mutate(c=...).group_by(...).aggregate(...)` | Mutate runs *before* aggregation; `c` is a dimension-grain column | `.with_measures(c=...).group_by(...).aggregate(..., "c")` — the analyzer classifies pre-agg derivations as `pushable`, so this Just Works on the existing path | +| `.aggregate(...).mutate(c=...)` followed by use as a join input | Mutate result becomes a `SemanticMutateOp` node the join planner had to special-case | After Phase 3, the result is a `SemanticAggregate` (subclass of `SemanticTable`) with `c` in its measures dict. The join planner already handles `SemanticAggregate`, so the `SemanticMutateOp` arm in `collect_mutates_to_join` simply has nothing to collect | +| `.join_one(...).mutate(c=...)` (mutate after join, before aggregate) | Mutate adds a column that participates as a group-by candidate via `mutated_gb_keys` | After Phase 3: define `c` via `with_measures` on either side before the join, or on the join result; either way it lands in `merged_calc_measures` and the planner sees it without the `mutated_gb_keys` heuristic | + +The `mutated_gb_keys` heuristic at `ops.py:2553` deserves a specific call-out: *the only reason it exists* is that `SemanticMutateOp` introduced columns the planner couldn't classify as dims/measures/calcs. Once mutate columns become calc measures, the heuristic has no work to do — every group-by key resolves through `merged_*` lookups directly. This collapses lines 2553–2770 (~60 LOC of conditional handling for mutated keys) without replacement. + +## Migration + +| Today | After Phase 3 | +|--------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------| +| `with_measures(avg=lambda t: t.total / t.cnt)` (calc measure today) | Already works — landed in Phase 1+2; analyzer classifies as `pushable`. | +| `with_measures(share=lambda t: t.x / t.all(t.x))` | Already works — landed in Phase 1+2; the analyzer detects the totals pattern, the compiler builds a real totals table. | +| `with_measures(avg=lambda t: t.x.mean(), ratio=lambda t: t.avg / t.all(t.avg))` (non-sum totals) | Already works — landed in Phase 1+2; totals re-aggregation uses the formula, not a windowed sum. | +| `.aggregate("a","b").mutate(c=lambda t: t.a / t.b)` | Define `c` on the model and request it: `.with_measures(c=...).aggregate("a","b","c")`. Or keep the chained shape if Phase 3 ships `.mutate` as an alias. | +| `.aggregate("c").mutate(bucket=lambda t: xo.case().when(t.c>=3,"hi").else_("lo").end())` | `.with_measures(bucket=...).aggregate("c","bucket")` — already permitted in Phase 1+2 (analyzer classifies as `post_agg_only`). | +| `.aggregate("x").mutate(rank=lambda t: t.x.rank(), pct=lambda t: t.x.percent_rank())` | `.with_measures(rank=..., pct=...).aggregate("x","rank","pct")` — already permitted (analyzer classifies as `has_window`). | +| `.aggregate("x").mutate(ma=lambda t: t.x.mean().over(window(order_by="d", preceding=2)))` | `.with_measures(ma=...).aggregate("x","ma")` — already permitted. | +| `.aggregate("x").mutate(adhoc=...)` where `adhoc` truly is per-query | `.with_measures(adhoc=...).aggregate("x","adhoc")` registers a query-local measure on the temporary aggregated model. Same shape, same lambda. | +| Existing serialized tags containing `SemanticMutateOp` | Fail to deserialize with a clear error naming the offending op and pointing at the `with_measures` equivalent. Users re-tag from current model definitions or pin the prior BSL version. No migration tool. | + +## Consequences + +### Positive — realized in Phase 1+2 + +- **One classification path.** Calc-measure pushability/AllOf-lift/post-agg routing is read off `analyze_calc_expr`. No more parallel "is this a curated AST node" + "what does ibis make of it" reasoning. +- **Full ibis expressivity for calc measures.** Windows, `xo.case`, `xo.ifelse`, `.fillna(...).cast(...)`, struct/array methods — all work declaratively *and* participate in serialization, catalog tooling, and pre-agg analysis. +- **No hand-rolled compiler.** `compile_grouped_with_all` and `infer_calc_dtype` deleted. Type inference falls out of `expr.type()` (with a debug-logged fallback for joined-model edge cases). +- **Serialization simplified.** Calc measures serialize through one mechanism (resolver trees). Two parallel formats collapsed to one. +- **Catalog visibility for everything mutate previously hid would be the next benefit** — partially realized: anything declarable on the model now shows up in `model.measures`. Phase 3 closes the gap for genuinely per-query derivations. +- **Obviates two would-be follow-up ADRs** (`xo.case` in calc measures, windows in calc measures). Done in one cut. + +### Positive — pending Phase 3 + +- **Planner branches collapse.** `collect_mutates_to_join`, `has_prior_aggregate`'s mutate arm, `mutated_gb_keys`, and `_to_untagged_with_preagg(..., mutates=...)` all go away. +- **`expr.py` shrinks.** `SemanticMutate` (entire class) and `SemanticMutateOp` (entire op) disappear. The `.mutate(**post)` method either survives as a 3-line alias for `with_measures(**post).aggregate(...)` or is deleted entirely. +- **No new public method.** The user-facing surface stays at `with_measures` — already familiar, already analyzer-routed. No `with_calc` to learn. +- **Serialization simplifies further.** Tags carrying `SemanticMutateOp` go away; the deserializer surfaces a clear error for any old tag that still references it. + +### Behavior changes for users + +The Phase 1+2 cutover is an internal refactor in shape, but it has three semantic edges users will hit and that the ADR commits to as stable surface area. + +1. **Calc lambdas execute twice per query.** Once at definition (classification — `_classify_measure` runs the lambda against `IbisCalcScope` to walk the resulting ibis tree) and once at query time (compilation — `apply_calc_measures` re-runs the lambda against the real aggregated table). Implications: + - **Pure expressions** (`lambda t: t.a / t.b`) — no observable change. + - **Lambdas that read external state** (config, env, globals) — both reads happen; if the values differ between definition and query they observe the *latter*. In practice the classification result is discarded once the lambda is stored, but the side effects are not. + - **Lambdas with side effects** (logging, counters, network calls) — fire twice. **Don't put side effects in calc lambdas.** This was technically also true for `mutate(...)` lambdas (they ran during planning *and* execution if the table was re-executed) but only the analyzer makes the double-execution unconditional and predictable. + - The current ``_classify_measure`` swallows generic exceptions and falls through to base classification (`ops.py:790`), so a lambda that raises during classification still gets called at query time. Don't rely on classification-time exceptions to short-circuit anything. + +2. **`IbisCalcScope` dispatch order is the public contract** users program against in calc lambdas. The order is: + 1. **Base column wins on collision.** If `t.foo` matches a column on the base table, it returns the base column — even when `foo` is also registered as a measure. This preserves historical `t.distance.sum()` semantics where `distance` was both a column and a measure name. + 2. **Then known measure (with suffix matching).** `t.flight_count` resolves to the measure named `flight_count`; on a joined model with prefixed names like `flights.flight_count`, the unique-suffix match bridges the short name automatically. + 3. **Then ibis Table methods.** `t.count`, `t.filter`, etc. fall through to the underlying table for parity with ibis usage inside calc bodies. + 4. **Otherwise `UnknownMeasureRefError`** with a `difflib`-derived "did you mean?" suggestion. + + `t.all(x)` follows a separate contract: string measure name → totals reference; string column name → `column.sum().over(window())` with a logger warning saying "use a measure for non-sum semantics"; ibis Reduction → window-wrap; ibis Field on the virtual aggregated table → totals reference. + + **Footgun.** A model defining a measure with the same name as a base column shadows the column inside `t.foo`. This is rarely what users want when defining ratios — `t.distance.mean() / t.all(t.distance.mean())` has clear intent, but if the user wrote `t.distance / t.all(t.distance)` expecting "ratio of mean to total mean," they'd get the column instead. ADR commits to documenting this in user docs and adding a startup-time warning when a measure shadows a column. + +3. **Nested-array measures + `t.all(...)` is unsupported.** `apply_calc_measures` raises `TotalsNotAvailableError` (calc_compiler.py:498) when a calc references `t.all(...)` on a model with nested-array measures. Reason: nested-array measures compile at multiple grains and join — there is no single "totals aggregation" that respects all grains. Users hitting this either (a) restructure the calc to reference a flat-grained intermediate measure, or (b) lower it manually via the `to_untagged()` escape hatch. ADR commits to the error rather than a silently-wrong answer; lifting the limitation is future work that requires designing per-grain totals semantics. + +### Negative — accepted + +- **The analyzer is harder than tag-matching.** Detecting `AllOf` was `isinstance(node, AllOf)`; detecting "agg-of-agg" requires walking ibis trees and recognizing the structural pattern. The analyzer is roughly the size of the deleted `validate_calc_ast` plus the AllOf-lift section of the deleted compiler — net code is still smaller (-727 LOC). +- **Calc lambdas now run twice.** Once at construction (for classification) and once at query time (against the real table). For typical models this is negligible — calc lambdas are tiny — but lambdas with side effects would be observed twice. Documented as an intentional tradeoff. +- **`IbisCalcScope` is load-bearing public-ish surface area.** When users write calcs they're effectively programming against the scope's dispatch rules (column-first, then known-measure suffix lookup, then `t.all(...)` totals). Documented. +- **Hard cutover, no compat shim.** Calc measures using the old curated-AST shapes (`MeasureRef("x")`, `AllOf(...)`, `BinOp(...)`) no longer work directly. In practice the user-facing API was always the lambda form (`lambda t: t.x / t.all(t.x)`); only internal tests touched the AST classes. Test suite migrated in the same branch. + +### Negative — pending Phase 3 + +- **Existing tags with `SemanticMutateOp` will not deserialize.** They fail with a clear error naming the offending op. Users re-tag from current model definitions or pin the prior BSL version. No migration tool — the failure mode is loud and the fix (re-tag) is a one-line script call against the model. +- **Doc churn.** `query-methods.md`, `bucketing.md`, `sessionized.md`, `windowing.md`, `percentage-total.md`, `reference.md` all need rewrites — but in a *better* direction (declarative on the model rather than chained `.mutate()`). +- **Test churn.** `test_real_world_scenarios.py`, `test_preagg_stress.py`, `test_malloy_inspired.py`, and the mutate-chain tests need updates. + +### Neutral + +- Pre-aggregation correctness (formerly the `mutated_gb_keys` machinery) is no longer the operator's responsibility — it was only the operator's responsibility because `SemanticMutateOp` introduced post-agg derived columns the planner couldn't see into. With analysis-based classification, the planner reads pushability off the calc measure's expression directly; correctness is a property of the analyzer. + +## Alternatives considered + +1. **Keep both, document the split.** Rejected — enshrines a confusing two-axis decision ("is this a measure or a mutate?") that was already a recurring user question. +2. **Drop calc measures, keep `mutate`.** Mutate's expression language is right; its placement (query-local, anonymous, not in `model.measures`) is wrong. Rejected. +3. **Drop `mutate`, keep the curated calc-measure AST** (the previous version of this ADR). Trades the planner cleanup for ergonomic regression on windows and `xo.case`. Rejected — the curated AST has to grow eventually anyway, and growing it is strictly more total work than going to ibis-as-the-language once. +4. **Extend the curated calc-measure AST node-by-node** (add `Case`, `Window`, `When`, …). Each extension requires coordinated changes to `validate_calc_ast`, the compiler, the AllOf-lift pass, and serialization. Strictly more work over time than this unification; rejected. +5. **Defang `SemanticMutateOp`** to a terminal-only post-agg escape hatch with no chained API and no pre-agg participation. Saves the planner cleanup but keeps two systems and the curated-AST limitation forever. Rejected. +6. **Keep `SemanticMutateOp` only at the post-aggregation boundary, hide the public `.mutate` API.** Strips the user-facing surface but leaves the planner branches and serialization registrations intact. Rejected. +7. **Soft cutover with a curated-AST compat shim** (the original Phase 1 plan). Rejected during implementation — the AST classes had no external users in practice (only internal tests), so the shim was pure carrying cost. Hard cutover saved the deprecation cycle. + +## Open questions + +### Resolved in Phase 1+2 + +- ~~**Analyzer scope for v1.**~~ Implemented per the recommendation: column-refs-on-one-source → `pushable`; agg-of-agg / empty-window-over-reduction → `references_AllOf`; any window node → `has_window`. Unrecognized inputs warn and fall back to `post_agg_only`. +- ~~**`t.all(...)` API.**~~ Kept the proxy method (`scope.all(...)`); emits a `Field(totals_vt, ...)` reference the analyzer recognizes. No user-visible change. +- ~~**`t.all(...)` over non-sum measures.**~~ Resolved with a real totals table. Per-group means cross-joined with overall mean; totals re-aggregation uses the formula, not a windowed sum. +- ~~**Inline reductions inside `t.all(...)`.**~~ Resolved via `lift_inline_reductions`. +- ~~**Calc-of-calc dependency ordering.**~~ Topological order from `CalcMeasure.depends_on`, captured at classification time and preserved through serialization. +- ~~**Compat-shim duration for the curated AST.**~~ Hard cutover; no shim. + +### Resolved during ADR review + +- ~~**`with_measures` semantics for chained calls after a filter.**~~ Settled. `SemanticFilter.with_measures` (`expr.py:1117`) builds its `MeasureScope` on `self.op().to_untagged()`, which is the *filtered* ibis table; lambdas registered there see the filtered rows. Phase 3's `.mutate(...)` desugaring inherits this — no new semantics to invent. Worth a regression test (`test_with_measures_after_filter_sees_filtered_table`) so the contract is pinned before mutate is removed. + +### Resolved during Phase 3 readiness pass + +- ~~**`.mutate(...)` chain method: alias or delete?**~~ **Alias.** Keep `.mutate(**post)` as a 3-line desugaring to `self.with_measures(**post).aggregate(*current_aggs, *post.keys())` on `SemanticAggregate`. The composition gotcha table above shows every chained-mutate shape lowers cleanly to the equivalent `with_measures` chain, so the alias is unambiguous. Preserves chain ergonomics; no operator/storage-path divergence. The `.mutate()` method on `SemanticTable` (pre-aggregate) lowers to a slightly different alias — `self.with_measures(**post)` — because pre-aggregate mutate is just "register measures on this model"; calling it as `with_measures` is the rename. + +- ~~**Old tag handling.**~~ No migration tool. Tags containing `SemanticMutateOp` raise a clear deserialization error naming the offending op and pointing at the `with_measures` equivalent in the deprecation note. Users either re-tag from their current model definitions or pin the prior BSL version. Rationale: the calc-measure expression language is now full ibis, so re-generating tags from the current model is straightforward; building a tool that introspects arbitrary persisted lambdas would cost more than the user-side re-tag. + +- ~~**Pre-agg correctness coverage audit.**~~ Required before deletion. The pre-agg paths (`_to_untagged_with_preagg`, `_to_untagged_with_deferred_joins`) currently special-case mutate via `mutated_gb_keys` and `collect_mutates_to_join`. The audit consists of: (a) enumerate every mutate-aware branch (the planner-readiness table above is this list); (b) for each branch, identify the equivalent `with_measures` test case in `test_real_world_scenarios.py` / `test_preagg_stress.py`; (c) where coverage is missing, add the test *before* deleting the branch. The Phase 1+2 baseline is 978 passing tests (1 preexisting unrelated xorq `read_parquet` failure) — any Phase 3 deletion that doesn't keep that count at 978 or higher blocks the merge. Concrete coverage gaps known today: chained-after-limit (`SemanticLimit.with_measures` doesn't exist yet — Phase 3 must add it), and the `mutated_gb_keys` interaction with cross-table dimension bridges (`ops.py:2566–2581`). + +### Open for Phase 3 + +- **Sequencing inside Phase 3.** Recommended order: (1) ship `SemanticLimit.with_measures` and the regression tests for every composition in the gotcha table — establishes the baseline equivalence; (2) reduce `.mutate()` to the desugaring alias and re-run the test suite — proves the alias is observationally equivalent; (3) drop `SemanticMutateOp`, planner branches, and serialization registrations (the alias's call site changes from "build mutate op" to "extend the aggregate's measure set") — the test suite is the safety net. Update `serialization/reconstruct.py` to raise a clear error for old `SemanticMutateOp` tags; (4) decide whether to keep or delete the alias one minor version later. Each step is independently revertable; (3) is the load-bearing deletion and (1)+(2) exist to make it boring. + +- **Classification-result caching.** The analyzer runs once per calc per `_compile_aggregation` call — already memoized within a single query (`classify_calc_lambdas` in `calc_compiler.py:540`). Cross-query caching is out of scope: the classification depends on the model's known-measure set, which varies between models that share calc lambdas (e.g. via copy-paste). Anyone hitting hot-path overhead from re-classification has bigger problems (their model construction is in the request path); the right fix is to cache the *built* `SemanticAggregateOp` tree, not the classification record. ADR commits to no cross-query analyzer cache; the classification-cost story is "negligible for typical models, not the bottleneck for hot paths anyway." diff --git a/src/boring_semantic_layer/agents/backends/mcp.py b/src/boring_semantic_layer/agents/backends/mcp.py index e30e0f2..08f6a4f 100644 --- a/src/boring_semantic_layer/agents/backends/mcp.py +++ b/src/boring_semantic_layer/agents/backends/mcp.py @@ -427,7 +427,7 @@ def search_dimension_values( f"Available dimensions: {list(dims.keys())}" ) - from boring_semantic_layer.compile_all import _get_ibis_module + from boring_semantic_layer.nested_compile import get_ibis_module as _get_ibis_module dim = dims[dimension_name] tbl = model.table diff --git a/src/boring_semantic_layer/calc_analyzer.py b/src/boring_semantic_layer/calc_analyzer.py new file mode 100644 index 0000000..bd0ae36 --- /dev/null +++ b/src/boring_semantic_layer/calc_analyzer.py @@ -0,0 +1,358 @@ +"""Analyzer for calc measures expressed as ibis expressions. + +Replaces the curated calc-measure AST classification (`MeasureRef` / +`AllOf` / `BinOp` / `MethodCall` / `AggregationExpr`) with structural +analysis of an ibis expression tree. + +The analyzer walks the ibis tree and returns a :class:`CalcExprAnalysis` +record describing properties relevant to the planner: + +- ``pushable``: every column reference targets one source table and there + is no window or aggregation-in-aggregation; the expression can be + computed pre-aggregation as a base measure. +- ``references_AllOf``: an aggregation node appears as a scalar inside + another aggregation context (the "totals" pattern). The compiler lifts + this to a window aggregation (or a cross-joined totals table). +- ``has_window``: any window node anywhere. Forces post-aggregation. +- ``post_agg_only``: the expression cannot be pushed pre-aggregation. +- ``depends_on``: set of names this expression references on the + *aggregated* virtual scope (i.e. measure references). + +Anything not classifiable falls back to ``post_agg_only=True`` with a +warning, never an error — see the ADR's "v1 analyzer scope" open +question. +""" + +from __future__ import annotations + +import warnings +from typing import Any + +from attrs import field, frozen + +from ._xorq import ( + Deferred, + Field, + Node, + operations as ibis_ops, +) +from ._xorq import ibis as ibis_mod + + +@frozen(kw_only=True) +class CalcExprAnalysis: + """Structural classification of a calc-measure ibis expression. + + Produced by :func:`analyze_calc_expr`. The planner reads ``pushable`` + to decide pre-agg pushdown, ``references_AllOf`` to decide whether + to compute totals, ``has_window`` and ``post_agg_only`` to decide + placement, and ``depends_on`` to order calc-measure compilation. + """ + + pushable: bool + references_AllOf: bool + has_window: bool + post_agg_only: bool + depends_on: frozenset[str] = field(factory=frozenset, converter=frozenset) + inline_aggs: frozenset[str] = field(factory=frozenset, converter=frozenset) + + +def _to_node(expr: Any) -> Node | None: + """Best-effort coercion of an arbitrary value to an ibis ``Node``. + + Returns ``None`` for primitives (int, float, str, None) and for + Deferreds that haven't been resolved yet — callers must resolve + Deferreds against an actual table before analysis. + """ + if expr is None: + return None + if isinstance(expr, (int, float, str, bool)): + return None + if isinstance(expr, Deferred): + return None + if hasattr(expr, "op") and callable(expr.op): + try: + return expr.op() + except Exception: + return None + if isinstance(expr, Node): + return expr + return None + + +def _is_reduction(node: Node) -> bool: + """True if ``node`` is an ibis ``Reduction`` (sum/mean/count/...).""" + Reduction = getattr(ibis_ops, "Reduction", None) + if Reduction is not None and isinstance(node, Reduction): + return True + name = type(node).__name__ + return name in ( + "Sum", + "Mean", + "Count", + "CountStar", + "CountDistinct", + "Min", + "Max", + "Variance", + "StandardDev", + "Median", + "Quantile", + "ApproxCountDistinct", + "Mode", + "First", + "Last", + "Arbitrary", + "Any", + "All", + "GroupConcat", + "ArrayCollect", + ) + + +def _is_window(node: Node) -> bool: + """True if ``node`` is any ibis window operation.""" + WindowFunction = getattr(ibis_ops, "WindowFunction", None) + if WindowFunction is not None and isinstance(node, WindowFunction): + return True + name = type(node).__name__ + return "Window" in name + + +def _walk_children(node: Node): + """Yield direct child Nodes of ``node``. Robust to ibis API drift. + + Walks ``__children__`` if present, otherwise ``__args__``. Skips + non-Node leaves (literals, schemas, etc.). Skips ``Relation`` nodes + so the analyzer does not descend into base-table expressions whose + body may itself contain window functions or aggregations unrelated + to the calc measure being classified. + """ + from ._xorq import operations as ibis_ops + + Relation = getattr(ibis_ops.relations, "Relation", None) + + children = getattr(node, "__children__", None) + if children is not None: + for c in children: + if isinstance(c, Node) and not (Relation is not None and isinstance(c, Relation)): + yield c + return + args = getattr(node, "__args__", None) + if args is None: + return + for arg in args: + if isinstance(arg, Node): + if Relation is not None and isinstance(arg, Relation): + continue + yield arg + elif isinstance(arg, tuple): + for inner in arg: + if isinstance(inner, Node): + if Relation is not None and isinstance(inner, Relation): + continue + yield inner + + +def _walk(node: Node): + """Iterate ``node`` and all descendants (preorder, deduped). + + Does not descend into ``Relation`` subtrees — those are the table + references the calc expression sits on top of, not part of its + structural shape. + """ + seen: set[int] = set() + stack = [node] + while stack: + cur = stack.pop() + key = id(cur) + if key in seen: + continue + seen.add(key) + yield cur + stack.extend(_walk_children(cur)) + + +def _collect_field_names(node: Node) -> set[str]: + """Collect all ``Field`` names referenced anywhere under ``node``.""" + return {n.name for n in _walk(node) if isinstance(n, Field)} + + +def _collect_source_tables(node: Node) -> set[int]: + """Identify distinct source-table ops referenced under ``node``. + + Returns a set of ``id()`` for ``Field.rel`` ops. Used to detect + expressions that span multiple tables (post-agg only). + """ + return {id(n.rel) for n in _walk(node) if isinstance(n, Field)} + + +def _is_empty_window(node: Node) -> bool: + """True if ``node`` is a window with no partitioning or ordering. + + The ``t.all(x)`` API emits ``x.sum().over(window())`` — an empty + window — to mean "take the totals over the whole post-agg result." + A partitioned or ordered window is a real window function (moving + average, rank, etc.) and is treated separately as ``has_window``. + """ + if not _is_window(node): + return False + group_by = getattr(node, "group_by", ()) + order_by = getattr(node, "order_by", ()) + return not group_by and not order_by + + +def _scan_tree(node: Node) -> tuple[bool, bool, bool]: + """Single-pass tree walk returning ``(has_reduction, has_window, has_totals)``. + + Combining the three checks avoids the O(K) full subtree walks the + original ``_has_totals_pattern`` did once per encountered reduction — + structural classification is a hot path called once per + ``with_measures`` lambda. Definition of ``has_totals``: + + * a ``Reduction`` whose subtree contains another ``Reduction`` (an + aggregation-inside-aggregation), or + * an empty (no group_by, no order_by) ``WindowFunction`` over a + ``Reduction`` — the ``x.sum().over(window())`` shape ``t.all(x)`` + emits today. + """ + has_reduction = False + has_window = False + has_totals = False + seen: set[int] = set() + stack = [(node, 0)] # (node, depth_of_enclosing_reduction) + while stack: + cur, agg_depth = stack.pop() + key = id(cur) + if key in seen: + continue + seen.add(key) + + cur_is_reduction = _is_reduction(cur) + cur_is_window = _is_window(cur) + + if cur_is_reduction: + has_reduction = True + if agg_depth > 0: + has_totals = True + agg_depth += 1 + elif cur_is_window: + has_window = True + if _is_empty_window(cur): + agg_depth += 1 + + for child in _walk_children(cur): + stack.append((child, agg_depth)) + return has_reduction, has_window, has_totals + + +def analyze_calc_expr( + expr: Any, + known_measures: frozenset[str] = frozenset(), + base_table_op: Node | None = None, + totals_vt_op: Node | None = None, +) -> CalcExprAnalysis: + """Classify a calc-measure ibis expression. + + Parameters + ---------- + expr: + An ibis expression, ``Deferred``, or primitive. Deferreds must + be resolved by the caller against the analysis scope before the + walker can inspect them. + known_measures: + Names of measures defined on the model. Field references on the + synthetic post-aggregation virtual table whose names are in + this set are recorded as ``depends_on``. + base_table_op: + Optional. The base table's ibis op. When provided, fields + referencing this exact table are not treated as measure + dependencies — they are inline base columns (used by inline + aggregations like ``t.distance.sum()`` in calc-measure form). + totals_vt_op: + Optional. The totals virtual table's ibis op (parallel to + ``base_table_op`` but representing no-group-by aggregation). + Field references on this table mark the totals pattern; the + compiler later substitutes them with a real totals aggregation + cross-joined into the result. + + Returns + ------- + CalcExprAnalysis + Structural classification. On unrecognized inputs the analyzer + returns ``post_agg_only=True`` with a warning rather than + raising. + """ + node = _to_node(expr) + + if node is None: + # Primitive (int/float/str). Pure constants are pushable + # trivially — they fold into both grouped and ungrouped contexts. + if isinstance(expr, (int, float, bool)): + return CalcExprAnalysis( + pushable=True, + references_AllOf=False, + has_window=False, + post_agg_only=False, + ) + warnings.warn( + f"calc-measure analyzer could not classify {type(expr).__name__}; " + "treating as post-aggregation-only.", + stacklevel=2, + ) + return CalcExprAnalysis( + pushable=False, + references_AllOf=False, + has_window=False, + post_agg_only=True, + ) + + _, has_window, references_AllOf = _scan_tree(node) + + field_names = _collect_field_names(node) + source_tables = _collect_source_tables(node) + + depends_on: set[str] = set() + inline_aggs: set[str] = set() + base_id = id(base_table_op) if base_table_op is not None else None + totals_id = id(totals_vt_op) if totals_vt_op is not None else None + for fld in (n for n in _walk(node) if isinstance(n, Field)): + if totals_id is not None and id(fld.rel) == totals_id: + references_AllOf = True + depends_on.add(fld.name) + elif base_id is not None and id(fld.rel) == base_id: + inline_aggs.add(fld.name) + elif fld.name in known_measures: + depends_on.add(fld.name) + + # Pushability heuristic: single source table, no windows, no + # cross-aggregation patterns, no measure refs (since measures are + # already aggregated and can't push pre-agg). + pushable = ( + not has_window + and not references_AllOf + and not depends_on + and len(source_tables) <= 1 + ) + + post_agg_only = has_window or references_AllOf or bool(depends_on) + + return CalcExprAnalysis( + pushable=pushable, + references_AllOf=references_AllOf, + has_window=has_window, + post_agg_only=post_agg_only, + depends_on=depends_on, + inline_aggs=inline_aggs, + ) + + +def virtual_agg_table( + schema: dict[str, Any], + name: str = "__bsl_virtual_agg__", +): + """Build a synthetic ibis table representing the post-aggregation + schema. Calc-measure lambdas evaluate against this table to produce + an ibis expression the analyzer can walk. + """ + return ibis_mod.table(schema, name=name) diff --git a/src/boring_semantic_layer/calc_compiler.py b/src/boring_semantic_layer/calc_compiler.py new file mode 100644 index 0000000..c086afa --- /dev/null +++ b/src/boring_semantic_layer/calc_compiler.py @@ -0,0 +1,1038 @@ +"""Ibis-native calc-measure compiler. + +Replaces the curated-AST ``compile_grouped_with_all`` / +``infer_calc_dtype`` pipeline with one that accepts user-written ibis +expressions directly. The compiler relies on +:func:`boring_semantic_layer.calc_analyzer.analyze_calc_expr` to +classify each measure structurally; placement (pushable vs. +post-aggregation) is read off the analysis record rather than off +curated-AST node types. + +Architecture +------------ + +A calc-measure lambda is evaluated against an :class:`IbisCalcScope` +that dispatches name lookups to three tables: + +- ``base_tbl`` for raw columns (used by inline aggregations like + ``t.distance.sum()`` inside a calc measure). +- ``virtual_agg_tbl`` for measure references (a synthetic ibis table + whose schema mirrors the post-aggregation result). +- ``totals_virtual_agg_tbl`` for ``t.all(measure_ref)`` — a parallel + synthetic table representing the same measures computed without + group_by. Compile-time substitution swaps it with a real totals + aggregation cross-joined into the result so non-sum measures get + correct overall values. + +When a column name exists on both ``base_tbl`` and ``virtual_agg_tbl``, +the base column wins — historical curated-AST behavior where +``t.distance.sum()`` meant "sum the base column" even when ``distance`` +was registered as a measure. +""" + +from __future__ import annotations + +import difflib +import logging +from collections.abc import Iterable +from typing import Any + +from ._xorq import Deferred, Field, Node +from ._xorq import ibis as ibis_mod +from ._xorq import operations as ibis_ops +from .calc_analyzer import ( + CalcExprAnalysis, + _walk, + _walk_children, + analyze_calc_expr, + virtual_agg_table, +) +from .measure_scope import UnknownMeasureRefError + +logger = logging.getLogger(__name__) + + +TOTALS_PREFIX = "__bsl_totals__" +"""Column prefix applied to totals-table columns when cross-joined into +the per-group result. Any column on a result table starting with this +prefix represents the same-named measure computed over the totals +aggregation; calc-measure compilation rewrites +``Field(totals_vt, name)`` references to point at these prefixed +columns.""" + + +_EMPTY_VT_SCHEMA: dict[str, str] = {"__bsl_unused__": "int64"} +"""Placeholder schema used when a virtual aggregated table would +otherwise be empty (e.g. a model with no measures yet). ibis tables must +have at least one column; the sentinel name is unlikely to collide with +real column names and lets analyzer/compiler logic stay uniform.""" + + +class TotalsNotAvailableError(RuntimeError): + """Raised when a calc measure references ``t.all(measure_ref)`` but + no totals table can be constructed in the current compilation context. + + Two situations produce this error: + + * The compilation path lacks the per-base aggregation specs needed + to recompute totals (``apply_calc_measures`` called without + ``agg_specs`` and no ``real_totals_tbl``). + * The aggregation involves nested-array measures, which are + computed at multiple grains and joined; building a totals table + that respects all grains is not yet supported. + """ + + +def _to_op(x): + """Return ``x.op()`` if ``x`` is an ibis expression-like, else ``x``. + + BSL accepts both expressions (``Table``/``Column``/...) and bare + ops in many places. Centralizing the duck-type lets call sites + stay focused on the substitution logic. + """ + op = getattr(x, "op", None) + return op() if callable(op) else x + + +def _drop_totals_columns(tbl, totals_prefix: str = TOTALS_PREFIX): + """Project ``tbl`` to columns that do not carry the totals prefix. + + Used after a calc-measure ``mutate`` on a cross-joined + ``real_with_totals`` table so the user-visible result no longer + exposes the synthetic totals columns. + """ + return tbl.select([c for c in tbl.columns if not c.startswith(totals_prefix)]) + + +class IbisCalcScope: + """Dual-table scope passed to calc-measure lambdas. + + ``t.column_name`` returns the base-table column when one exists; + ``t.measure_name`` returns the virtual aggregated column otherwise. + Base columns win on collision so that historical patterns like + ``t.distance.sum()`` (where ``distance`` is also a measure name) + still classify as a base aggregation rather than a post-aggregation + sum. + + ``t.all(measure_ref)`` resolves to a Field on a parallel + ``totals_virtual_agg_tbl`` that mirrors the post-aggregation schema + but represents the same measures computed without group_by. The + compiler later substitutes this synthetic table with a real totals + table built from the base by re-running the aggregation without + group keys, so non-sum measures (mean / quantile / …) get correct + overall values rather than a windowed sum of per-group results. + """ + + __slots__ = ("_base_tbl", "_virtual_agg_tbl", "_totals_virtual_agg_tbl", "_known_measures") + + def __init__( + self, + base_tbl, + virtual_agg_tbl, + known_measures, + totals_virtual_agg_tbl=None, + ): + object.__setattr__(self, "_base_tbl", base_tbl) + object.__setattr__(self, "_virtual_agg_tbl", virtual_agg_tbl) + if totals_virtual_agg_tbl is None: + vt_op = _to_op(virtual_agg_tbl) + schema = ( + dict(vt_op.schema.items()) + if hasattr(vt_op, "schema") + else {n: "float64" for n in known_measures} + ) + if not schema: + schema = dict(_EMPTY_VT_SCHEMA) + totals_virtual_agg_tbl = ibis_mod.table(schema, name="__bsl_virtual_totals__") + object.__setattr__(self, "_totals_virtual_agg_tbl", totals_virtual_agg_tbl) + object.__setattr__(self, "_known_measures", frozenset(known_measures)) + + @property + def tbl(self): + """Backwards-compat: return the base table for callers that + introspect ``scope.tbl`` (e.g. unnest inference).""" + return self._base_tbl + + def _has_column(self, name: str) -> bool: + return hasattr(self._base_tbl, "columns") and name in self._base_tbl.columns + + def _typo_suggestion(self, name: str) -> str | None: + cutoff = 0.80 + candidates: list[tuple[str, str]] = [] + if self._known_measures: + for match in difflib.get_close_matches( + name, list(self._known_measures), n=3, cutoff=cutoff + ): + candidates.append(("measure", match)) + if hasattr(self._base_tbl, "columns"): + for match in difflib.get_close_matches( + name, list(self._base_tbl.columns), n=3, cutoff=cutoff + ): + candidates.append(("column", match)) + if not candidates: + return None + formatted = ", ".join(f"{kind} {match!r}" for kind, match in candidates) + return f"Did you mean: {formatted}?" + + def _resolve_measure_name(self, name: str) -> str | None: + """Resolve ``name`` to a known measure, including suffix matching. + + On a joined model, measure names are prefixed (``flights.flight_count``). + A calc-measure lambda written on the un-joined model still references + them by short name (``t.flight_count``); we transparently bridge by + suffix-matching when a unique match exists. + """ + if name in self._known_measures: + return name + suffix = f".{name}" + matches = tuple(k for k in self._known_measures if k.endswith(suffix)) + if len(matches) == 1: + return matches[0] + return None + + def __getattr__(self, name: str): + if name.startswith("_"): + raise AttributeError(name) + if self._has_column(name): + return self._base_tbl[name] + resolved = self._resolve_measure_name(name) + if resolved is not None: + return self._virtual_agg_tbl[resolved] + # Fall through to ibis Table methods (e.g. `count`). + try: + return getattr(self._base_tbl, name) + except AttributeError: + suggestion = self._typo_suggestion(name) + if suggestion: + raise UnknownMeasureRefError( + f"{name!r} is not a known measure or column. {suggestion}" + ) from None + raise + + def __getitem__(self, name: str): + if self._has_column(name): + return self._base_tbl[name] + resolved = self._resolve_measure_name(name) + if resolved is not None: + return self._virtual_agg_tbl[resolved] + return self._base_tbl[name] + + def all(self, x: Any): + """Resolve a measure reference to its totals-table column. + + ``t.all(measure_name)`` and ``t.all(t.measure_name)`` both + return ``Field(totals_virtual_agg_tbl, measure_name)``. The + compiler builds a real totals table from the base aggregation + (no group_by) and substitutes it in at compile time, so the + result is the measure's overall value computed by the same + formula — not a windowed sum of per-group values. + + Inline reductions (``t.all(t.distance.sum())``) keep the + windowed-reduction shape here; the inline-reduction lift in + :func:`lift_inline_reductions` rewrites them to totals-table + Field references too. + + Raw base columns (``t.all("col")`` where ``col`` is not a + measure) fall back to the legacy ``column.sum().over(window())`` + shape — there is no measure formula to re-apply. + """ + if isinstance(x, str): + resolved = self._resolve_measure_name(x) + if resolved is not None: + return self._totals_virtual_agg_tbl[resolved] + if self._has_column(x): + logger.warning( + "t.all(%r) over a raw column emits column.sum().over(window()); " + "this is correct only for sum semantics. Reference a measure " + "instead (e.g. t.all(t.measure_name)) so the totals re-aggregation " + "uses the measure formula.", + x, + ) + return self._base_tbl[x].sum().over(ibis_mod.window()) + suggestion = self._typo_suggestion(x) + if suggestion: + raise UnknownMeasureRefError( + f"{x!r} is not a known measure or column. {suggestion}" + ) + return self._base_tbl[x].sum().over(ibis_mod.window()) + + # If x is a Field on virtual_agg_tbl (a known measure + # reference), redirect to the parallel totals table so the + # compiler can substitute in a properly re-aggregated value. + if hasattr(x, "op") and callable(x.op): + try: + op = x.op() + if isinstance(op, Field) and id(op.rel) == id(_to_op(self._virtual_agg_tbl)): + return self._totals_virtual_agg_tbl[op.name] + + Reduction = getattr(ibis_ops, "Reduction", None) + if Reduction is not None: + if isinstance(op, Reduction): + return x.over(ibis_mod.window()) + if any(isinstance(n, Reduction) for n in _walk(op)): + return x.over(ibis_mod.window()) + except Exception as exc: + logger.debug("IbisCalcScope.all() reduction-detection swallowed: %s", exc) + + if hasattr(x, "sum"): + return x.sum().over(ibis_mod.window()) + + return x + + +def evaluate_calc_lambda( + fn, + base_tbl, + known_measures: frozenset[str], + virtual_agg_schema: dict[str, Any] | None = None, +): + """Run a calc-measure lambda and return the ibis expression it builds. + + Constructs an :class:`IbisCalcScope` over ``base_tbl``, a synthetic + virtual aggregated table whose schema is derived from + ``virtual_agg_schema``, and a parallel synthetic totals table with + the same schema. The scope is passed to ``fn`` exactly once; the + returned ibis expression encodes the structural shape the analyzer + walks — including any ``Field(totals_vt, ...)`` references emitted + by ``t.all(measure_ref)``. + + Returns ``(expr, vt, totals_vt)``. Callers that only need the + virtual aggregated table can ignore the third element. + """ + if virtual_agg_schema is None: + virtual_agg_schema = {name: "float64" for name in known_measures} + if not virtual_agg_schema: + virtual_agg_schema = dict(_EMPTY_VT_SCHEMA) + + vt = virtual_agg_table(virtual_agg_schema) + totals_vt = ibis_mod.table(dict(virtual_agg_schema), name="__bsl_virtual_totals__") + scope = IbisCalcScope(base_tbl, vt, known_measures, totals_virtual_agg_tbl=totals_vt) + + if hasattr(fn, "_resolver") and hasattr(fn, "resolve"): + return fn.resolve(scope), vt, totals_vt + + if callable(fn): + return fn(scope), vt, totals_vt + + return fn, vt, totals_vt + + +def classify_calc_lambda( + fn, + base_tbl, + known_measures: frozenset[str], + virtual_agg_schema: dict[str, Any] | None = None, +) -> tuple[Any, CalcExprAnalysis]: + """Evaluate the lambda and run :func:`analyze_calc_expr` on the result. + + Returns ``(expr, analysis)`` where ``expr`` is the ibis expression + the lambda built (with references against the virtual aggregated + table) and ``analysis`` is the structural classification. The + caller can then route to base-measure or calc-measure compilation + based on ``analysis.pushable``. + """ + expr, vt, totals_vt = evaluate_calc_lambda( + fn, base_tbl, known_measures, virtual_agg_schema + ) + analysis = analyze_calc_expr( + expr, + known_measures=known_measures, + base_table_op=_to_op(base_tbl), + totals_vt_op=_to_op(totals_vt), + ) + return expr, analysis + + +def compile_calc_measure( + expr, + virtual_agg_tbl, + real_agg_tbl, + totals_virtual_agg_tbl=None, + real_with_totals=None, + totals_prefix: str = TOTALS_PREFIX, +): + """Compile a calc-measure ibis expression against the real agg table. + + Substitutes references to ``virtual_agg_tbl`` with ``real_agg_tbl``. + When the calc references a totals virtual table, also rewrites + each ``Field(totals_vt, name)`` to ``Field(real_with_totals, + f"{totals_prefix}{name}")`` — i.e. the prefixed column produced by + cross-joining the totals aggregation into the per-group result. + + The resulting ibis expression is suitable for use as a column in + ``mutate(name=expr)`` on whichever table holds those references + (``real_agg_tbl`` for non-totals calcs, ``real_with_totals`` for + totals-using calcs). + """ + op = _to_op(expr) + vt_op = _to_op(virtual_agg_tbl) + real_op = _to_op(real_agg_tbl) + subs: dict = {vt_op: real_op} + + totals_vt_op = None + if totals_virtual_agg_tbl is not None and real_with_totals is not None: + totals_vt_op = _to_op(totals_virtual_agg_tbl) + rwt_op = _to_op(real_with_totals) + totals_schema = ( + dict(totals_vt_op.schema.items()) if hasattr(totals_vt_op, "schema") else {} + ) + rwt_columns = ( + real_with_totals.columns if hasattr(real_with_totals, "columns") else () + ) + for col_name in totals_schema: + prefixed = f"{totals_prefix}{col_name}" + target_name = prefixed if prefixed in rwt_columns else col_name + if target_name in rwt_columns: + subs[Field(totals_vt_op, col_name)] = Field(rwt_op, target_name) + + rewritten = op.replace(subs) + + # Verify no Field reference to the totals virtual table survived the + # rewrite; an unsubstituted reference reaches ibis as + # ``IntegrityError: Cannot add ... to projection`` and obscures the + # real cause (schema drift / missing totals column). + if totals_vt_op is not None: + unresolved = sorted( + { + n.name + for n in _walk(rewritten) + if isinstance(n, Field) and id(n.rel) == id(totals_vt_op) + } + ) + if unresolved: + raise TotalsNotAvailableError( + "Calc measure references totals columns that were not " + f"substituted: {unresolved!r}. Expected prefixed columns " + f"({totals_prefix}) on the cross-joined real_with_totals " + f"table but found neither prefixed nor unprefixed match in " + f"columns: {list(real_with_totals.columns)!r}." + ) + + return rewritten.to_expr() + + +def compile_calc_measures( + real_agg_tbl, + calc_exprs: dict[str, tuple[Any, Any]], +): + """Apply post-aggregation calc measures to the aggregated table. + + Convenience wrapper: each entry in ``calc_exprs`` is + ``measure_name → (expr, virtual_agg_tbl)``; we substitute the + virtual table with ``real_agg_tbl`` and add the resulting columns + via ``mutate``. Totals-aware compilation lives in the + full :func:`compile_calc_measure` entry point and is invoked from + higher-level orchestration in ``ops.py``. + """ + if not calc_exprs: + return real_agg_tbl + new_cols = { + name: compile_calc_measure(expr, vt, real_agg_tbl) + for name, (expr, vt) in calc_exprs.items() + } + return real_agg_tbl.mutate(**new_cols) + + +def apply_calc_measures( + real_agg_tbl, + base_tbl, + calc_lambdas: dict[str, Any], + known_measures: frozenset[str], + real_totals_tbl=None, + agg_specs: dict[str, Any] | None = None, + totals_prefix: str = TOTALS_PREFIX, +): + """Re-run each calc-measure lambda against the real aggregated table. + + Calc measures are applied one-at-a-time in topological order via + successive ``mutate`` calls so calc-of-calc chains see prior results + as columns. + + Totals handling: when a calc lambda emits ``Field(totals_vt, name)`` + references (the ``t.all(measure_ref)`` shape) we cross-join a + no-group-by totals table into the result with prefixed column names + and rewrite the field references to point at it. ``real_totals_tbl`` + can be supplied directly; otherwise, when ``agg_specs`` is provided + and at least one calc actually references totals, we build the + totals table on first need by re-running ``agg_specs`` on + ``base_tbl`` without group keys. + + .. note:: + When ``real_totals_tbl`` is supplied directly, the caller is + responsible for ensuring it already carries any non-AllOf calc + columns that AllOf calcs depend on. The lazy ``agg_specs`` path + re-applies those calc-of-calc deps automatically; the + pre-built path does not, since rebuilding them would require + re-running the analyzer on a table whose schema may already + diverge from ``base_tbl`` + ``agg_specs``. + + Raises :class:`TotalsNotAvailableError` when a calc references + totals but neither ``real_totals_tbl`` nor ``agg_specs`` lets us + build one — surfaces the missing-totals condition with a clear + message instead of letting the unsubstituted ``Field(totals_vt, + ...)`` reach ibis and fail with ``IntegrityError``. + """ + if not calc_lambdas: + return real_agg_tbl + + ordered = _topological_order(calc_lambdas, base_tbl, known_measures) + + real_with_totals = None + if real_totals_tbl is not None: + real_with_totals = _join_totals(real_agg_tbl, real_totals_tbl, totals_prefix) + + base_op = _to_op(base_tbl) + + for name in ordered: + fn = calc_lambdas[name] + cur_known = known_measures | frozenset(real_agg_tbl.columns) + virtual_schema = { + col: real_agg_tbl[col].type() + for col in real_agg_tbl.columns + if col in cur_known + } + expr, vt, totals_vt = evaluate_calc_lambda( + fn, base_tbl, cur_known, virtual_schema + ) + analysis = analyze_calc_expr( + expr, + known_measures=cur_known, + base_table_op=base_op, + totals_vt_op=_to_op(totals_vt), + ) + + if analysis.references_AllOf: + if real_with_totals is None: + if real_totals_tbl is None: + real_totals_tbl = _build_totals_from_agg_specs( + base_tbl, agg_specs, calc_lambdas, known_measures + ) + if real_totals_tbl is not None: + real_with_totals = _join_totals( + real_agg_tbl, real_totals_tbl, totals_prefix + ) + + if real_with_totals is None: + raise TotalsNotAvailableError( + f"Calc measure {name!r} references t.all(...) but no totals " + "table could be built. Pass `real_totals_tbl` or `agg_specs` " + "to apply_calc_measures, or define the calc on a model " + "without nested-array measures (which compile at multiple " + "grains and don't yet support totals)." + ) + + compiled = compile_calc_measure( + expr, + vt, + real_with_totals, + totals_virtual_agg_tbl=totals_vt, + real_with_totals=real_with_totals, + totals_prefix=totals_prefix, + ) + real_with_totals = real_with_totals.mutate(**{name: compiled}) + real_agg_tbl = _drop_totals_columns(real_with_totals, totals_prefix) + continue + + compiled = compile_calc_measure(expr, vt, real_agg_tbl) + real_agg_tbl = real_agg_tbl.mutate(**{name: compiled}) + if real_with_totals is not None and real_totals_tbl is not None: + real_with_totals = _join_totals(real_agg_tbl, real_totals_tbl, totals_prefix) + + return real_agg_tbl + + +def attach_windowed_totals( + base_tbl, + agg_specs: dict[str, Any], + total_names: Iterable[str], + totals_prefix: str = TOTALS_PREFIX, +) -> tuple[Any, dict[str, Any]]: + """Pre-mutate ``base_tbl`` with windowed totals for the given base measures. + + For each name in ``total_names`` that has an entry in ``agg_specs``, + evaluate the agg-spec callable on ``base_tbl`` to get the measure's + aggregation expression (e.g. ``base.count()`` or + ``base.distance.mean()``), wrap it in ``.over(window())`` to produce + a window function over the entire base, and add the result as a + new column ``f"{totals_prefix}{name}"``. Returns the mutated base + table plus a dict of arbitrary-aggregator specs that callers should + add to their per-group aggregation so the totals propagate as + ordinary columns on the result. + + This expresses "ungrouped aggregate alongside a grouped one" as a + single-pass query: the totals are computed once via window function, + broadcast to every base row, and surface as a per-group column via + ``arbitrary()`` in the aggregation. No cross-join, no shared-ancestor + collapse, compiles to SQL on every backend that supports window + functions. + + Returns + ------- + (new_base_tbl, totals_arbitrary_specs): + - ``new_base_tbl`` carries the original columns plus + ``__bsl_totals__`` for each requested measure. + - ``totals_arbitrary_specs[col]`` is an agg-spec callable that + wraps ``t[col].arbitrary()``. + """ + new_base = base_tbl + arbitrary_specs: dict[str, Any] = {} + for name in total_names: + if name not in agg_specs: + continue + try: + agg_expr = agg_specs[name](new_base) + except Exception as exc: + logger.debug( + "could not evaluate agg_spec for %r when attaching windowed totals: %s", + name, + exc, + ) + continue + try: + windowed = agg_expr.over(ibis_mod.window()) + except Exception as exc: + logger.debug( + "could not wrap %r in window() for windowed totals: %s", + name, + exc, + ) + continue + col = f"{totals_prefix}{name}" + new_base = new_base.mutate(**{col: windowed}) + arbitrary_specs[col] = (lambda t, _c=col: t[_c].arbitrary()) + return new_base, arbitrary_specs + + +class _TotalsResolvingScope: + """Scope that resolves measure references to ``__bsl_totals__`` columns. + + Used by :func:`attach_calc_totals` to evaluate a calc lambda + against the totals columns of a per-group result. Since each + ``__bsl_totals__`` column carries the same value across all + rows (the overall total computed via window function), applying + a calc formula against this scope produces the calc's totals value + on every row. + """ + + __slots__ = ("_tbl", "_totals_prefix") + + def __init__(self, tbl, totals_prefix: str): + object.__setattr__(self, "_tbl", tbl) + object.__setattr__(self, "_totals_prefix", totals_prefix) + + def _resolve(self, name: str): + col = f"{self._totals_prefix}{name}" + if hasattr(self._tbl, "columns") and col in self._tbl.columns: + return self._tbl[col] + # Suffix matching for joined models: ``flights.flight_count`` + # has totals column ``__bsl_totals__flights.flight_count``. + suffix = f".{name}" + for c in getattr(self._tbl, "columns", ()): + if c.startswith(self._totals_prefix) and c[len(self._totals_prefix):].endswith( + suffix + ): + return self._tbl[c] + raise AttributeError(f"No totals column found for measure {name!r}") + + def __getattr__(self, name: str): + if name.startswith("_"): + raise AttributeError(name) + return self._resolve(name) + + def __getitem__(self, name: str): + return self._resolve(name) + + def all(self, x): + # Inside a totals evaluation, ``t.all(t.x)`` is just ``t.x`` — + # we're already computing in the totals scope. Pass the value + # through. + if isinstance(x, str): + return self._resolve(x) + return x + + +def attach_calc_totals( + real_agg_tbl, + calc_specs: dict[str, Any], + classifications: dict[str, CalcExprAnalysis], + totals_prefix: str = TOTALS_PREFIX, +): + """Compute ``__bsl_totals__`` columns for calc-of-calc-AllOf chains. + + When an AllOf-using calc references another calc (rather than a + base measure) — e.g. ``t.all(t.avg_distance)`` where + ``avg_distance`` is itself a calc — we need the totals value of + the referenced calc on the per-group result so substitution can + point at it. ``attach_windowed_totals`` only handles base measures + via ``agg.over(window())``; this function fills the gap by + evaluating each calc's lambda against the totals columns already + attached to ``real_agg_tbl``, in topological order so calc-of-calc + chains see prior totals as inputs. + + The user's calc lambda doesn't change — it's the same formula — + but the scope it runs against returns ``__bsl_totals__`` + columns instead of regular per-group columns. Since each totals + column carries a constant value across rows, applying the formula + yields the corresponding constant calc-totals value. + """ + # Identify calcs whose totals are needed: the direct AllOf targets + # plus any transitive calc dependencies of those. + needed: set[str] = set() + work: list[str] = [] + for cn, c in classifications.items(): + if c.references_AllOf: + for d in c.depends_on: + if d in calc_specs: + needed.add(d) + work.append(d) + while work: + n = work.pop() + if n not in classifications: + continue + for d in classifications[n].depends_on: + if d in calc_specs and d not in needed: + needed.add(d) + work.append(d) + + if not needed: + return real_agg_tbl + + # Topo-order so a calc's deps are computed before the calc itself. + deps_map = { + n: set(classifications[n].depends_on) & needed + for n in needed + if n in classifications + } + ordered = topological_order_from_deps(needed, deps_map) + + for calc_name in ordered: + if calc_name not in calc_specs: + continue + cm = calc_specs[calc_name] + fn = cm.expr if hasattr(cm, "expr") else cm + try: + scope = _TotalsResolvingScope(real_agg_tbl, totals_prefix) + if hasattr(fn, "_resolver") and hasattr(fn, "resolve"): + totals_expr = fn.resolve(scope) + elif callable(fn): + totals_expr = fn(scope) + else: + totals_expr = fn + except Exception as exc: + logger.debug( + "calc-of-calc totals evaluation failed for %r: %s", calc_name, exc + ) + continue + col = f"{totals_prefix}{calc_name}" + real_agg_tbl = real_agg_tbl.mutate(**{col: totals_expr}) + + return real_agg_tbl + + +def _join_totals(real_agg_tbl, real_totals_tbl, totals_prefix: str): + """Legacy cross-join path. Kept for ``apply_calc_measures`` callers + that pass a pre-built ``real_totals_tbl``. + + .. deprecated:: + Prefer :func:`attach_windowed_totals` which avoids the + shared-ancestor cross-join collapse some SQL backends apply + when both sides derive from the same parent relation. This + helper survives only for the ``apply_calc_measures(real_totals_tbl=...)`` + entry point where the totals are produced externally and the + per-group table is already built; the windowed-totals path + requires attaching at base-table time before the per-group + aggregation runs. + """ + rename_map = {f"{totals_prefix}{c}": c for c in real_totals_tbl.columns} + totals_renamed = real_totals_tbl.rename(rename_map) + return real_agg_tbl.cross_join(totals_renamed) + + +def classify_calc_lambdas( + calc_lambdas: dict[str, Any], + base_tbl, + known_measures: frozenset[str], +) -> dict[str, CalcExprAnalysis]: + """Run the analyzer once per calc lambda; return ``{name → analysis}``. + + Lets multiple passes (topological order, totals-build filtering, + apply loop) read the same classification record without + re-evaluating each lambda. Lambdas that fail evaluation get an + empty ``CalcExprAnalysis`` (post_agg_only=True, no deps) so the + surrounding orchestration still terminates. + """ + base_op = _to_op(base_tbl) + out: dict[str, CalcExprAnalysis] = {} + for name, fn in calc_lambdas.items(): + try: + virtual_schema = {n: "float64" for n in known_measures} + expr, _vt, totals_vt = evaluate_calc_lambda( + fn, base_tbl, known_measures, virtual_schema + ) + out[name] = analyze_calc_expr( + expr, + known_measures=known_measures, + base_table_op=base_op, + totals_vt_op=_to_op(totals_vt), + ) + except Exception as exc: + logger.debug("calc-measure classification failed for %r: %s", name, exc) + out[name] = CalcExprAnalysis( + pushable=False, + references_AllOf=False, + has_window=False, + post_agg_only=True, + ) + return out + + +def _build_totals_from_agg_specs( + base_tbl, + agg_specs: dict[str, Any] | None, + calc_lambdas: dict[str, Any], + known_measures: frozenset[str], + classifications: dict[str, CalcExprAnalysis] | None = None, +): + """Build a no-group-by totals table when callers passed ``agg_specs``. + + Re-runs each base-aggregation callable on ``base_tbl`` without group + keys, then applies the non-AllOf calc lambdas so calc-of-calc chains + see correctly-recomputed dependencies. Returns ``None`` when there + is no way to construct totals (no ``agg_specs`` supplied or the + specs fail to evaluate against the base). + """ + if not agg_specs: + return None + try: + totals_aggs = {n: f(base_tbl) for n, f in agg_specs.items()} + except Exception as exc: + logger.debug("totals aggregation failed to evaluate: %s", exc) + return None + real_totals = base_tbl.aggregate(**totals_aggs) + + if classifications is None: + classifications = classify_calc_lambdas(calc_lambdas, base_tbl, known_measures) + non_allof = { + name: fn + for name, fn in calc_lambdas.items() + if not classifications.get(name, _EMPTY_ANALYSIS).references_AllOf + } + if non_allof: + real_totals = apply_calc_measures(real_totals, base_tbl, non_allof, known_measures) + return real_totals + + +_EMPTY_ANALYSIS = CalcExprAnalysis( + pushable=False, + references_AllOf=False, + has_window=False, + post_agg_only=True, +) + + +def topological_order_from_deps( + names: list[str] | tuple[str, ...] | dict[str, Any], + deps: dict[str, set[str] | frozenset[str]], +) -> list[str]: + """Topologically order ``names`` using ``deps`` (``name → {dep, ...}``). + + Edges to nodes outside ``names`` are ignored; cycles fall back to + insertion order so a downstream substitution failure surfaces the + real error rather than this helper raising. Shared by ops.py and + apply_calc_measures so calc-of-calc ordering is consistent. + """ + name_seq = list(names) + name_set = set(name_seq) + + ordered: list[str] = [] + visited: set[str] = set() + visiting: set[str] = set() + + def visit(node: str) -> None: + if node in visited or node in visiting: + return + visiting.add(node) + for dep in deps.get(node, ()): + if dep in name_set: + visit(dep) + visiting.discard(node) + visited.add(node) + ordered.append(node) + + for n in name_seq: + visit(n) + return ordered + + +def _topological_order( + calc_lambdas: dict[str, Any], + base_tbl, + known_measures: frozenset[str], + classifications: dict[str, CalcExprAnalysis] | None = None, +) -> list[str]: + """Order calc lambdas using analyzer-derived dependencies.""" + if classifications is None: + classifications = classify_calc_lambdas(calc_lambdas, base_tbl, known_measures) + deps = { + name: set(classifications.get(name, _EMPTY_ANALYSIS).depends_on) + for name in calc_lambdas + } + return topological_order_from_deps(calc_lambdas, deps) + + +def lift_inline_reductions(expr, virtual_agg_tbl, base_tbl, totals_virtual_agg_tbl=None): + """Lift inline reductions over the base table out of a calc expression. + + The user's calc lambda may contain reductions that read base-table + columns directly, e.g. ``t.distance.mean() / t.all(t.distance.mean())``. + Straight ``mutate`` can't compile these because the reductions are + bound to the unaggregated base relation, not the post-aggregation + result. + + Each unique base-table reduction is named, added to both the per-group + base aggregation and the (no-group-by) totals aggregation, then + rewritten in the calc expression: + + - A reduction at the top level becomes ``Field(vt, anon_name)`` — + a column reference on the per-group result. + - A reduction that is the ``func`` of a ``WindowFunction`` (the + ``t.all(...)`` totals shape) becomes ``Field(totals_vt, anon_name)`` + — a reference to the same reduction computed over the full + filtered base. The compiler later substitutes ``totals_vt`` with a + real totals table cross-joined into the result, so non-sum + reductions (mean/quantile/…) get correct overall values. + + Returns ``(rewritten_expr, new_vt, new_totals_vt, lifted)`` where + ``lifted`` maps anonymous names to the original scalar reduction + expression. The caller adds those reductions to both the per-group + aggregation and the totals aggregation. + """ + op = _to_op(expr) + vt_op = _to_op(virtual_agg_tbl) + if totals_virtual_agg_tbl is None: + totals_schema = dict(vt_op.schema.items()) if hasattr(vt_op, "schema") else {} + totals_virtual_agg_tbl = ibis_mod.table( + totals_schema or dict(_EMPTY_VT_SCHEMA), + name="__bsl_virtual_totals__", + ) + totals_vt_op = _to_op(totals_virtual_agg_tbl) + base_op = _to_op(base_tbl) + + Reduction = getattr(ibis_ops, "Reduction", None) + WindowFunction = getattr(ibis_ops, "WindowFunction", None) + + if Reduction is None: + return expr, virtual_agg_tbl, totals_virtual_agg_tbl, {} + + def is_base_reduction(node): + if not isinstance(node, Reduction): + return False + for c in _walk(node): + if isinstance(c, Field) and id(c.rel) == id(base_op): + return True + return False + + base_reductions = [n for n in _walk(op) if is_base_reduction(n)] + if not base_reductions: + return expr, virtual_agg_tbl, totals_virtual_agg_tbl, {} + + name_to_reduction: dict[str, Any] = {} + reduction_to_name: dict[int, str] = {} + counter = 0 + for r in base_reductions: + if id(r) in reduction_to_name: + continue + anon = f"__bsl_inline_{type(r).__name__.lower()}_{counter}" + counter += 1 + name_to_reduction[anon] = r + reduction_to_name[id(r)] = anon + + extended_schema = dict(vt_op.schema.items()) + for anon, r in name_to_reduction.items(): + extended_schema[anon] = r.dtype + new_vt = ibis_mod.table(extended_schema, name=getattr(vt_op, "name", "__bsl_virtual_agg__")) + new_vt_op = new_vt.op() + + totals_extended_schema = dict(totals_vt_op.schema.items()) if hasattr( + totals_vt_op, "schema" + ) else {} + for anon, r in name_to_reduction.items(): + totals_extended_schema[anon] = r.dtype + new_totals_vt = ibis_mod.table( + totals_extended_schema or dict(_EMPTY_VT_SCHEMA), + name=getattr(totals_vt_op, "name", "__bsl_virtual_totals__"), + ) + new_totals_vt_op = new_totals_vt.op() + + # Two-pass substitution. The same ``Reduction`` node may appear both + # at top level (where we want ``Field(vt, anon)`` — the per-group + # value) and as a ``WindowFunction.func`` (the ``t.all(...)`` totals + # shape, where we want ``Field(totals_vt, anon)`` — the overall + # value). ``op.replace`` dedupes by equality, so we can't tell those + # apart in one pass: handle WindowFunctions first, then the bare + # Reductions. + if WindowFunction is not None: + window_subs: dict = {} + for n in _walk(op): + if not isinstance(n, WindowFunction): + continue + inner = getattr(n, "func", None) + if inner is None or id(inner) not in reduction_to_name: + continue + anon = reduction_to_name[id(inner)] + window_subs[n] = Field(new_totals_vt_op, anon) + intermediate = op.replace(window_subs) if window_subs else op + else: + intermediate = op + + field_subs = {r: Field(new_vt_op, reduction_to_name[id(r)]) for r in base_reductions} + new_op = intermediate.replace(field_subs) + + lifted_aggs = {anon: r.to_expr() for anon, r in name_to_reduction.items()} + return new_op.to_expr(), new_vt, new_totals_vt, lifted_aggs + + +def rename_measure_refs(expr, virtual_agg_tbl, name_map: dict[str, str]): + """Rename measure references inside a calc-measure ibis expression. + + Used when joining tables: a calc measure declared on a model named + ``flights`` may reference ``flight_count``, but after the join the + aggregated column is ``flights.flight_count``. This function rebuilds + the calc expression so that field references on the virtual aggregated + table map to their prefixed names. + + Parameters + ---------- + expr: + Calc-measure ibis expression built against ``virtual_agg_tbl``. + virtual_agg_tbl: + The synthetic table the expression was built against. + name_map: + Mapping of ``old_name → new_name`` for measure references that + need renaming. Names not in the map are left untouched. + + Returns + ------- + A new ibis expression with prefixed field names. The returned + expression now references a *new* virtual table whose schema includes + the renamed columns, so callers must compile against that new virtual + table (use :func:`build_renamed_virtual_table` to get it). + """ + if not name_map: + return expr, virtual_agg_tbl + + vt_op = virtual_agg_tbl.op() if hasattr(virtual_agg_tbl, "op") else virtual_agg_tbl + old_schema = dict(vt_op.schema.items()) if hasattr(vt_op, "schema") else {} + new_schema = {name_map.get(k, k): v for k, v in old_schema.items()} + new_vt = ibis_mod.table(new_schema, name=getattr(vt_op, "name", "__bsl_virtual_agg__")) + new_vt_op = new_vt.op() + + field_substitutions = {} + for old_name, dtype in old_schema.items(): + new_name = name_map.get(old_name, old_name) + old_field = Field(vt_op, old_name) + new_field = Field(new_vt_op, new_name) + field_substitutions[old_field] = new_field + + op = expr.op() if hasattr(expr, "op") and callable(expr.op) else expr + return op.replace(field_substitutions).to_expr(), new_vt diff --git a/src/boring_semantic_layer/compile_all.py b/src/boring_semantic_layer/compile_all.py deleted file mode 100644 index 59154a1..0000000 --- a/src/boring_semantic_layer/compile_all.py +++ /dev/null @@ -1,564 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable -from functools import reduce -from typing import Any - -import ibis -from attrs import frozen -from toolz import curry, pipe - -from .measure_scope import AllOf, BinOp, MeasureExpr, MeasureRef, MethodCall - - -@curry -def _extract_nested_array(prev_col: str, array_col: str, table): - if prev_col not in table.columns: - return table - prev_struct = table[prev_col] - if not hasattr(prev_struct, array_col): - return table - return table.mutate(**{array_col: getattr(prev_struct, array_col)}) - - -@curry -def _do_unnest_array(array_col: str, table): - return table.unnest(array_col) if array_col in table.columns else table - - -def _unnest_nested_arrays(base_tbl, array_path: tuple[str, ...]): - sorted_path = tuple(sorted(array_path)) - - def unnest_step(table, indexed_col): - idx, array_col = indexed_col - if idx == 0: - return _do_unnest_array(array_col, table) - prev_col = sorted_path[idx - 1] - if array_col in table.columns: - return _do_unnest_array(array_col, table) - return pipe(table, _extract_nested_array(prev_col, array_col), _do_unnest_array(array_col)) - - return reduce(unnest_step, enumerate(sorted_path), base_tbl) - - -def _collect_all_refs(expr: MeasureExpr, out: set[str]) -> None: - if isinstance(expr, AllOf): - # Only add if ref is a MeasureRef, not an AggregationExpr - if isinstance(expr.ref, MeasureRef): - out.add(expr.ref.name) - # If it's an AggregationExpr, it will be resolved during compilation - elif isinstance(expr, MethodCall): - _collect_all_refs(expr.receiver, out) - elif isinstance(expr, BinOp): - _collect_all_refs(expr.left, out) - _collect_all_refs(expr.right, out) - - -def _collect_aggregation_exprs(expr: MeasureExpr, out: set) -> None: - """Collect all AggregationExpr from a calc_spec expression.""" - from .measure_scope import AggregationExpr - - if isinstance(expr, AggregationExpr): - out.add(expr) - elif isinstance(expr, BinOp): - _collect_aggregation_exprs(expr.left, out) - _collect_aggregation_exprs(expr.right, out) - elif isinstance(expr, AllOf): - _collect_aggregation_exprs(expr.ref, out) - - -def _make_agg_name(agg_expr) -> str: - """Generate a unique name for an AggregationExpr.""" - try: - from dask.base import tokenize - - token = tokenize(agg_expr) - except Exception: - # Fallback when dask isn't available in the runtime environment. - import hashlib - - token = hashlib.sha1(repr(agg_expr).encode()).hexdigest() - - return f"__agg_{agg_expr.column}_{agg_expr.operation}_{token[:12]}" - - -def _replace_aggregation_exprs(expr: MeasureExpr, agg_name_map: dict) -> MeasureExpr: - """Replace AggregationExpr in an expression with MeasureRef to pre-computed aggregations.""" - from .measure_scope import AggregationExpr - - if isinstance(expr, AggregationExpr): - name = agg_name_map.get(expr) - return MeasureRef(name) if name is not None else expr - elif isinstance(expr, BinOp): - new_left = _replace_aggregation_exprs(expr.left, agg_name_map) - new_right = _replace_aggregation_exprs(expr.right, agg_name_map) - return BinOp(op=expr.op, left=new_left, right=new_right) - elif isinstance(expr, AllOf): - new_ref = _replace_aggregation_exprs(expr.ref, agg_name_map) - return AllOf(ref=new_ref) - return expr - - -def _make_agg_fn_from_expr(agg_expr): - """Create an aggregation function from an AggregationExpr.""" - operations = { - "sum": lambda col: col.sum(), - "mean": lambda col: col.mean(), - "avg": lambda col: col.mean(), - "count": lambda col: col.count(), - "min": lambda col: col.min(), - "max": lambda col: col.max(), - } - - def agg_fn(t): - if agg_expr.operation == "count": - result = t.count() - else: - result = operations[agg_expr.operation](t[agg_expr.column]) - - # Apply post_ops (e.g., .coalesce(0)) - for method_name, args, kwargs_tuple in agg_expr.post_ops: - result = getattr(result, method_name)(*args, **dict(kwargs_tuple)) - - return result - - return agg_fn - - -@curry -def _compile_binop(by_tbl, all_tbl, base_tbl, op: str, left: Any, right: Any): - left_val = _compile_formula(left, by_tbl, all_tbl, base_tbl) - right_val = _compile_formula(right, by_tbl, all_tbl, base_tbl) - ops = { - "add": lambda left_val, right_val: left_val + right_val, - "sub": lambda left_val, right_val: left_val - right_val, - "mul": lambda left_val, right_val: left_val * right_val, - "div": lambda left_val, right_val: left_val.cast("float64") / right_val.cast("float64"), - } - if op not in ops: - raise ValueError(f"Unknown operator: {op}") - return ops[op](left_val, right_val) - - -def _get_ibis_module(table): - """Detect which ibis module the table is using (regular ibis or xorq's vendored ibis).""" - table_module = type(table).__module__ - if table_module.startswith("xorq.vendor.ibis"): - # Table is from xorq's vendored ibis - from ._xorq import ibis as xorq_ibis - return xorq_ibis - else: - # Table is from regular ibis - return ibis - - -def _compile_formula(expr: MeasureExpr, by_tbl, all_tbl, base_tbl): - """Compile a measure expression to ibis, using functional dispatch for AggregationExpr.""" - from .measure_scope import AggregationExpr - - if isinstance(expr, int | float): - # Use the same ibis module as the table to avoid mixing regular and xorq ibis - ibis_module = _get_ibis_module(by_tbl) - return ibis_module.literal(expr) - if isinstance(expr, MeasureRef): - return by_tbl[expr.name] - if isinstance(expr, AllOf): - # Handle AllOf with AggregationExpr or MeasureRef - if isinstance(expr.ref, MeasureRef): - return all_tbl[expr.ref.name] - elif isinstance(expr.ref, AggregationExpr): - # AllOf with AggregationExpr requires the measure to exist - raise ValueError( - f"Unresolved AggregationExpr in AllOf: {expr.ref}. " - f"Expected a measure computing {expr.ref.column}.{expr.ref.operation}() to exist." - ) - if isinstance(expr, AggregationExpr): - # Handle inline aggregations in calculated measures - # Check if column exists in base table - if expr.column != "*" and expr.column not in base_tbl.columns: - raise ValueError( - f"Unresolved AggregationExpr: {expr}. " - f"Column '{expr.column}' not found in base table. " - f"Available columns: {base_tbl.columns}" - ) - - # Apply the aggregation to the base table column - # The result will be an aggregation that can be used in the grouped context - operations = { - "sum": lambda col: col.sum(), - "mean": lambda col: col.mean(), - "avg": lambda col: col.mean(), - "count": lambda col: col.count(), - "min": lambda col: col.min(), - "max": lambda col: col.max(), - } - - op_fn = operations.get(expr.operation) - if op_fn is None: - raise ValueError(f"Unknown aggregation operation: {expr.operation}") - - # Apply the operation to the base table column - if expr.operation == "count": - result = base_tbl.count() - else: - column = base_tbl[expr.column] - result = op_fn(column) - - # Apply post_ops (e.g., .coalesce(0), .abs(), etc.) - for method_name, args, kwargs_tuple in expr.post_ops: - result = getattr(result, method_name)(*args, **dict(kwargs_tuple)) - - return result - if isinstance(expr, MethodCall): - compiled = _compile_formula(expr.receiver, by_tbl, all_tbl, base_tbl) - return getattr(compiled, expr.method)(*expr.args, **dict(expr.kwargs)) - if isinstance(expr, BinOp): - return _compile_binop(by_tbl, all_tbl, base_tbl, expr.op, expr.left, expr.right) - return expr - - -def infer_calc_dtype(calc_expr, base_measure_schema, base_tbl, ibis_module): - """Compile *calc_expr* against a synthetic dummy table to infer its dtype. - - Mirrors the ``AggregationExpr`` rewrite step in - ``compile_grouped_with_all`` so that calc measures containing inline - aggregations (e.g. ``t.value.sum() / t.all(t.value.sum())``) can have - their type resolved. Each inline ``AggregationExpr`` is materialized - against ``base_tbl`` to learn its dtype, added as a synthetic column - on a dummy table, and replaced with a ``MeasureRef`` in the rewritten - expression before compilation. - - Returns the compiled ibis expression. Caller handles failure. - """ - inline_aggs = set() - _collect_aggregation_exprs(calc_expr, inline_aggs) - - extended_schema = dict(base_measure_schema) - agg_name_map = {} - for agg_expr in sorted(inline_aggs, key=repr): - name = _make_agg_name(agg_expr) - while name in extended_schema: - name = name + "_" - agg_name_map[agg_expr] = name - agg_fn = _make_agg_fn_from_expr(agg_expr) - extended_schema[name] = agg_fn(base_tbl).type() - - dummy = ibis_module.table(extended_schema, name="__type_inference__") - rewritten = _replace_aggregation_exprs(calc_expr, agg_name_map) - return _compile_formula(rewritten, dummy, dummy, base_tbl) - - -@frozen -class MeasureClassification: - regular_measures: dict[str, tuple[callable, Any]] - nested_measures: dict[tuple[str, ...], dict[str, tuple[callable, Any]]] - - -def _fix_relation_mismatch(result, base_tbl): - """Fix measure results that belong to a different table relation. - - When a measure uses filtering (like `t[t.x > 5].count()`), it creates - a new table relation. This causes ibis integrity errors when trying to - use it in an aggregation. We detect this and wrap it in a scalar subquery. - - Args: - result: The measure result expression - base_tbl: The table we're aggregating on - - Returns: - Fixed expression that belongs to base_tbl's relation - """ - # Check if result is an expression - if not hasattr(result, 'op'): - return result - - result_op = result.op() - base_tbl_op = base_tbl.op() - - # Check if the result's immediate table relation matches base_tbl - def get_immediate_table(op): - """Get the immediate table that an aggregation operates on. - - For aggregations like CountStar, the immediate table is the first - argument in __args__. - """ - from ibis.expr.operations.reductions import Reduction - from ibis.expr.operations.relations import Relation - - # For reductions (count, sum, etc), the first arg is usually the table - if isinstance(op, Reduction) and hasattr(op, '__args__') and op.__args__: - first_arg = op.__args__[0] - if isinstance(first_arg, Relation): - return first_arg - - return None - - immediate_table = get_immediate_table(result_op) - - # If the aggregation operates on a different table than base_tbl, - # wrap it in a scalar subquery - if immediate_table is not None and immediate_table != base_tbl_op: - # Convert to scalar subquery - return result.as_scalar() - - return result - - -def make_measure_classification( - base_tbl, - agg_specs: dict[str, callable], -) -> MeasureClassification: - from .nested_access import NestedAccessMarker - - regular = {} - nested = {} - - for name, agg_fn in agg_specs.items(): - result = agg_fn(base_tbl) - - if isinstance(result, NestedAccessMarker): - # Nested measure - group by array path - array_path = result.array_path - if array_path not in nested: - nested[array_path] = {} - nested[array_path][name] = (agg_fn, result) - else: - # Fix relation mismatches (e.g., measures that filter the table) - result = _fix_relation_mismatch(result, base_tbl) - # Regular session-level measure - regular[name] = (agg_fn, result) - - return MeasureClassification( - regular_measures=regular, - nested_measures=nested, - ) - - -@curry -def _build_field_expr(array_path: tuple[str, ...], field_path: tuple[str, ...], unnested_tbl): - # Start from first array column - expr = getattr(unnested_tbl, array_path[0]) - - if not field_path: - return expr - - # Traverse field path - return reduce(lambda e, field: getattr(e, field), field_path, expr) - - -@curry -def _apply_aggregation(marker, expr): - if marker.operation == "count": - # Count operates on table, not expression - return expr.count() if hasattr(expr, "count") else expr - else: - # Other operations on expression - agg_method = getattr(expr, marker.operation) - return agg_method() - - -def _build_nested_aggregation(unnested_tbl, marker) -> Any: - if marker.operation == "count": - return unnested_tbl.count() - - # Build field access expression - expr = _build_field_expr(marker.array_path, marker.field_path, unnested_tbl) - - # Apply aggregation - return _apply_aggregation(marker, expr) - - -def _build_level_aggregations( - base_tbl, - array_path: tuple[str, ...], - measures: dict[str, tuple[callable, Any]], -) -> dict[str, Any]: - unnested_tbl = _unnest_nested_arrays(base_tbl, array_path) - - return { - name: _build_nested_aggregation(unnested_tbl, marker) - for name, (agg_fn, marker) in measures.items() - } - - -@curry -def _make_grouped_table(agg_dict: dict[str, Any], by_cols: Iterable[str], table): - group_exprs = [table[c] for c in by_cols] - # xorq requires at least one grouping expression, so handle empty case - return table.group_by(group_exprs).aggregate(**agg_dict) if group_exprs else table.aggregate(**agg_dict) - - -def _build_session_table(base_tbl, by_cols: Iterable[str], regular_measures: dict) -> Any: - if not regular_measures: - return None - - session_aggs = {name: result for name, (_, result) in regular_measures.items()} - return _make_grouped_table(session_aggs, by_cols, base_tbl) - - -def _build_nested_level_table( - base_tbl, - by_cols: Iterable[str], - array_path: tuple[str, ...], - measures: dict[str, tuple[callable, Any]], -): - level_aggs = _build_level_aggregations(base_tbl, array_path, measures) - unnested_tbl = _unnest_nested_arrays(base_tbl, array_path) - return _make_grouped_table(level_aggs, by_cols, unnested_tbl) - - -def _join_tables(by_cols: Iterable[str], tables: list) -> Any: - if len(tables) == 0: - raise ValueError("Cannot join zero tables") - if len(tables) == 1: - return tables[0] - - by_cols_set = set(by_cols) - - def join_step(left, right): - # Build join predicates - predicates = [left[c] == right[c] for c in by_cols] - - # Select only non-key columns from right to avoid duplicates - right_cols = [c for c in right.columns if c not in by_cols_set] - right_select = [right[c] for c in right_cols] - - # Join and select - return left.left_join(right, predicates).select([left] + right_select) - - # Left join all tables sequentially - return reduce(join_step, tables[1:], tables[0]) - - -def _find_measure_in_nested( - measure_name: str, - nested_measures: dict[tuple[str, ...], dict[str, tuple[callable, Any]]], -) -> tuple[tuple[str, ...], tuple[callable, Any]] | None: - for array_path, measures in nested_measures.items(): - if measure_name in measures: - return (array_path, measures[measure_name]) - return None - - -def _build_total_aggregation( - base_tbl, - measure_name: str, - classification: MeasureClassification, - agg_specs: dict[str, callable], -) -> Any: - # Check regular measures first - if measure_name in classification.regular_measures: - _, result = classification.regular_measures[measure_name] - return result - - # Check nested measures - found = _find_measure_in_nested(measure_name, classification.nested_measures) - if found: - array_path, (agg_fn, marker) = found - unnested_tbl = _unnest_nested_arrays(base_tbl, array_path) - return _build_nested_aggregation(unnested_tbl, marker) - - # Fallback - evaluate the function - return agg_specs[measure_name](base_tbl) - - -def _build_totals_table( - base_tbl, - needed_totals: set[str], - classification: MeasureClassification, - agg_specs: dict[str, callable], -) -> Any | None: - if not needed_totals: - return None - - totals_aggs = { - name: _build_total_aggregation(base_tbl, name, classification, agg_specs) - for name in needed_totals - } - - return base_tbl.aggregate(**totals_aggs) - - -def compile_grouped_with_all( - base_tbl, - by_cols: Iterable[str], - agg_specs: dict[str, callable], - calc_specs: dict[str, MeasureExpr], - requested_measures: Iterable[str] = None, -): - # Step 0: Extract AggregationExpr from calc_specs and add them as regular measures - # This ensures they are computed in the grouped context instead of as scalar subqueries - all_agg_exprs = set() - for calc_expr in calc_specs.values(): - _collect_aggregation_exprs(calc_expr, all_agg_exprs) - - # Create a mapping from AggregationExpr to generated measure name - agg_name_map = {} - extended_agg_specs = dict(agg_specs) - for agg_expr in sorted(all_agg_exprs, key=repr): - name = _make_agg_name(agg_expr) - # Avoid name collisions - while name in extended_agg_specs: - name = name + "_" - agg_name_map[agg_expr] = name - extended_agg_specs[name] = _make_agg_fn_from_expr(agg_expr) - - # Replace AggregationExpr in calc_specs with MeasureRef to the new measures - updated_calc_specs = { - name: _replace_aggregation_exprs(expr, agg_name_map) - for name, expr in calc_specs.items() - } - - # Step 1: Classify measures (using extended agg_specs) - classification = make_measure_classification(base_tbl, extended_agg_specs) - - # Step 2: Build result tables for each level - result_tables = [] - - # Session-level table - session_table = _build_session_table( - base_tbl, - by_cols, - classification.regular_measures, - ) - if session_table is not None: - result_tables.append(session_table) - - # Nested-level tables - for array_path, measures in classification.nested_measures.items(): - level_table = _build_nested_level_table(base_tbl, by_cols, array_path, measures) - result_tables.append(level_table) - - # Step 3: Join tables (or create empty grouped table) - if len(result_tables) == 0: - by_tbl = _make_grouped_table({}, by_cols, base_tbl) - else: - by_tbl = _join_tables(by_cols, result_tables) - - # Step 4: Add totals if needed - needed_totals = set() - for ast in updated_calc_specs.values(): - _collect_all_refs(ast, needed_totals) - - if needed_totals: - all_tbl = _build_totals_table(base_tbl, needed_totals, classification, extended_agg_specs) - out = by_tbl.join(all_tbl, how="cross") - else: - all_tbl = None - out = by_tbl - - # Step 5: Apply calculated measures - calc_cols = {name: _compile_formula(ast, out, all_tbl, base_tbl) for name, ast in updated_calc_specs.items()} - out = out.mutate(**calc_cols) - - # Step 6: Select requested columns (exclude internal __agg_ measures) - if requested_measures is not None: - # Preserve order and uniqueness - select_cols = list( - dict.fromkeys( - list(by_cols) + list(requested_measures) + list(updated_calc_specs.keys()), - ), - ) - out = out.select([out[c] for c in select_cols]) - - return out diff --git a/src/boring_semantic_layer/expr.py b/src/boring_semantic_layer/expr.py index db75766..be7dc0b 100644 --- a/src/boring_semantic_layer/expr.py +++ b/src/boring_semantic_layer/expr.py @@ -19,7 +19,7 @@ ) from .chart import chart as create_chart -from .measure_scope import AggregationExpr, MeasureScope +from .measure_scope import MeasureScope from .ops import ( Dimension, Measure, @@ -40,6 +40,7 @@ _is_deferred, _normalize_join_predicate, _normalize_to_name, + make_bare_ref_lambda, ) from .query import compare_periods as build_compare_periods from .query import query as build_query @@ -1216,12 +1217,12 @@ def aggregate( if _is_deferred(item): try: name = _normalize_to_name(item) - aggs[name] = lambda t, n=name: t[n] + aggs[name] = make_bare_ref_lambda(name) except TypeError: # Complex Deferred (e.g. _.distance.sum()) — treat as callable aggs[f"_measure_{id(item)}"] = item elif isinstance(item, str): - aggs[item] = lambda t, n=item: t[n] + aggs[item] = make_bare_ref_lambda(item) elif callable(item): aggs[f"_measure_{id(item)}"] = item else: @@ -1230,18 +1231,6 @@ def aggregate( f"got {type(item)}", ) - def wrap_aggregation_expr(expr): - if isinstance(expr, AggregationExpr): - - def wrapped(t): - if expr.operation == "count": - return t.count() - return getattr(t[expr.column], expr.operation)() - - return wrapped - return expr - - aliased = {k: wrap_aggregation_expr(v) for k, v in aliased.items()} aggs.update(aliased) if nest: diff --git a/src/boring_semantic_layer/graph_utils.py b/src/boring_semantic_layer/graph_utils.py index 424d662..5f75caa 100644 --- a/src/boring_semantic_layer/graph_utils.py +++ b/src/boring_semantic_layer/graph_utils.py @@ -335,7 +335,7 @@ def build_dependency_graph( Returns: Dictionary mapping field names to metadata with "deps" and "type" keys """ - from .ops import _collect_measure_refs + from .ops import CalcMeasure graph = {} extended_table = _build_extended_table(base_table, dimensions) @@ -362,9 +362,11 @@ def build_dependency_graph( except Exception: graph[name] = {"deps": {}, "type": "dimension" if name in dimensions else "measure"} - for name, calc_expr in calc_measures.items(): - refs = set() - _collect_measure_refs(calc_expr, refs) + for name, calc in calc_measures.items(): + if isinstance(calc, CalcMeasure): + refs = set(calc.depends_on) + else: + refs = set() graph[name] = {"deps": {ref: "measure" for ref in refs}, "type": "calc_measure"} return graph diff --git a/src/boring_semantic_layer/measure_scope.py b/src/boring_semantic_layer/measure_scope.py index f741586..3449bf7 100644 --- a/src/boring_semantic_layer/measure_scope.py +++ b/src/boring_semantic_layer/measure_scope.py @@ -1,19 +1,28 @@ +"""Scopes for evaluating user-supplied measure / dimension lambdas. + +This module is the legacy compatibility surface for ``MeasureScope`` and +``ColumnScope`` — the lookup proxies passed into measure callables. The +curated calc-measure AST (``MeasureRef``, ``AllOf``, ``BinOp`` …) used to +live here too; it has been removed in favor of the analyzer-based path +in :mod:`boring_semantic_layer.calc_compiler`. ``MeasureScope`` is now a +thin pass-through that returns ibis values directly, kept around for the +post-aggregation ``SemanticMutateOp`` path which still constructs a +scope to evaluate ad-hoc mutate lambdas. +""" + from __future__ import annotations import difflib -from collections.abc import Iterable from typing import Any from attrs import field, frozen -from returns.maybe import Maybe, Some -from toolz import curry class UnknownMeasureRefError(AttributeError): - """Raised when a calc-measure lambda references an unknown name. + """Raised when a lambda references an unknown measure or column. - Subclasses :class:`AttributeError` so existing code that ``except``\\ s - on attribute errors continues to work, but ``_classify_measure`` + Subclasses :class:`AttributeError` so existing code that catches + attribute errors continues to work, but the analyzer-based classifier re-raises this specific subclass instead of swallowing it. Surfaces typos at construction time with a "did you mean?" suggestion built from the surrounding measure / column names. @@ -31,9 +40,9 @@ def _has_prefixed_columns(tbl, name: str) -> bool: class _ColumnPrefixProxy: """Proxy for navigating prefixed column names on joined ibis tables. - Supports chained attribute access like ``t.flights.carrier`` which resolves - to ``table["flights.carrier"]`` when the table has columns with the - ``"flights."`` prefix (typical after joins). + Supports chained attribute access like ``t.flights.carrier`` which + resolves to ``table["flights.carrier"]`` when the table has columns + with the ``"flights."`` prefix (typical after joins). """ __slots__ = ("_tbl", "_prefix") @@ -61,265 +70,13 @@ def __getitem__(self, name: str): ) -class _PendingMethodCall: - """Captures a method access on a calc-measure AST node, waiting for ``()``.""" - - __slots__ = ("_receiver", "_method") - - def __init__(self, receiver, method): - object.__setattr__(self, "_receiver", receiver) - object.__setattr__(self, "_method", method) - - def __call__(self, *args, **kwargs): - if args and hasattr(args[0], "columns"): - return self._receiver # table-call passthrough - return MethodCall(self._receiver, self._method, args, tuple(sorted(kwargs.items()))) - - def __getattr__(self, name): - if name.startswith("_"): - raise AttributeError(name) - # Zero-arg call of current method, then chain next method - zero_call = MethodCall(self._receiver, self._method, (), ()) - return _PendingMethodCall(zero_call, name) - - -class _Node: - def _bin(self, op: str, other: Any) -> BinOp: - return BinOp(op, self, other) - - def __add__(self, o: Any): - return self._bin("add", o) - - def __sub__(self, o: Any): - return self._bin("sub", o) - - def __mul__(self, o: Any): - return self._bin("mul", o) - - def __truediv__(self, o: Any): - return self._bin("div", o) - - def __radd__(self, o: Any): - return BinOp("add", o, self) - - def __rsub__(self, o: Any): - return BinOp("sub", o, self) - - def __rmul__(self, o: Any): - return BinOp("mul", o, self) - - def __rtruediv__(self, o: Any): - return BinOp("div", o, self) - - # Method-style arithmetic parity with ibis value expressions, e.g. t.x.add(1) - def add(self, other: Any) -> BinOp: - return self + other - - def sub(self, other: Any) -> BinOp: - return self - other - - def mul(self, other: Any) -> BinOp: - return self * other - - def div(self, other: Any) -> BinOp: - return self / other - - def __getattr__(self, name): - if name.startswith("_"): - raise AttributeError(f"'{type(self).__name__}' has no attribute {name!r}") - return _PendingMethodCall(self, name) - - -@frozen -class MeasureRef(_Node): - name: str - - -@frozen -class AllOf(_Node): - ref: MeasureRef - - -@frozen -class BinOp(_Node): - op: str - left: Any - right: Any - - -@frozen -class MethodCall(_Node): - receiver: Any - method: str - args: tuple = () - kwargs: tuple = () # tuple of (key, value) pairs - - -@frozen -class AggregationExpr(_Node): - column: str - operation: str - post_ops: tuple = field(default=(), converter=tuple) - - def __getattr__(self, name: str): - if name.startswith("_"): - raise AttributeError(f"AggregationExpr has no attribute {name!r}") - return AggregationExpr( - column=self.column, operation=self.operation, post_ops=self.post_ops + ((name, (), ()),) - ) - - def __call__(self, *args, **kwargs): - if args and hasattr(args[0], "columns"): - return self - - if not self.post_ops: - raise TypeError("Cannot call AggregationExpr with arguments when no post_ops exist") - - *rest, (method_name, _, _) = self.post_ops - return AggregationExpr( - column=self.column, - operation=self.operation, - post_ops=tuple(rest) + ((method_name, args, tuple(sorted(kwargs.items()))),), - ) - - -MeasureExpr = MeasureRef | AllOf | BinOp | MethodCall | AggregationExpr | float | int - - -def validate_calc_ast(expr: Any, measure_name: str | None = None) -> None: - """Walk a calc-measure AST and raise ``ValueError`` on illegal shapes. - - The AST nodes are unconstrained at construction (Any-typed fields), so - invalid compositions like ``AllOf(BinOp(...))`` parse but later fail - deep inside the compiler with confusing messages. Run this after - classification to surface the structural problem early, naming the - offending calc measure when known. - - ``AllOf.ref`` must be a ``MeasureRef`` or ``AggregationExpr``. Other - refs (BinOp, MethodCall, nested AllOf) are not supported by either - the direct compile path or the rewrite-then-compile pipeline in - ``compile_grouped_with_all``. - """ - where = f" in calc measure {measure_name!r}" if measure_name else "" - - def walk(node): - if isinstance(node, AllOf): - if not isinstance(node.ref, (MeasureRef, AggregationExpr)): - raise ValueError( - f"Invalid AllOf{where}: ref must be a measure reference or " - f"inline aggregation, got {type(node.ref).__name__}. " - f"Wrap it in a named measure first, e.g. " - f".with_measures(my_measure=...) then use t.all(t.my_measure)." - ) - walk(node.ref) - elif isinstance(node, BinOp): - walk(node.left) - walk(node.right) - elif isinstance(node, MethodCall): - walk(node.receiver) - for arg in node.args: - walk(arg) - - walk(expr) - - -class DeferredColumn: - _AGGREGATIONS = { - "sum": "sum", - "mean": "mean", - "avg": "mean", - "count": "count", - "min": "min", - "max": "max", - } - - def __init__(self, column_name: str, tbl: Any): - self._column_name = column_name - self._tbl = tbl - self._column = tbl[column_name] - - for method_name, operation in self._AGGREGATIONS.items(): - setattr( - self, - method_name, - lambda op=operation: AggregationExpr(column=self._column_name, operation=op), - ) - - def __getattr__(self, name): - return getattr(self._column, name) - - def __add__(self, other): - return self._column + other - - def __radd__(self, other): - return other + self._column - - def __sub__(self, other): - return self._column - other - - def __rsub__(self, other): - return other - self._column - - def __mul__(self, other): - return self._column * other - - def __rmul__(self, other): - return other * self._column - - def __truediv__(self, other): - return self._column / other - - def __rtruediv__(self, other): - return other / self._column - - def __eq__(self, other): - return self._column.__eq__(other) - - def __ne__(self, other): - return self._column.__ne__(other) - - def __lt__(self, other): - return self._column.__lt__(other) - - def __le__(self, other): - return self._column.__le__(other) - - def __gt__(self, other): - return self._column.__gt__(other) - - def __ge__(self, other): - return self._column.__ge__(other) - - -@curry -def _resolve_measure_name( - name: str, - known: tuple[str, ...], - known_set: frozenset[str], -) -> Maybe[str]: - if name in known_set: - return Some(name) - # Suffix matching: resolve unprefixed name to prefixed equivalent - suffix = f".{name}" - matches = tuple(k for k in known if k.endswith(suffix)) - if len(matches) == 1: - return Some(matches[0]) - return Maybe.from_optional(None) - - -def _make_known_measures( - measures: Iterable[str], -) -> tuple[tuple[str, ...], frozenset[str]]: - known_tuple = tuple(measures) if not isinstance(measures, tuple) else measures - return (known_tuple, frozenset(known_tuple)) - - def _resolve_column_short_name(tbl, name): - """Resolve a column name against a table, requiring fully qualified names after joins. + """Resolve a column name against a table. Tries direct column access first; falls back to ``getattr(tbl, name)`` - for ibis methods. Raises ``AttributeError`` with a helpful message - suggesting FQDN when the short name matches prefixed columns. + for ibis methods. Raises ``AttributeError`` with a helpful message + suggesting fully qualified names when the short name matches prefixed + columns. """ if hasattr(tbl, "columns") and name in tbl.columns: return tbl[name] @@ -337,12 +94,20 @@ def _resolve_column_short_name(tbl, name): def _resolve_column_item(tbl, name): - """Resolve a column name via bracket access, requiring fully qualified names after joins.""" return tbl[name] @frozen(kw_only=True, slots=True) class MeasureScope: + """Lookup proxy passed to measure / mutate lambdas. + + Compared with :class:`~boring_semantic_layer.calc_compiler.IbisCalcScope`, + this scope is a thin pass-through to the underlying ibis table. It is + still used by the post-aggregation ``SemanticMutateOp`` path (where + ``post_agg=True``) and by callers that want suffix-resolution of + measure names without virtual aggregated tables. + """ + tbl: Any = field(alias="_tbl") known: tuple[str, ...] = field(converter=tuple, alias="_known") known_set: frozenset[str] = field(init=False, alias="_known_set") @@ -351,31 +116,34 @@ class MeasureScope: def __attrs_post_init__(self): object.__setattr__(self, "known_set", frozenset(self.known)) + def _typo_suggestion(self, name: str) -> str | None: + cutoff = 0.80 + candidates: list[tuple[str, str]] = [] + if self.known: + for match in difflib.get_close_matches(name, self.known, n=3, cutoff=cutoff): + candidates.append(("measure", match)) + if hasattr(self.tbl, "columns"): + for match in difflib.get_close_matches( + name, list(self.tbl.columns), n=3, cutoff=cutoff + ): + candidates.append(("column", match)) + if not candidates: + return None + formatted = ", ".join(f"{kind} {match!r}" for kind, match in candidates) + return f"Did you mean: {formatted}?" + def __getattr__(self, name: str): if name.startswith("_"): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'", ) - if self.post_agg: - if _has_prefixed_columns(self.tbl, name): - return _ColumnPrefixProxy(self.tbl, name) - return _resolve_column_short_name(self.tbl, name) - - maybe_measure = _resolve_measure_name(name, self.known, self.known_set).map(MeasureRef) - if isinstance(maybe_measure, Some): - return maybe_measure.unwrap() - if hasattr(self.tbl, "columns") and name in self.tbl.columns: - return DeferredColumn(name, self.tbl) + return self.tbl[name] - # Support prefix navigation for joined tables (e.g., t.flights.carrier) if _has_prefixed_columns(self.tbl, name): return _ColumnPrefixProxy(self.tbl, name) - # Fall through to ibis (covers Table methods like ``count``, ``filter``). - # If ibis rejects too, surface a typo suggestion rather than the opaque - # ibis AttributeError so the user can see a "did you mean?" hint. try: return _resolve_column_short_name(self.tbl, name) except AttributeError: @@ -386,66 +154,22 @@ def __getattr__(self, name: str): ) from None raise - def _typo_suggestion(self, name: str) -> str | None: - # 0.80 catches single-character typos and case mistakes - # (``flight_konut`` vs ``flight_count`` ≈ 0.83) without flagging - # legitimate substring overlaps (``net_revenue`` vs - # ``total_net_revenue`` ≈ 0.79). Calibrated against real-world - # confusable measure names. - cutoff = 0.80 - candidates: list[tuple[str, str]] = [] - if self.known: - for match in difflib.get_close_matches(name, self.known, n=3, cutoff=cutoff): - candidates.append(("measure", match)) - if hasattr(self.tbl, "columns"): - for match in difflib.get_close_matches( - name, list(self.tbl.columns), n=3, cutoff=cutoff - ): - candidates.append(("column", match)) - if not candidates: - return None - formatted = ", ".join(f"{kind} {match!r}" for kind, match in candidates) - return f"Did you mean: {formatted}?" - def __getitem__(self, name: str): - if self.post_agg: - return _resolve_column_item(self.tbl, name) - - maybe_measure = _resolve_measure_name(name, self.known, self.known_set).map(MeasureRef) - if isinstance(maybe_measure, Some): - return maybe_measure.unwrap() return _resolve_column_item(self.tbl, name) def all(self, ref): from ._xorq import ibis as ibis_mod if isinstance(ref, str): - if self.post_agg: - return self.tbl[ref].sum().over(ibis_mod.window()) - - maybe_measure = _resolve_measure_name(ref, self.known, self.known_set).map( - lambda name: AllOf(MeasureRef(name)) - ) - if isinstance(maybe_measure, Some): - return maybe_measure.unwrap() return self.tbl[ref].sum().over(ibis_mod.window()) - if isinstance(ref, MeasureRef): - return AllOf(ref) - - if isinstance(ref, AggregationExpr): - return AllOf(ref) - if hasattr(ref, "__class__") and "ibis" in str(type(ref).__module__): if "Scalar" in type(ref).__name__: return ref.over(ibis_mod.window()) - else: - return ref.sum().over(ibis_mod.window()) + return ref.sum().over(ibis_mod.window()) raise TypeError( - "t.all(...) expects either a measure reference (e.g., t.flight_count), " - "a string measure name (e.g., 'flight_count'), an AggregationExpr, " - "or an ibis expression (e.g., t.distance.sum())", + "t.all(...) expects a string column name or an ibis expression", ) @@ -465,7 +189,6 @@ def __getattr__(self, name: str): proxy = create_table_proxy(self.tbl) return getattr(proxy, name) - # Support prefix navigation for joined tables (e.g., t.flights.carrier) if _has_prefixed_columns(self.tbl, name): return _ColumnPrefixProxy(self.tbl, name) @@ -480,16 +203,11 @@ def all(self, ref): if isinstance(ref, str): return self.tbl[ref].sum().over(ibis_mod.window()) - if isinstance(ref, AggregationExpr): - return AllOf(ref) - if hasattr(ref, "__class__") and "ibis" in str(type(ref).__module__): if "Scalar" in type(ref).__name__: return ref.over(ibis_mod.window()) - else: - return ref.sum().over(ibis_mod.window()) + return ref.sum().over(ibis_mod.window()) raise TypeError( - "t.all(...) expects either a string column name (e.g., 'flight_count'), " - "an AggregationExpr, or an ibis expression (e.g., t.distance.sum())", + "t.all(...) expects a string column name or an ibis expression", ) diff --git a/src/boring_semantic_layer/nested_compile.py b/src/boring_semantic_layer/nested_compile.py new file mode 100644 index 0000000..870e57f --- /dev/null +++ b/src/boring_semantic_layer/nested_compile.py @@ -0,0 +1,148 @@ +"""Helpers for compiling nested-array aggregations. + +The semantic layer supports measures that aggregate over nested array +columns (``t.hits.count()``, ``t.hits.value.sum()`` …). At compile time +each array path is unnested in isolation, aggregated at its own grain, +and joined back to the session-level result via the requested group-by +columns. These helpers used to live in ``compile_all.py`` alongside the +curated calc-measure compiler; that compiler is gone (replaced by the +ibis-native :mod:`calc_compiler`) so the nested-array machinery now sits +in its own module. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from functools import reduce +from typing import Any + +import ibis +from toolz import curry, pipe + + +def get_ibis_module(table): + """Return the ibis module that built ``table`` (regular vs xorq-vendored). + + BSL coexists with both flavors of ibis. Picking the right module avoids + cross-flavor literal/struct construction errors. + """ + table_module = type(table).__module__ + if table_module.startswith("xorq.vendor.ibis"): + from ._xorq import ibis as xorq_ibis + + return xorq_ibis + return ibis + + +@curry +def _extract_nested_array(prev_col: str, array_col: str, table): + if prev_col not in table.columns: + return table + prev_struct = table[prev_col] + if not hasattr(prev_struct, array_col): + return table + return table.mutate(**{array_col: getattr(prev_struct, array_col)}) + + +@curry +def _do_unnest_array(array_col: str, table): + return table.unnest(array_col) if array_col in table.columns else table + + +def unnest_nested_arrays(base_tbl, array_path: tuple[str, ...]): + """Apply unnest steps for each level of a nested array path.""" + sorted_path = tuple(sorted(array_path)) + + def unnest_step(table, indexed_col): + idx, array_col = indexed_col + if idx == 0: + return _do_unnest_array(array_col, table) + prev_col = sorted_path[idx - 1] + if array_col in table.columns: + return _do_unnest_array(array_col, table) + return pipe(table, _extract_nested_array(prev_col, array_col), _do_unnest_array(array_col)) + + return reduce(unnest_step, enumerate(sorted_path), base_tbl) + + +@curry +def _build_field_expr(array_path: tuple[str, ...], field_path: tuple[str, ...], unnested_tbl): + expr = getattr(unnested_tbl, array_path[0]) + if not field_path: + return expr + return reduce(lambda e, field: getattr(e, field), field_path, expr) + + +@curry +def _apply_aggregation(marker, expr): + if marker.operation == "count": + return expr.count() if hasattr(expr, "count") else expr + return getattr(expr, marker.operation)() + + +def build_nested_aggregation(unnested_tbl, marker) -> Any: + """Compile a single nested-array marker into an ibis aggregation.""" + if marker.operation == "count": + return unnested_tbl.count() + expr = _build_field_expr(marker.array_path, marker.field_path, unnested_tbl) + return _apply_aggregation(marker, expr) + + +def build_level_aggregations( + base_tbl, + array_path: tuple[str, ...], + measures: dict[str, tuple[Any, Any]], +) -> dict[str, Any]: + unnested_tbl = unnest_nested_arrays(base_tbl, array_path) + return { + name: build_nested_aggregation(unnested_tbl, marker) + for name, (_agg_fn, marker) in measures.items() + } + + +@curry +def _make_grouped_table(agg_dict: dict[str, Any], by_cols: Iterable[str], table): + group_exprs = [table[c] for c in by_cols] + return ( + table.group_by(group_exprs).aggregate(**agg_dict) + if group_exprs + else table.aggregate(**agg_dict) + ) + + +def build_session_table(base_tbl, by_cols: Iterable[str], regular_measures: dict) -> Any: + """Aggregate regular (non-nested) measures at the session grain.""" + if not regular_measures: + return None + session_aggs = {name: result for name, (_, result) in regular_measures.items()} + return _make_grouped_table(session_aggs, by_cols, base_tbl) + + +def build_nested_level_table( + base_tbl, + by_cols: Iterable[str], + array_path: tuple[str, ...], + measures: dict[str, tuple[Any, Any]], +): + """Aggregate nested-array measures at the unnested grain.""" + level_aggs = build_level_aggregations(base_tbl, array_path, measures) + unnested_tbl = unnest_nested_arrays(base_tbl, array_path) + return _make_grouped_table(level_aggs, by_cols, unnested_tbl) + + +def join_tables(by_cols: Iterable[str], tables: list) -> Any: + """Left-join a list of pre-aggregated tables on shared group-by columns.""" + if len(tables) == 0: + raise ValueError("Cannot join zero tables") + if len(tables) == 1: + return tables[0] + + by_cols_set = set(by_cols) + + def join_step(left, right): + predicates = [left[c] == right[c] for c in by_cols] + right_cols = [c for c in right.columns if c not in by_cols_set] + right_select = [right[c] for c in right_cols] + return left.left_join(right, predicates).select([left] + right_select) + + return reduce(join_step, tables[1:], tables[0]) diff --git a/src/boring_semantic_layer/ops.py b/src/boring_semantic_layer/ops.py index 038bd0c..678d095 100644 --- a/src/boring_semantic_layer/ops.py +++ b/src/boring_semantic_layer/ops.py @@ -45,16 +45,30 @@ def _reductions_for_expr(expr): from toolz import curry from . import projection_utils -from .compile_all import compile_grouped_with_all +from .calc_analyzer import analyze_calc_expr, virtual_agg_table +from .calc_compiler import ( + TOTALS_PREFIX, + IbisCalcScope, + TotalsNotAvailableError, + UnknownMeasureRefError, + _drop_totals_columns, + _join_totals, + _to_op, + apply_calc_measures, + attach_calc_totals, + attach_windowed_totals, + classify_calc_lambdas, + compile_calc_measure as _compile_calc_measure_impl, + compile_calc_measures, + evaluate_calc_lambda, + lift_inline_reductions, + rename_measure_refs, + topological_order_from_deps, +) from .graph_utils import walk_nodes from .measure_scope import ( - AggregationExpr, - AllOf, - BinOp, ColumnScope, - MeasureRef, MeasureScope, - MethodCall, ) from .nested_access import NestedAccessMarker @@ -518,19 +532,6 @@ def _get_merged_fields(all_roots: list, field_type: str) -> dict: ) -def _collect_measure_refs(expr, refs_out: set): - if isinstance(expr, MeasureRef): - refs_out.add(expr.name) - elif isinstance(expr, AllOf): - if isinstance(expr.ref, MeasureRef): - refs_out.add(expr.ref.name) - elif isinstance(expr, BinOp): - _collect_measure_refs(expr.left, refs_out) - _collect_measure_refs(expr.right, refs_out) - elif isinstance(expr, MethodCall): - _collect_measure_refs(expr.receiver, refs_out) - - def _extract_missing_column_name(exc: Exception) -> str | None: """Extract a missing column/attribute name from common resolution errors.""" message = str(exc) @@ -680,232 +681,125 @@ def _extract_measure_metadata( return (fn_or_expr, None, (), {}) -_AGG_METHODS = frozenset({ - # Standard reductions - "sum", "mean", "avg", "count", "min", "max", - # Statistical reductions - "var", "std", "median", "quantile", - # Approximate reductions - "approx_count_distinct", "approx_nunique", "approx_median", - # Distinct count - "nunique", - # Categorical reductions - "mode", "first", "last", "arbitrary", - # Boolean reductions - "any", "all", - # Collection reductions - "group_concat", "collect", -}) - - -def _is_calculated_measure(val: Any) -> bool: - # A MethodCall with an aggregation method on a MeasureRef is a base measure: - # the column name matched a known measure name in MeasureScope, but the user - # is really defining a column aggregation (e.g. lambda t: t.flight_count.sum() - # or t.distance.var()). ``_AGG_METHODS`` covers the ibis ``Reduction`` ops a - # user might reasonably reach for; methods outside the set fall through to - # the calc path. - if ( - isinstance(val, MethodCall) - and val.method in _AGG_METHODS - and isinstance(val.receiver, MeasureRef) - ): - return False - return isinstance(val, MeasureRef | AllOf | BinOp | MethodCall | int | float) - - -def _matches_aggregation_pattern(measure_expr, agg_expr, tbl): - if not isinstance(agg_expr, AggregationExpr): - return Success(False) - - @curry - def evaluate_in_scope(tbl, expr): - """Evaluate measure expression in a ColumnScope.""" - scope = ColumnScope(_tbl=tbl) - return ( - expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr - ) - - @curry - def has_matching_operation(agg_expr, result): - """Check if the operation matches the expected aggregation. - - All our supported aggregations (Sum, Mean, Count, Min, Max) are ibis operations. - """ - op_name = type(result.op()).__name__.lower() - expected_op = "avg" if agg_expr.operation.lower() == "mean" else agg_expr.operation.lower() - - return expected_op in op_name - - @curry - def has_matching_column(agg_expr, result): - """Check if result's operation references the expected column. - - All supported aggregation operations (Sum, Mean, Count, Min, Max) have: - - args[0]: Field operation with .name attribute - - args[1]: Optional where clause (typically None) - """ - op = result.op() - - if not isinstance(op.args[0], Field): - return False - - return op.args[0].name == agg_expr.column - - def matches_pattern(result): - """Check if result matches both operation and column.""" - return has_matching_operation(agg_expr, result) and has_matching_column(agg_expr, result) - - return safe(lambda: evaluate_in_scope(tbl, measure_expr))().map(matches_pattern) - - -def _find_matching_measure(agg_expr, known_measures: dict, tbl): - """Find a measure that matches the aggregation expression pattern. - - Returns Maybe[str] using functional patterns. - """ - if not isinstance(agg_expr, AggregationExpr): - return Nothing - - @curry - def matches_pattern(agg_expr, tbl, measure_obj): - """Check if measure matches the aggregation pattern. - - All measure_obj values are Measure instances with an expr attribute. - """ - result = _matches_aggregation_pattern(measure_obj.expr, agg_expr, tbl) - return result.value_or(False) - - for measure_name, measure_obj in known_measures.items(): - if matches_pattern(agg_expr, tbl, measure_obj): - return Some(measure_name) - - return Nothing - - def _make_base_measure( expr: Any, description: str | None, requires_unnest: tuple, metadata: Mapping[str, Any] | None = None, ) -> Measure: - """Create a base measure with proper callable wrapping using functional patterns.""" - - @curry - def apply_aggregation(operation: str, column): - """Apply aggregation operation to a column using functional dispatch.""" - operations = { - "sum": lambda c: c.sum(), - "mean": lambda c: c.mean(), - "avg": lambda c: c.mean(), - "count": lambda c: c.count(), - "min": lambda c: c.min(), - "max": lambda c: c.max(), - } - - return ( - Maybe.from_optional(operations.get(operation)) - .map(lambda fn: fn(column)) - .value_or( - (_ for _ in ()).throw(ValueError(f"Unknown aggregation operation: {operation}")) - ) - ) - - @curry - def evaluate_expr(expr, scope): - """Evaluate expression in given scope.""" - return ( - expr.resolve(scope) if _is_deferred(expr) else expr(scope) if callable(expr) else expr - ) - - def convert_aggregation_expr(t, agg_expr: AggregationExpr): - """Convert AggregationExpr to ibis expression.""" - if agg_expr.operation == "count": - result = t.count() - else: - result = apply_aggregation(agg_expr.operation, t[agg_expr.column]) - - for method_name, args, kwargs_tuple in agg_expr.post_ops: - result = getattr(result, method_name)(*args, **dict(kwargs_tuple)) - - return result + """Wrap a base-measure callable as a :class:`Measure`. + The lambda is invoked against a :class:`ColumnScope` so that nested + array columns (``t.hits.count()`` over an array) surface as + ``NestedAccessMarker`` values for the nested-aggregation pipeline. + Plain reductions (``t.distance.sum()``) flow through unchanged. + """ raw_expr = expr._fn if isinstance(expr, _CallableWrapper) else expr - if isinstance(expr, AggregationExpr): - - def wrapped_expr(t): - """Convert AggregationExpr to ibis expression.""" - return convert_aggregation_expr(t, expr) - - return Measure( - expr=wrapped_expr, - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), - ) - - if callable(expr): - - def wrapped_expr(t): - """Wrapped expression that handles AggregationExpr conversion.""" - scope = ColumnScope(_tbl=t) - result = evaluate_expr(expr, scope) - - if isinstance(result, AggregationExpr): - return convert_aggregation_expr(t, result) - return result - - return Measure( - expr=wrapped_expr, - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), - ) + if _is_deferred(expr): + wrapped = lambda t, fn=expr: fn.resolve(ColumnScope(_tbl=t)) + elif callable(expr): + wrapped = lambda t, fn=expr: fn(ColumnScope(_tbl=t)) else: - return Measure( - expr=lambda t, fn=expr: evaluate_expr(fn, ColumnScope(_tbl=t)), - description=description, - requires_unnest=requires_unnest, - original_expr=raw_expr, - metadata=dict(metadata or {}), - ) + wrapped = lambda t, v=expr: v + + return Measure( + expr=wrapped, + description=description, + requires_unnest=requires_unnest, + original_expr=raw_expr, + metadata=dict(metadata or {}), + ) def _classify_measure( fn_or_expr: Any, scope: Any, measure_name: str | None = None ) -> tuple[str, Any]: - """Classify measure as 'calc' or 'base' with appropriate handling.""" - from .measure_scope import UnknownMeasureRefError, validate_calc_ast + """Classify a measure lambda as ``base`` or ``calc``. + Runs the lambda once against an :class:`IbisCalcScope`, then walks + the resulting ibis tree with :func:`analyze_calc_expr`. Pushable + expressions become base measures (the same lambda runs at agg time + against the raw ibis table). Post-aggregation expressions become + :class:`CalcMeasure` records that re-evaluate at query time. + + The legacy ``MeasureScope`` is accepted as the scope argument for + backwards compatibility with call sites — only its ``tbl`` and + ``known`` fields are read. + """ expr, description, requires_unnest, metadata = _extract_measure_metadata(fn_or_expr) - # ``_resolve_expr`` may raise for legitimate base-measure shapes - # (e.g. lambdas that touch ibis methods MeasureScope can't reflect), - # so most exceptions are caught and the lambda falls through to - # base classification. ``UnknownMeasureRefError`` is the typo case - # though — surface it loudly instead of letting the lambda fail with - # an opaque error at execute time. + base_tbl = getattr(scope, "tbl", None) + if base_tbl is None: + base_tbl = getattr(scope, "_tbl", None) + known = getattr(scope, "known", None) + if known is None: + known = getattr(scope, "_known", ()) + known_set = frozenset(known) + + # Pure constants fold into both grouped and ungrouped contexts. + if isinstance(expr, (int, float)) and not isinstance(expr, bool): + return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) + + if base_tbl is None: + return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) + + # Build virtual aggregated table schema from already-known measures. + # The dtypes are placeholders — the analyzer cares about structure. + virtual_schema = {name: "float64" for name in known_set} + try: - resolved_value = _resolve_expr(expr, scope) + ibis_expr, vt, totals_vt = evaluate_calc_lambda( + expr, base_tbl, known_set, virtual_schema + ) except UnknownMeasureRefError: raise except Exception: - resolved_value = None - - if resolved_value is not None and _is_calculated_measure(resolved_value): - validate_calc_ast(resolved_value, measure_name) - return ("calc", resolved_value) + # Could not evaluate against the analyzer scope (e.g. lambda + # uses backend-specific methods MeasureScope didn't reflect). + # Fall back to base classification — the lambda runs verbatim + # against the raw ibis table at agg time. + if not requires_unnest and callable(expr): + inferred_unnest = _infer_unnest(expr, base_tbl) + requires_unnest = requires_unnest or inferred_unnest + return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) + + base_op = base_tbl.op() if hasattr(base_tbl, "op") and callable(base_tbl.op) else None + totals_op = totals_vt.op() if hasattr(totals_vt, "op") and callable(totals_vt.op) else None + analysis = analyze_calc_expr( + ibis_expr, + known_measures=known_set, + base_table_op=base_op, + totals_vt_op=totals_op, + ) - if not requires_unnest and callable(expr): - # All scopes (MeasureScope, ColumnScope) have tbl attribute - table = scope.tbl - inferred_unnest = _infer_unnest(expr, table) - requires_unnest = requires_unnest or inferred_unnest + if analysis.pushable or analysis.post_agg_only is False: + # ``post_agg_only=False`` without ``pushable`` means no window / + # AllOf / measure deps but the expression touched multiple source + # tables. Routing to base lets the lambda run verbatim at agg + # time; if it really does span tables, ibis will surface the + # error there. Log so the silent fallthrough is visible. + if not analysis.pushable: + logger.debug( + "calc-measure %r references multiple source tables but no measures; " + "routing to base classification — ibis will validate at agg time.", + measure_name, + ) + if not requires_unnest and callable(expr): + inferred_unnest = _infer_unnest(expr, base_tbl) + requires_unnest = requires_unnest or inferred_unnest + return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) - return ("base", _make_base_measure(expr, description, requires_unnest, metadata)) + return ( + "calc", + CalcMeasure( + expr=expr, + description=description, + requires_unnest=requires_unnest, + depends_on=analysis.depends_on, + metadata=metadata, + ), + ) def _build_json_definition( @@ -1110,6 +1004,35 @@ def __hash__(self) -> int: return hash((self.description, self.requires_unnest)) +@frozen(kw_only=True, slots=True) +class CalcMeasure: + """Stored representation of a calc (post-aggregation) measure. + + Holds the user's original lambda — the analyzer-classified ibis + expression is recomputed from the lambda at query time against the + actual base table. ``depends_on`` is captured at classification time + so the planner can auto-include base-measure dependencies in + aggregations even when the user did not request them explicitly. + """ + + expr: Any # callable | Deferred + description: str | None = None + requires_unnest: tuple[str, ...] = () + depends_on: frozenset[str] = field(factory=frozenset, converter=frozenset) + metadata: Mapping[str, Any] = field(factory=dict, eq=False, hash=False) + + def to_json(self) -> Mapping[str, Any]: + base = {"description": self.description} + if self.requires_unnest: + base["requires_unnest"] = list(self.requires_unnest) + if self.metadata: + base.update(self.metadata) + return base + + def __hash__(self) -> int: + return hash((self.description, self.requires_unnest, self.depends_on)) + + class SemanticTableOp(Relation): """Relation with semantic metadata (dimensions and measures). @@ -1170,27 +1093,24 @@ def values(self) -> FrozenOrderedDict[str, Any]: **{name: enriched[name].op() for name in dims}, **{name: fn(enriched).op() for name, fn in measures.items()}, } - # Resolve calculated measure types via a dummy table with base measure dtypes. - # ``infer_calc_dtype`` mirrors the AggregationExpr rewrite from - # ``compile_grouped_with_all`` so calc measures with inline aggregations - # (e.g. ``AllOf(AggregationExpr)``) round-trip through type inference. + # Calc measures are stored as ``CalcMeasure`` objects holding the + # original lambda. Re-run each one against an ``IbisCalcScope`` + # over ``enriched`` plus a virtual aggregated table whose schema + # mirrors the base measures. Type inference falls out of ibis + # naturally; failures are best-effort. if calc_measures: - from .compile_all import _get_ibis_module, infer_calc_dtype - measure_schema = { name: base_values[name].dtype for name in measures if name in base_values } - ibis_module = _get_ibis_module(enriched) - for name, expr in calc_measures.items(): + known_set = frozenset(measures.keys()) | frozenset(calc_measures.keys()) + for name, calc in calc_measures.items(): + fn = calc.expr if isinstance(calc, CalcMeasure) else calc try: - compiled = infer_calc_dtype( - expr, measure_schema, enriched, ibis_module + expr, _vt, _tvt = evaluate_calc_lambda( + fn, enriched, known_set, measure_schema ) - base_values[name] = compiled.op() + base_values[name] = expr.op() except Exception as e: - # Joined models with dotted column names, calc measures - # whose inline aggregations don't apply to the dummy schema, - # etc. Type info is best-effort; surface for debugging. logger.debug( "calc-measure type inference failed for %r: %s", name, e ) @@ -1512,245 +1432,521 @@ class _AggregationPlan: group_by_cols: tuple[str, ...] -def _resolve_aggregation_exprs( - expr: Any, - merged_base_measures: dict, - merged_calc_measures: dict, - tbl: ir.Table, -) -> Any: - @curry - def find_in_calc_measures(expr, calc_measures): - for calc_name, calc_expr in calc_measures.items(): - if isinstance(calc_expr, AggregationExpr) and ( - calc_expr.column == expr.column and calc_expr.operation == expr.operation - ): - return Some(calc_name) - return Nothing - - def resolve_aggregation(agg_expr): - matched = _find_matching_measure(agg_expr, merged_base_measures, tbl) - return matched.map(MeasureRef).value_or( - find_in_calc_measures(agg_expr, merged_calc_measures).map(MeasureRef).value_or(agg_expr) - ) - - if isinstance(expr, AggregationExpr): - return resolve_aggregation(expr) - elif isinstance(expr, MethodCall): - return MethodCall( - receiver=_resolve_aggregation_exprs( - expr.receiver, merged_base_measures, merged_calc_measures, tbl - ), - method=expr.method, - args=expr.args, - kwargs=expr.kwargs, - ) - elif isinstance(expr, BinOp): - return BinOp( - op=expr.op, - left=_resolve_aggregation_exprs( - expr.left, merged_base_measures, merged_calc_measures, tbl - ), - right=_resolve_aggregation_exprs( - expr.right, merged_base_measures, merged_calc_measures, tbl - ), - ) - elif isinstance(expr, AllOf) and isinstance(expr.ref, AggregationExpr): - return AllOf(resolve_aggregation(expr.ref)) - else: - return expr +def _make_agg_callable(measure: Any) -> Callable: + """Wrap a base-measure value into a callable that returns an ibis aggregation. + + ``Measure.expr`` is already wrapped with ``ColumnScope`` inside + :func:`_make_base_measure`, so ``Measure`` instances and raw callables + (e.g. lifted-reduction stubs that close over a pre-built ibis op) are + invoked with the raw ibis table directly. Only ``Deferred`` values + are resolved through ``ColumnScope`` here, since they have no other + way to bind to the table. + """ + if _is_deferred(measure): + return lambda t: measure.resolve(ColumnScope(_tbl=t)) + if isinstance(measure, Measure): + return lambda t: measure(t) + if callable(measure): + return lambda t: measure(t) + return lambda t, v=measure: v -def _create_measure_spec( - name: str, - fn_wrapped: Any, +def _build_aggregation_plan( + aggs: dict, + keys: tuple, scope: Any, is_post_agg: bool, merged_base_measures: dict, merged_calc_measures: dict, tbl: ir.Table, -) -> _MeasureSpec: - fn = _unwrap(fn_wrapped) - val = _resolve_expr(fn, scope) - val = _resolve_aggregation_exprs(val, merged_base_measures, merged_calc_measures, tbl) - - if is_post_agg: - return _MeasureSpec(name=name, kind="agg", value=fn) - - if isinstance(val, MeasureRef): - ref_name = val.name - if ref_name in merged_calc_measures: - calc_expr = merged_calc_measures[ref_name] - resolved = _resolve_aggregation_exprs( - calc_expr, merged_base_measures, merged_calc_measures, tbl - ) - return _MeasureSpec(name=name, kind="calc", value=resolved) - elif ref_name in merged_base_measures: - return _MeasureSpec(name=name, kind="agg", value=merged_base_measures[ref_name]) - else: - return _MeasureSpec(name=name, kind="calc", value=val) +) -> _AggregationPlan: + """Split requested aggregations into base aggs and calc-measure lambdas. - if isinstance(val, AllOf | BinOp | MethodCall | int | float): - return _MeasureSpec(name=name, kind="calc", value=val) + Each entry in ``aggs`` is a callable. We resolve it once against the + measure scope to determine whether it refers to a base measure (yields + a ``Measure``-like callable that produces an ibis aggregation) or a + calc measure (a ``CalcMeasure`` recorded in ``merged_calc_measures`` + or an inline post-aggregation expression). + + Inline ad-hoc lambdas that look like calc expressions (use + ``t.measure_name`` or ``t.all(...)``) are classified on the fly via + :func:`_classify_measure` and routed to ``calc_specs``. + """ + agg_specs: dict[str, Callable] = {} + calc_specs: dict[str, CalcMeasure] = {} + + base_tbl = getattr(scope, "tbl", None) + if base_tbl is None: + base_tbl = getattr(scope, "_tbl", None) + if base_tbl is None: + base_tbl = tbl + known_set = frozenset(merged_base_measures) | frozenset(merged_calc_measures) + + for name, fn_wrapped in aggs.items(): + fn = _unwrap(fn_wrapped) + + if is_post_agg: + # Wrap raw user callables with ColumnScope (via Measure) so a + # re-aggregation lambda like ``t.flights.carrier.nunique()`` + # routes through the NestedAccessMarker pipeline in + # _compile_aggregation. Without the wrap, t.flights returns a + # raw ArrayColumn and struct-field access blows up before the + # marker can be produced. + if ( + callable(fn) + and not _is_deferred(fn) + and not isinstance(fn, Measure) + ): + fn = _make_base_measure(fn, None, (), {}) + agg_specs[name] = _make_agg_callable(fn) + continue - return _MeasureSpec(name=name, kind="agg", value=fn) + # Recognize bare-name lambdas (``lambda t, n=name: t[n]``) that + # the SemanticAggregate.aggregate API generates for measure + # lookups by name. These should resolve to the named measure, + # suffix-matching prefixed names on joined models. + ref_name = _detect_bare_name_lambda(fn) + if ref_name is not None: + resolved = _resolve_short_name(ref_name, merged_base_measures, merged_calc_measures) + if resolved is not None: + if resolved in merged_base_measures: + agg_specs[name] = _make_agg_callable(merged_base_measures[resolved]) + continue + if resolved in merged_calc_measures: + calc_specs[name] = merged_calc_measures[resolved] + continue + # Otherwise classify the inline lambda on the fly. + kind, value = _classify_measure(fn, scope, name) + if kind == "calc": + calc_specs[name] = value + else: + agg_specs[name] = _make_agg_callable(value) + + # Auto-include base-measure dependencies referenced by calc measures + # so the aggregation produces the columns the calc lambdas read. + # Walk transitively so calc-of-calc chains pull all needed bases. + if calc_specs: + def _resolve_dep(ref: str) -> str | None: + """Resolve a dependency name against base/calc measures. + + On joined models, calc measures captured ``depends_on`` with + short names (``flight_count``); the merged dictionaries hold + prefixed names (``flights.flight_count``). Suffix-match when + the exact name is missing. + """ + if ref in merged_base_measures or ref in merged_calc_measures: + return ref + suffix = f".{ref}" + base_matches = [k for k in merged_base_measures if k.endswith(suffix)] + calc_matches = [k for k in merged_calc_measures if k.endswith(suffix)] + matches = base_matches + calc_matches + if len(matches) == 1: + return matches[0] + return None -def _make_agg_callable(measure: Any) -> Callable: - if _is_deferred(measure): - return lambda t: measure.resolve(ColumnScope(_tbl=t)) - elif callable(measure): - return lambda t: measure(ColumnScope(_tbl=t)) - else: - return lambda t: measure(t) + worklist = list(calc_specs.values()) + seen_calcs: set[str] = set(calc_specs.keys()) + while worklist: + cm = worklist.pop() + for ref in cm.depends_on: + resolved_ref = _resolve_dep(ref) + if resolved_ref is None or resolved_ref in agg_specs: + continue + if resolved_ref in merged_base_measures: + agg_specs[resolved_ref] = _make_agg_callable( + merged_base_measures[resolved_ref] + ) + elif resolved_ref in merged_calc_measures and resolved_ref not in seen_calcs: + dep_cm = merged_calc_measures[resolved_ref] + calc_specs[resolved_ref] = dep_cm + seen_calcs.add(resolved_ref) + if isinstance(dep_cm, CalcMeasure): + worklist.append(dep_cm) + return _AggregationPlan( + agg_specs=FrozenDict(agg_specs), + calc_specs=FrozenDict(calc_specs), + requested_measures=tuple(aggs.keys()), + group_by_cols=tuple(keys), + ) -def _collect_all_measure_refs(calc_exprs) -> frozenset[str]: - all_refs = set() - for expr in calc_exprs: - _collect_measure_refs(expr, all_refs) - return frozenset(all_refs) +def _compile_aggregation( + base_tbl, + by_cols: list[str], + agg_specs: dict[str, Callable], + calc_specs: dict[str, CalcMeasure], + known_measures: frozenset[str], + requested_measures: list[str] | None = None, +): + """Run base aggregations on ``base_tbl``, then apply calc measures. + + Replaces the legacy ``compile_grouped_with_all`` pipeline. Calc + measures are recomputed at query time by re-running their lambda + against an :class:`IbisCalcScope` over ``base_tbl`` plus a virtual + aggregated table that mirrors the real result schema. Nested-array + aggregations surface as :class:`NestedAccessMarker` values and are + routed through :func:`_compile_aggregation_with_nested`. + """ + # --- Pre-process calc specs --------------------------------------- + # Run the analyzer once per calc, then route inline reductions + # through the lift pass. ``lifted_calc_specs[name]`` carries the + # rewritten expression and the virtual tables it references; + # ``classifications[name]`` carries the structural analysis. + # ``None`` lift means the lambda blew up — we'll re-evaluate from + # scratch in the apply loop. + # + # Build the virtual schema with *real* dtypes derived from the base + # aggregations. Using a placeholder dtype (``float64`` for + # everything) lets ibis silently elide ``column.cast(float64)`` as a + # no-op during ``evaluate_calc_lambda``; after the substitution + # ``Field(virtual_agg) → Field(real_agg)`` the Cast is gone but the + # real column is int64, so ``int / int * 100`` returns 0. Probing + # ``agg_specs[n](base_tbl).type()`` gives the analyzer the same + # dtype the user's calc will see at compile time. + base_op = _to_op(base_tbl) + virtual_schema_real: dict[str, Any] = {} + for n in known_measures: + if n in agg_specs: + try: + virtual_schema_real[n] = agg_specs[n](base_tbl).type() + except Exception as exc: + logger.debug( + "could not probe dtype for measure %r; falling back to float64: %s", + n, + exc, + ) + virtual_schema_real[n] = "float64" + else: + virtual_schema_real[n] = "float64" + + lifted_calc_specs: dict[str, tuple[Any, Any, Any] | None] = {} + classifications: dict[str, Any] = {} + preproc_errors: dict[str, Exception] = {} + needs_totals = False + if calc_specs: + for name, cm in calc_specs.items(): + try: + virtual_schema = dict(virtual_schema_real) + expr, vt, totals_vt = evaluate_calc_lambda( + cm.expr, base_tbl, known_measures, virtual_schema + ) + new_expr, new_vt, new_totals_vt, lifted = lift_inline_reductions( + expr, vt, base_tbl, totals_virtual_agg_tbl=totals_vt + ) + analysis = analyze_calc_expr( + new_expr, + known_measures=known_measures, + base_table_op=base_op, + totals_vt_op=_to_op(new_totals_vt), + ) + lifted_calc_specs[name] = (new_expr, new_vt, new_totals_vt) + classifications[name] = analysis + if analysis.references_AllOf: + needs_totals = True + for anon_name, reduction_expr in lifted.items(): + if anon_name not in agg_specs: + agg_specs[anon_name] = lambda t, r=reduction_expr: r + except Exception as exc: + logger.debug( + "calc-measure lift/classify failed for %r; will re-evaluate " + "at apply time: %s", + name, + exc, + ) + lifted_calc_specs[name] = None + preproc_errors[name] = exc -def _expand_calc_measure_refs( - expr: Any, - merged_base_measures: dict, - merged_calc_measures: dict, - tbl: ir.Table, - cache: dict[str, Any] | None = None, - path: tuple[str, ...] = (), -) -> Any: - """Inline calc-measure references transitively for multi-layer formulas.""" - cache = {} if cache is None else cache - - def _lift_to_allof(value: Any) -> Any: - """Lift an expanded expression into totals-space via AllOf on refs.""" - if isinstance(value, MeasureRef): - return AllOf(value) - if isinstance(value, BinOp): - return BinOp( - op=value.op, - left=_lift_to_allof(value.left), - right=_lift_to_allof(value.right), - ) - if isinstance(value, MethodCall): - return MethodCall( - receiver=_lift_to_allof(value.receiver), - method=value.method, - args=value.args, - kwargs=value.kwargs, + nested_marker_specs: dict[str, Any] = {} + regular_specs: dict[str, Callable] = {} + for name, fn in agg_specs.items(): + try: + probe = fn(base_tbl) + except Exception: + regular_specs[name] = fn + continue + if isinstance(probe, NestedAccessMarker): + nested_marker_specs[name] = probe + else: + regular_specs[name] = fn + + # --- Attach windowed totals to base ------------------------------ + # When any calc references ``t.all(measure_ref)``, compute that + # measure's formula as a window function over the entire base + # *before* group_by, then carry it through the per-group aggregation + # via ``arbitrary()``. This expresses "ungrouped aggregate alongside + # grouped one" as a single-pass query — no cross-join, no + # shared-ancestor collapse, compiles to SQL on every backend + # supporting window functions. Skipped on the nested-array path: + # totals across multiple grains aren't well-defined; we surface a + # clear error in the apply loop. + # + # Calc-of-calc-AllOf (an AllOf-using calc that references a calc, + # not a base measure — e.g. ``t.all(t.avg_distance)`` where + # ``avg_distance`` is itself a calc) is handled in two passes: + # first transitively expand to the base measures; attach window + # totals for those; then post-aggregation derive the calc's totals + # value via :func:`attach_calc_totals`. + totals_arbitrary_specs: dict[str, Callable] = {} + if needs_totals and regular_specs and not nested_marker_specs: + totals_for_base: set[str] = set() + # Transitive expansion: for AllOf-using calcs, follow calc deps + # through ``classifications`` until we land on base measures. + work: list[str] = [] + for cn, c in classifications.items(): + if c.references_AllOf: + for d in c.depends_on: + if d in regular_specs: + totals_for_base.add(d) + elif d in calc_specs: + work.append(d) + seen: set[str] = set() + while work: + calc_dep = work.pop() + if calc_dep in seen: + continue + seen.add(calc_dep) + cls = classifications.get(calc_dep) + if cls is None: + continue + for d in cls.depends_on: + if d in regular_specs: + totals_for_base.add(d) + elif d in calc_specs: + work.append(d) + + if totals_for_base: + base_tbl, totals_arbitrary_specs = attach_windowed_totals( + base_tbl, regular_specs, totals_for_base, TOTALS_PREFIX ) - return value - - if isinstance(expr, MeasureRef): - ref_name = expr.name - if ref_name not in merged_calc_measures: - return expr - if ref_name in cache: - return cache[ref_name] - if ref_name in path: - cycle = " -> ".join((*path, ref_name)) - raise ValueError(f"Circular calculated measure dependency detected: {cycle}") - - resolved = _resolve_aggregation_exprs( - merged_calc_measures[ref_name], merged_base_measures, merged_calc_measures, tbl + + if not nested_marker_specs: + if by_cols or regular_specs or totals_arbitrary_specs: + agg_exprs = {n: f(base_tbl) for n, f in regular_specs.items()} + for tn, tf in totals_arbitrary_specs.items(): + agg_exprs[tn] = tf(base_tbl) + if by_cols: + real_agg_tbl = base_tbl.group_by([base_tbl[c] for c in by_cols]).aggregate( + **agg_exprs + ) + else: + real_agg_tbl = base_tbl.aggregate(**agg_exprs) + else: + real_agg_tbl = base_tbl.aggregate() + else: + real_agg_tbl = _compile_aggregation_with_nested( + base_tbl, by_cols, regular_specs, nested_marker_specs ) - expanded = _expand_calc_measure_refs( - resolved, - merged_base_measures, - merged_calc_measures, - tbl, - cache, - (*path, ref_name), + + # --- Derive calc-of-calc totals ---------------------------------- + # If any AllOf-using calc references another calc (transitively), + # the windowed-totals pass attached only the base totals. Now that + # ``real_agg_tbl`` has those base totals as columns, we evaluate + # each needed calc lambda against the totals columns to derive the + # calc's totals value (constant across rows). + if calc_specs and totals_arbitrary_specs: + real_agg_tbl = attach_calc_totals( + real_agg_tbl, calc_specs, classifications, TOTALS_PREFIX ) - cache[ref_name] = expanded - return expanded - if isinstance(expr, MethodCall): - return MethodCall( - receiver=_expand_calc_measure_refs( - expr.receiver, merged_base_measures, merged_calc_measures, tbl, cache, path - ), - method=expr.method, - args=expr.args, - kwargs=expr.kwargs, + # --- Apply calc measures ----------------------------------------- + if calc_specs: + # ``real_agg_tbl`` already carries ``__bsl_totals__`` + # columns when totals were attached above. Calc compilation + # rewrites ``Field(totals_vt, name) → Field(real_agg, "__bsl_totals__")`` + # directly; no separate cross-joined table is needed. + real_with_totals = real_agg_tbl if totals_arbitrary_specs else None + cur_known = known_measures | frozenset(calc_specs.keys()) + + ordered = _topological_calc_order(calc_specs, base_tbl, known_measures) + for name in ordered: + spec = lifted_calc_specs.get(name) + if spec is None: + # Lift failed at preprocessing; re-evaluate AND re-lift so + # inline base reductions (``t.distance.sum() / t.all(...)``) + # don't reach _compile_calc_measure_impl as bare base + # reductions ibis can't compile through mutate. + fn = calc_specs[name].expr + virtual_schema = { + col: real_agg_tbl[col].type() + for col in real_agg_tbl.columns + if col in cur_known + } + expr0, vt0, totals_vt0 = evaluate_calc_lambda( + fn, base_tbl, cur_known, virtual_schema + ) + rewritten_expr, rewritten_vt, rewritten_totals_vt, lifted = ( + lift_inline_reductions( + expr0, vt0, base_tbl, totals_virtual_agg_tbl=totals_vt0 + ) + ) + if lifted: + # The lift produced anonymous base reductions that + # would need to be added to the per-group aggregation, + # but that has already been built. Surface the original + # preprocessing failure rather than letting unbound + # Field references reach ibis. + orig = preproc_errors.get(name) + raise RuntimeError( + f"Calc measure {name!r} contains inline base reductions " + "that could not be lifted at preprocessing time." + ) from orig + analysis = analyze_calc_expr( + rewritten_expr, + known_measures=known_measures, + base_table_op=base_op, + totals_vt_op=_to_op(rewritten_totals_vt), + ) + else: + rewritten_expr, rewritten_vt, rewritten_totals_vt = spec + analysis = classifications[name] + + if analysis.references_AllOf: + if real_with_totals is None: + raise TotalsNotAvailableError( + f"Calc measure {name!r} references t.all(...) but no totals " + "columns were attached. This typically means the model contains " + "nested-array measures (which compile at multiple grains and " + "don't yet support totals), or the AllOf reference targets a " + "calc measure rather than a base measure (calc-of-calc-totals " + "is not yet supported via the windowed-totals path)." + ) + compiled = _compile_calc_measure_impl( + rewritten_expr, + rewritten_vt, + real_agg_tbl, + totals_virtual_agg_tbl=rewritten_totals_vt, + real_with_totals=real_agg_tbl, + ) + real_agg_tbl = real_agg_tbl.mutate(**{name: compiled}) + real_with_totals = real_agg_tbl + else: + compiled = _compile_calc_measure_impl( + rewritten_expr, rewritten_vt, real_agg_tbl + ) + real_agg_tbl = real_agg_tbl.mutate(**{name: compiled}) + if real_with_totals is not None: + real_with_totals = real_agg_tbl + + # Drop the synthetic ``__bsl_totals__`` columns so the + # result schema only carries user-requested measures. + if calc_specs: + real_agg_tbl = _drop_totals_columns(real_agg_tbl, TOTALS_PREFIX) + + if requested_measures is not None: + select_cols = list( + dict.fromkeys( + list(by_cols) + list(requested_measures) + list(calc_specs.keys()) + ) ) + available = frozenset(real_agg_tbl.columns) + select_cols = [c for c in select_cols if c in available] + if select_cols: + real_agg_tbl = real_agg_tbl.select([real_agg_tbl[c] for c in select_cols]) - if isinstance(expr, BinOp): - return BinOp( - op=expr.op, - left=_expand_calc_measure_refs( - expr.left, merged_base_measures, merged_calc_measures, tbl, cache, path - ), - right=_expand_calc_measure_refs( - expr.right, merged_base_measures, merged_calc_measures, tbl, cache, path - ), + return real_agg_tbl + + +def _compile_aggregation_with_nested( + base_tbl, + by_cols: list[str], + regular_specs: dict[str, Callable], + nested_specs: dict[str, Any], +): + """Compile aggregations when nested-array measures are present. + + Each array path is unnested in isolation, aggregated at its own + grain, and joined back to the session-level table on ``by_cols``. + The new calc-compiler path layers on top of the resulting joined + table via :func:`apply_calc_measures`. + """ + from .nested_compile import ( + build_nested_level_table, + build_session_table, + join_tables, + ) + + nested_by_path: dict[tuple[str, ...], dict[str, tuple]] = {} + for name, marker in nested_specs.items(): + nested_by_path.setdefault(marker.array_path, {})[name] = ( + regular_specs.get(name) or (lambda t, m=marker: m), + marker, ) - if isinstance(expr, AllOf): - if isinstance(expr.ref, MeasureRef): - expanded_ref = _expand_calc_measure_refs( - expr.ref, merged_base_measures, merged_calc_measures, tbl, cache, path - ) - if isinstance(expanded_ref, MeasureRef): - return AllOf(expanded_ref) - return _lift_to_allof(expanded_ref) - return expr + result_tables: list = [] + if regular_specs: + regular_results = {n: (f, f(base_tbl)) for n, f in regular_specs.items()} + session_table = build_session_table(base_tbl, by_cols, regular_results) + if session_table is not None: + result_tables.append(session_table) - return expr + for array_path, measures in nested_by_path.items(): + level_table = build_nested_level_table(base_tbl, by_cols, array_path, measures) + result_tables.append(level_table) + if not result_tables: + if by_cols: + return base_tbl.group_by([base_tbl[c] for c in by_cols]).aggregate() + return base_tbl.aggregate() -def _build_aggregation_plan( - aggs: dict, - keys: tuple, - scope: Any, - is_post_agg: bool, + return join_tables(by_cols, result_tables) + + +def _resolve_short_name( + name: str, merged_base_measures: dict, merged_calc_measures: dict, - tbl: ir.Table, -) -> _AggregationPlan: - specs = [ - _create_measure_spec( - name, fn, scope, is_post_agg, merged_base_measures, merged_calc_measures, tbl - ) - for name, fn in aggs.items() - ] - - agg_specs_list = [s for s in specs if s.kind == "agg"] - calc_specs_list = [s for s in specs if s.kind == "calc"] - - agg_specs = FrozenDict({s.name: _make_agg_callable(s.value) for s in agg_specs_list}) - calc_specs = FrozenDict({s.name: s.value for s in calc_specs_list}) - - calc_cache: dict[str, Any] = {} - expanded_calc_specs = FrozenDict( - { - name: _expand_calc_measure_refs( - expr, - merged_base_measures, - merged_calc_measures, - tbl, - cache=calc_cache, - path=(name,), - ) - for name, expr in calc_specs.items() - } - ) +) -> str | None: + """Match ``name`` against merged measure dicts, allowing suffix lookup.""" + if name in merged_base_measures or name in merged_calc_measures: + return name + suffix = f".{name}" + matches = [k for k in merged_base_measures if k.endswith(suffix)] + matches += [k for k in merged_calc_measures if k.endswith(suffix)] + if len(matches) == 1: + return matches[0] + return None - referenced = _collect_all_measure_refs(expanded_calc_specs.values()) - additional_aggs = { - ref: _make_agg_callable(merged_base_measures[ref]) - for ref in referenced - if ref not in agg_specs and ref in merged_base_measures - } - final_agg_specs = FrozenDict({**agg_specs, **additional_aggs}) +def _topological_calc_order( + calc_specs: dict[str, CalcMeasure], + base_tbl, + known_measures: frozenset[str], +) -> list[str]: + """Order calc measures by ``CalcMeasure.depends_on`` so deps compile first.""" + deps = {name: set(cm.depends_on) for name, cm in calc_specs.items()} + return topological_order_from_deps(calc_specs, deps) - return _AggregationPlan( - agg_specs=final_agg_specs, - calc_specs=expanded_calc_specs, - requested_measures=tuple(aggs.keys()), - group_by_cols=tuple(keys), - ) + +def _detect_bare_name_lambda(fn: Any) -> str | None: + """Return the captured name when ``fn`` was generated by ``make_bare_ref_lambda``. + + Read the ``_bsl_bare_ref`` sentinel attribute set at the API site — + sniffing ``__defaults__`` was unreliable because user lambdas with + arbitrary string defaults (e.g. ``lambda t, c=col, op=op: getattr(...) + ``) collide with the trivial ``lambda t, n=name: t[n]`` shape and + silently misroute as bare references. + """ + if not callable(fn): + return None + name = getattr(fn, "_bsl_bare_ref", None) + if isinstance(name, str): + return name + return None + + +def make_bare_ref_lambda(name: str): + """Build a ``lambda t: t[name]`` tagged for fast-path measure lookup. + + Use this anywhere the BSL surface needs to construct a measure-name + passthrough callable: it sets ``_bsl_bare_ref`` so + :func:`_detect_bare_name_lambda` can route the call straight to the + referenced base or calc measure without re-running the analyzer. + """ + fn = lambda t, _n=name: t[_n] # noqa: E731 + fn._bsl_bare_ref = name + return fn # --------------------------------------------------------------------------- @@ -2403,16 +2599,14 @@ def collect_mutates_to_join(node): tbl=tbl, ) - if plan.calc_specs or plan.group_by_cols: - return compile_grouped_with_all( - tbl, - list(plan.group_by_cols), - dict(plan.agg_specs), - dict(plan.calc_specs), - requested_measures=list(plan.requested_measures), - ) - else: - return tbl.aggregate({name: fn(tbl) for name, fn in plan.agg_specs.items()}) + return _compile_aggregation( + tbl, + list(plan.group_by_cols), + dict(plan.agg_specs), + dict(plan.calc_specs), + known_measures=frozenset(merged_base_measures) | frozenset(merged_calc_measures), + requested_measures=list(plan.requested_measures), + ) def _to_untagged_with_preagg( self, @@ -2937,9 +3131,14 @@ def strip_deferred(node): # Handle calculated measures if plan.calc_specs: - from .compile_all import compile_calc_measures - - result = compile_calc_measures(result, plan.calc_specs) + calc_lambdas = { + name: cm.expr if isinstance(cm, CalcMeasure) else cm + for name, cm in plan.calc_specs.items() + } + known = frozenset(merged_base_measures) | frozenset(merged_calc_measures) + result = apply_calc_measures( + result, core_tbl, calc_lambdas, known, agg_specs=dict(plan.agg_specs) + ) # --- 3. LEFT JOIN deferred dimension tables --- for d in deferrable: @@ -3014,7 +3213,7 @@ def _join_preagg_with_dim_bridge( ``decomposed_means`` and ``reagg_ops`` are tuples of (key, value) pairs. """ - from .compile_all import _join_tables + from .nested_compile import join_tables as _join_tables reagg_map = dict(reagg_ops) # Include decomposed auxiliary columns in measure names @@ -3098,7 +3297,7 @@ def _build_minimal_dim_bridge( ``decomposed_means`` and ``reagg_ops`` are tuples of (key, value) pairs. """ - from .compile_all import _join_tables + from .nested_compile import join_tables as _join_tables reagg_map = dict(reagg_ops) aux_cols = frozenset(c for _, (sc, cc) in decomposed_means for c in (sc, cc)) @@ -3161,25 +3360,31 @@ def _bridge_one_preagg(pt): @staticmethod def _apply_calc_specs(result, plan, tbl): - """Apply calculated measure specs (ratios, percent-of-total, etc.).""" - from .compile_all import _collect_all_refs, _compile_formula - - needed_totals: set[str] = set() - for ast in plan.calc_specs.values(): - _collect_all_refs(ast, needed_totals) - - if needed_totals: - totals_aggs = {ref: result[ref].sum() for ref in needed_totals if ref in result.columns} - all_tbl = result.aggregate(**totals_aggs) if totals_aggs else None - else: - all_tbl = None - - out = result.cross_join(all_tbl) if all_tbl is not None else result - calc_cols = { - name: _compile_formula(ast, out, all_tbl, tbl if tbl is not None else out) - for name, ast in plan.calc_specs.items() + """Apply calculated measure specs to the pre-aggregated result. + + Each calc spec is a :class:`CalcMeasure` whose lambda is + re-evaluated against the post-aggregation result via the + ibis-native compiler. ``t.all(measure_ref)`` patterns trigger a + no-group-by totals aggregation that gets cross-joined into the + result so non-sum measures (mean/quantile/…) get correct overall + values; ``apply_calc_measures`` builds the totals lazily on + first use. + """ + calc_lambdas = { + name: cm.expr if isinstance(cm, CalcMeasure) else cm + for name, cm in plan.calc_specs.items() } - return out.mutate(**calc_cols) + if not calc_lambdas: + return result + base_for_calc = tbl if tbl is not None else result + known = frozenset(plan.agg_specs.keys()) | frozenset(plan.calc_specs.keys()) + return apply_calc_measures( + result, + base_for_calc, + calc_lambdas, + known, + agg_specs=dict(plan.agg_specs), + ) class SemanticMutateOp(Relation): @@ -4707,43 +4912,6 @@ def _walk_join_spine(n): return depth_map -def _update_measure_refs_in_calc(expr, prefix_map: dict[str, str]): - """ - Recursively update MeasureRef names in a calculated measure expression. - - Args: - expr: A MeasureExpr (MeasureRef, AllOf, BinOp, MethodCall, or literal) - prefix_map: Mapping from old name to new prefixed name - - Returns: - Updated expression with prefixed MeasureRef names - """ - if isinstance(expr, MeasureRef): - # Update the measure reference name if it's in the map - new_name = prefix_map.get(expr.name, expr.name) - return MeasureRef(new_name) - elif isinstance(expr, AllOf): - # Update the inner MeasureRef - updated_ref = _update_measure_refs_in_calc(expr.ref, prefix_map) - return AllOf(updated_ref) - elif isinstance(expr, MethodCall): - updated_receiver = _update_measure_refs_in_calc(expr.receiver, prefix_map) - return MethodCall( - receiver=updated_receiver, - method=expr.method, - args=expr.args, - kwargs=expr.kwargs, - ) - elif isinstance(expr, BinOp): - # Recursively update left and right - updated_left = _update_measure_refs_in_calc(expr.left, prefix_map) - updated_right = _update_measure_refs_in_calc(expr.right, prefix_map) - return BinOp(op=expr.op, left=updated_left, right=updated_right) - else: - # Literal number or other - return as-is - return expr - - def _extract_join_key_column_names(source: Relation) -> set[str]: """ Extract column names that ibis will merge (coalesce) during joins. @@ -4956,21 +5124,13 @@ def _merge_fields_with_prefixing( merged_fields = {} - is_calc_measures = False is_dimensions = False if all_roots: sample_fields = field_accessor(all_roots[0]) if sample_fields: - from .measure_scope import AllOf, BinOp, MeasureRef, MethodCall - first_val = next(iter(sample_fields.values()), None) - is_calc_measures = isinstance( - first_val, - MeasureRef | AllOf | BinOp | MethodCall | int | float, - ) is_dimensions = isinstance(first_val, Dimension) - # For dimensions, build a column rename map to handle Ibis join conflicts column_rename_map = {} if is_dimensions: column_rename_map = _build_column_rename_map(all_roots, field_accessor, source) @@ -4979,36 +5139,17 @@ def _merge_fields_with_prefixing( root_name = root.name fields_dict = field_accessor(root) - if is_calc_measures and root_name: - base_map = ( - {k: f"{root_name}.{k}" for k in root.get_measures()} - if hasattr(root, "get_measures") - else {} - ) - calc_map = ( - {k: f"{root_name}.{k}" for k in root.get_calculated_measures()} - if hasattr(root, "get_calculated_measures") - else {} - ) - prefix_map = {**base_map, **calc_map} - for field_name, field_value in fields_dict.items(): if root_name: - # Always use prefixed name with . separator prefixed_name = f"{root_name}.{field_name}" - # If it's a calculated measure, update internal MeasureRefs - if is_calc_measures: - field_value = _update_measure_refs_in_calc(field_value, prefix_map) - # If it's a dimension that needs column renaming, wrap the callable - elif is_dimensions and prefixed_name in column_rename_map: + if is_dimensions and prefixed_name in column_rename_map: field_value = _wrap_dimension_for_renamed_column( field_value, column_rename_map[prefixed_name] ) merged_fields[prefixed_name] = field_value else: - # Fallback to original name if no root name merged_fields[field_name] = field_value return FrozenDict(merged_fields) diff --git a/src/boring_semantic_layer/serialization/extract.py b/src/boring_semantic_layer/serialization/extract.py index 52c9f49..7934e7c 100644 --- a/src/boring_semantic_layer/serialization/extract.py +++ b/src/boring_semantic_layer/serialization/extract.py @@ -304,76 +304,85 @@ def do_serialize(): def serialize_calc_measures(calc_measures: Mapping[str, Any]) -> Result[dict, Exception]: + """Serialize calc measures (``CalcMeasure`` objects) by resolver-tree. + + Each calc measure stores the original user lambda. We run it once + against a fresh ``Deferred`` variable to capture the structural shape + (calls to ``.all(...)``, attribute access, arithmetic ...) and + serialize the resulting resolver tree. + """ + from ..utils import expr_to_structured + @safe def do_serialize(): - from ..measure_scope import AggregationExpr, AllOf, BinOp, MeasureRef, MethodCall - - def _serialize_calc_expr(expr): - if isinstance(expr, MeasureRef): - return ("measure_ref", expr.name) - if isinstance(expr, AggregationExpr): - return ("agg_expr", expr.column, expr.operation, expr.post_ops) - if isinstance(expr, AllOf): - return ("all_of", _serialize_calc_expr(expr.ref)) - if isinstance(expr, MethodCall): - return ( - "method_call", - _serialize_calc_expr(expr.receiver), - expr.method, - tuple(expr.args), - tuple(expr.kwargs), - ) - if isinstance(expr, BinOp): - return ( - "calc_binop", - expr.op, - _serialize_calc_expr(expr.left), - _serialize_calc_expr(expr.right), - ) - if isinstance(expr, int | float): - return ("num", expr) - return None - - result = {} - for name, expr in calc_measures.items(): - serialized = _serialize_calc_expr(expr) - if serialized is not None: - result[name] = serialized + result: dict[str, Any] = {} + for name, calc in calc_measures.items(): + fn = getattr(calc, "expr", calc) + struct_result = expr_to_structured(fn) + entry: dict[str, Any] = {} + match struct_result: + case Success(): + entry["expr_struct"] = struct_result.unwrap() + case _: + continue + description = getattr(calc, "description", None) + if description is not None: + entry["description"] = description + requires_unnest = getattr(calc, "requires_unnest", ()) + if requires_unnest: + entry["requires_unnest"] = list(requires_unnest) + depends_on = getattr(calc, "depends_on", None) + if depends_on: + entry["depends_on"] = sorted(depends_on) + result[name] = entry return result return do_serialize() def deserialize_calc_measures(calc_data: Mapping[str, Any]) -> dict[str, Any]: - from ..measure_scope import AggregationExpr, AllOf, BinOp, MeasureRef, MethodCall + """Reconstruct calc measures from their serialized resolver trees. - from .freeze import list_to_tuple + Returns a dict mapping ``name → CalcMeasure``. Each entry's expression + is a Deferred whose resolver mirrors the original lambda's structural + shape; at query time the planner runs it against an + ``IbisCalcScope`` exactly like a user-supplied lambda. + """ + from ..ops import CalcMeasure + from ..utils import structured_to_expr - def _deserialize_calc_expr(data): - if isinstance(data, int | float): - return data - tag = data[0] - if tag == "measure_ref": - return MeasureRef(data[1]) - if tag == "agg_expr": - return AggregationExpr( - column=data[1], - operation=data[2], - post_ops=list_to_tuple(data[3]) if data[3] else (), - ) - if tag == "all_of": - return AllOf(_deserialize_calc_expr(data[1])) - if tag == "method_call": - return MethodCall( - receiver=_deserialize_calc_expr(data[1]), - method=data[2], - args=tuple(data[3]) if data[3] else (), - kwargs=tuple(data[4]) if data[4] else (), - ) - if tag == "calc_binop": - return BinOp(data[1], _deserialize_calc_expr(data[2]), _deserialize_calc_expr(data[3])) - if tag == "num": - return data[1] - raise ValueError(f"Unknown calc measure tag: {tag}") + from .freeze import list_to_tuple - return {name: _deserialize_calc_expr(expr) for name, expr in calc_data.items()} + out: dict[str, Any] = {} + for name, data in calc_data.items(): + if isinstance(data, dict): + entry = data + struct = entry.get("expr_struct") + description = entry.get("description") + requires_unnest = tuple(entry.get("requires_unnest", ()) or ()) + depends_on = frozenset(entry.get("depends_on", ()) or ()) + else: + # Backwards-compat: old curated-AST tags arrive as bare tuples. + struct = data + description = None + requires_unnest = () + depends_on = frozenset() + + if struct is None: + continue + # ``thaw`` converts the resolver tuple into a list of lists; the + # resolver deserializer expects nested tuples, so convert back. + struct = list_to_tuple(struct) + result = structured_to_expr(struct) + match result: + case Success(): + expr = result.unwrap() + case _: + continue + out[name] = CalcMeasure( + expr=expr, + description=description, + requires_unnest=requires_unnest, + depends_on=depends_on, + ) + return out diff --git a/src/boring_semantic_layer/server/api.py b/src/boring_semantic_layer/server/api.py index 820f623..a62b783 100644 --- a/src/boring_semantic_layer/server/api.py +++ b/src/boring_semantic_layer/server/api.py @@ -162,7 +162,7 @@ def _search_dimension_values_response( ), ) - from boring_semantic_layer.compile_all import _get_ibis_module + from boring_semantic_layer.nested_compile import get_ibis_module as _get_ibis_module dim = dims[dimension_name] tbl = model.table diff --git a/src/boring_semantic_layer/tests/test_calc_analyzer.py b/src/boring_semantic_layer/tests/test_calc_analyzer.py new file mode 100644 index 0000000..49ea936 --- /dev/null +++ b/src/boring_semantic_layer/tests/test_calc_analyzer.py @@ -0,0 +1,115 @@ +"""Tests for the ibis-tree analyzer that replaces the curated calc-measure +AST classifier. + +The analyzer walks an ibis expression tree and returns a +:class:`CalcExprAnalysis` record. These tests exercise each branch of +the classification — pushable base, post-agg measure refs, totals +pattern (``t.all``-style empty window over reduction), real windowed +expressions (moving avg, rank), and inline aggregations — to lock in +the structural shapes the planner reads downstream. +""" + +from __future__ import annotations + +import pytest + +xorq = pytest.importorskip("xorq", reason="xorq not installed") + +from boring_semantic_layer._xorq import ibis as xibis # noqa: E402 +from boring_semantic_layer.calc_analyzer import ( # noqa: E402 + CalcExprAnalysis, + analyze_calc_expr, + virtual_agg_table, +) + + +def _vt(): + return virtual_agg_table( + {"flight_count": "int64", "total_distance": "float64", "date": "date"} + ) + + +def _base(): + return xibis.table( + {"distance": "float64", "passengers": "int64", "carrier": "string"}, + "flights", + ) + + +def test_literal_is_pushable(): + r = analyze_calc_expr(42) + assert r.pushable is True + assert r.post_agg_only is False + assert r.depends_on == frozenset() + + +def test_plain_reduction_on_base_is_pushable(): + base = _base() + r = analyze_calc_expr(base.distance.sum(), base_table_op=base.op()) + assert r.pushable is True + assert r.post_agg_only is False + assert r.inline_aggs == frozenset({"distance"}) + + +def test_arith_of_aggs_on_same_base_is_pushable(): + base = _base() + r = analyze_calc_expr( + base.distance.sum() / base.passengers.sum(), base_table_op=base.op() + ) + assert r.pushable is True + assert r.inline_aggs == frozenset({"distance", "passengers"}) + + +def test_measure_ratio_is_post_agg_only(): + vt = _vt() + r = analyze_calc_expr( + vt.flight_count / vt.total_distance, + known_measures=frozenset({"flight_count", "total_distance"}), + ) + assert r.post_agg_only is True + assert r.pushable is False + assert r.depends_on == frozenset({"flight_count", "total_distance"}) + assert r.references_AllOf is False + assert r.has_window is False + + +def test_empty_window_over_reduction_is_totals_reference(): + """``t.all(x)`` shape: x.sum().over(empty window) reads as references_AllOf.""" + vt = _vt() + r = analyze_calc_expr( + vt.flight_count / vt.flight_count.sum().over(xibis.window()), + known_measures=frozenset({"flight_count"}), + ) + assert r.references_AllOf is True + assert r.has_window is True + assert r.post_agg_only is True + assert "flight_count" in r.depends_on + + +def test_ordered_window_is_window_not_totals(): + """Moving average / rank-style windows must not be classified as totals.""" + vt = _vt() + r = analyze_calc_expr( + vt.flight_count.mean().over(xibis.window(order_by="date", preceding=2)), + known_measures=frozenset({"flight_count", "date"}), + ) + assert r.has_window is True + assert r.references_AllOf is False + assert r.post_agg_only is True + + +def test_unknown_input_warns_and_falls_back(): + class Weird: + pass + + with pytest.warns(UserWarning, match="post-aggregation-only"): + r = analyze_calc_expr(Weird()) + assert r.post_agg_only is True + assert r.pushable is False + + +def test_returns_frozen_dataclass(): + r = analyze_calc_expr(1) + assert isinstance(r, CalcExprAnalysis) + with pytest.raises(Exception): + r.pushable = False # frozen diff --git a/src/boring_semantic_layer/tests/test_calc_compiler.py b/src/boring_semantic_layer/tests/test_calc_compiler.py new file mode 100644 index 0000000..e953d3b --- /dev/null +++ b/src/boring_semantic_layer/tests/test_calc_compiler.py @@ -0,0 +1,664 @@ +"""Tests for the ibis-native calc-measure compiler. + +Exercises :class:`IbisCalcScope` dispatch, lambda evaluation, structural +classification via the analyzer, and compile-time substitution of the +virtual aggregated table with the real one. +""" + +from __future__ import annotations + +import pandas as pd +import pytest + +xorq = pytest.importorskip("xorq", reason="xorq not installed") + +import xorq.api as xo # noqa: E402 + +from boring_semantic_layer._xorq import ibis as xibis # noqa: E402 +from boring_semantic_layer.calc_compiler import ( # noqa: E402 + IbisCalcScope, + classify_calc_lambda, + compile_calc_measure, + compile_calc_measures, + evaluate_calc_lambda, +) + + +@pytest.fixture(scope="module") +def base_tbl(): + con = xo.duckdb.connect() + df = pd.DataFrame( + { + "carrier": ["AA", "AA", "UA", "UA", "DL"], + "distance": [100, 200, 150, 250, 300], + "passengers": [10, 20, 30, 40, 50], + } + ) + return con.create_table("flights", df) + + +def test_scope_dispatches_measure_to_virtual_table(base_tbl): + vt = xibis.table({"flight_count": "int64"}, "__virt__") + scope = IbisCalcScope(base_tbl, vt, frozenset({"flight_count"})) + expr = scope.flight_count + # Field's relation should be the virtual table + assert expr.op().rel == vt.op() + + +def test_scope_dispatches_column_to_base_table(base_tbl): + vt = xibis.table({"flight_count": "int64"}, "__virt__") + scope = IbisCalcScope(base_tbl, vt, frozenset({"flight_count"})) + expr = scope.distance + assert expr.op().rel == base_tbl.op() + + +def test_all_with_string_measure_name(base_tbl): + vt = xibis.table({"flight_count": "int64"}, "__virt__") + scope = IbisCalcScope(base_tbl, vt, frozenset({"flight_count"})) + expr = scope.all("flight_count") + # ``t.all(measure_name)`` resolves to a Field on the parallel + # totals virtual table; the compiler later substitutes it with + # the real no-group-by aggregation. + op = expr.op() + assert type(op).__name__ == "Field" + assert op.name == "flight_count" + assert id(op.rel) == id(scope._totals_virtual_agg_tbl.op()) + + +def test_classify_pct_calc_measure(base_tbl): + """``flight_count / t.all(flight_count)`` is a post-agg measure with + ``references_AllOf`` set.""" + fn = lambda t: t.flight_count / t.all(t.flight_count) + _, analysis = classify_calc_lambda(fn, base_tbl, frozenset({"flight_count"})) + assert analysis.post_agg_only is True + assert analysis.pushable is False + assert analysis.references_AllOf is True + assert "flight_count" in analysis.depends_on + + +def test_classify_inline_agg_pushable(base_tbl): + """``t.distance.sum() / t.passengers.sum()`` references only the base + table — the analyzer reports it as pushable (a base measure).""" + fn = lambda t: t.distance.sum() / t.passengers.sum() + _, analysis = classify_calc_lambda(fn, base_tbl, frozenset()) + assert analysis.pushable is True + assert analysis.post_agg_only is False + assert analysis.inline_aggs == frozenset({"distance", "passengers"}) + + +def test_compile_substitutes_virtual_for_real(base_tbl): + """``compile_calc_measure`` rewrites Field(vt) → Field(real_agg).""" + # A pure measure-ref calc: avg_dist references two known measures + # on the virtual table. ``compile_calc_measure`` should rebind both + # Fields to the actual aggregated table. + fn = lambda t: t.total_distance / t.flight_count + known = frozenset({"total_distance", "flight_count"}) + expr, vt, _totals_vt = evaluate_calc_lambda(fn, base_tbl, known) + + real_agg = base_tbl.group_by("carrier").aggregate( + total_distance=base_tbl.distance.sum(), + flight_count=base_tbl.count(), + ) + compiled = compile_calc_measure(expr, vt, real_agg) + + # Walk the compiled op tree; every Field reference should land on + # real_agg, none on the synthetic vt. + real_op = real_agg.op() + vt_op = vt.op() + rels: list[int] = [] + seen: set[int] = set() + stack: list = [compiled.op()] + while stack: + cur = stack.pop() + if id(cur) in seen: + continue + seen.add(id(cur)) + rel = getattr(cur, "rel", None) + if rel is not None: + rels.append(id(rel)) + for child in getattr(cur, "__args__", ()) or (): + if hasattr(child, "__args__") or hasattr(child, "rel"): + stack.append(child) + assert id(real_op) in rels, "compiled expression should reference real_agg" + assert id(vt_op) not in rels, "compiled expression must not reference virtual vt" + + # End-to-end: the table should execute and produce the expected ratio. + final = real_agg.mutate(avg_dist=compiled).execute().sort_values("carrier") + # AA: 300/2 = 150; UA: 400/2 = 200; DL: 300/1 = 300 + by_carrier = dict(zip(final["carrier"], final["avg_dist"], strict=True)) + assert pytest.approx(by_carrier["AA"]) == 150.0 + assert pytest.approx(by_carrier["UA"]) == 200.0 + assert pytest.approx(by_carrier["DL"]) == 300.0 + + +def test_apply_calc_measures_raises_when_totals_unavailable(base_tbl): + """Clear error when t.all(...) is referenced but no totals can be built.""" + from boring_semantic_layer.calc_compiler import ( + TotalsNotAvailableError, + apply_calc_measures, + ) + + fn = lambda t: t.flight_count / t.all(t.flight_count) + real_agg = base_tbl.group_by("carrier").aggregate( + flight_count=base_tbl.count(), + ) + # No agg_specs and no real_totals_tbl — the totals build can't run. + with pytest.raises(TotalsNotAvailableError, match="t.all"): + apply_calc_measures( + real_agg, + base_tbl, + {"pct": fn}, + frozenset({"flight_count"}), + ) + + +def test_compile_pct_calc_measure_end_to_end(base_tbl): + """``apply_calc_measures`` builds totals on demand for ``t.all``.""" + from boring_semantic_layer.calc_compiler import apply_calc_measures + + fn = lambda t: t.flight_count / t.all(t.flight_count) + real_agg = base_tbl.group_by("carrier").aggregate( + flight_count=base_tbl.count(), + ) + final = apply_calc_measures( + real_agg, + base_tbl, + {"pct": fn}, + frozenset({"flight_count"}), + agg_specs={"flight_count": lambda t: t.count()}, + ) + df = final.execute().sort_values("carrier").reset_index(drop=True) + assert "pct" in df.columns + # Sum of per-group counts ÷ total count = 1.0 for sum-style measures. + assert pytest.approx(df["pct"].sum()) == 1.0 + + +def test_compile_no_calcs_passes_through(base_tbl): + real_agg = base_tbl.group_by("carrier").aggregate( + flight_count=base_tbl.count(), + ) + out = compile_calc_measures(real_agg, {}) + assert out is real_agg + + +def test_compile_multiple_calc_measures(base_tbl): + """Two independent calcs apply together: one references totals, one doesn't.""" + from boring_semantic_layer.calc_compiler import apply_calc_measures + + pct = lambda t: t.flight_count / t.all(t.flight_count) + avg_dist = lambda t: t.total_distance / t.flight_count + + real_agg = base_tbl.group_by("carrier").aggregate( + flight_count=base_tbl.count(), + total_distance=base_tbl.distance.sum(), + ) + final = apply_calc_measures( + real_agg, + base_tbl, + {"pct": pct, "avg_dist": avg_dist}, + frozenset({"flight_count", "total_distance"}), + agg_specs={ + "flight_count": lambda t: t.count(), + "total_distance": lambda t: t.distance.sum(), + }, + ) + df = final.execute().sort_values("carrier").reset_index(drop=True) + assert "pct" in df.columns + assert "avg_dist" in df.columns + assert pytest.approx(df["pct"].sum()) == 1.0 + + +def test_multiple_allof_calcs_share_one_totals_per_measure(): + """Two t.all-referencing calcs share their totals computations. + + Each base measure that's referenced by ``t.all(...)`` gets exactly + one windowed-totals column added to the base via + ``measure.over(window())``. Multiple calcs referencing the same + measure share the same totals column — no duplicate window + computation, no quadratic growth. + + The new compilation strategy uses windowed totals carried through + the per-group aggregation rather than cross-joined totals tables, + so the rendered SQL has *zero* ``CROSS JOIN`` operations and one + ``OVER (...)`` window per AllOf-referenced measure. Locking the + "one totals per measure" property guards against an O(n²) regression. + """ + from boring_semantic_layer import to_semantic_table + + con = xo.duckdb.connect() + df = pd.DataFrame( + { + "carrier": ["AA", "AA", "UA", "UA"], + "distance": [100, 200, 300, 400], + "passengers": [10, 20, 30, 40], + } + ) + tbl = con.create_table("flights_share_totals", df) + st = ( + to_semantic_table(tbl, "flights_share_totals") + .with_measures( + total_distance=lambda t: t.distance.sum(), + total_passengers=lambda t: t.passengers.sum(), + ) + .with_measures( + pct_distance=lambda t: t.total_distance / t.all(t.total_distance), + pct_passengers=lambda t: t.total_passengers / t.all(t.total_passengers), + ) + ) + sql = st.group_by("carrier").aggregate("pct_distance", "pct_passengers").compile() + sql_upper = sql.upper() + # No cross joins under the new strategy. + assert sql_upper.count("CROSS JOIN") == 0 + # One windowed totals per AllOf-referenced base measure (two here). + # Each appears as ``SUM(...) OVER (ROWS BETWEEN ...)``. + assert sql_upper.count("__BSL_TOTALS__TOTAL_DISTANCE") >= 1 + assert sql_upper.count("__BSL_TOTALS__TOTAL_PASSENGERS") >= 1 + # Output schema only has the user-requested columns. + result = st.group_by("carrier").aggregate("pct_distance", "pct_passengers") + assert set(result.columns) == {"carrier", "pct_distance", "pct_passengers"} + + +def test_apply_calc_measures_join_with_mean_totals(): + """Joined model: t.all over a non-sum measure recomputes totals from base.""" + from boring_semantic_layer import to_semantic_table + + con = xo.duckdb.connect() + flights = pd.DataFrame( + { + "carrier_code": ["AA", "AA", "UA", "UA"], + "distance": [100, 200, 300, 400], + } + ) + carriers = pd.DataFrame( + {"code": ["AA", "UA"], "carrier_name": ["American", "United"]} + ) + f_tbl = con.create_table("join_flights", flights) + c_tbl = con.create_table("join_carriers", carriers) + + flights_st = to_semantic_table(f_tbl, "flights").with_measures( + avg_distance=lambda t: t.distance.mean(), + ) + carriers_st = to_semantic_table(c_tbl, "carriers").with_dimensions( + carrier_name=lambda t: t.carrier_name, + ) + joined = flights_st.join_one( + carriers_st, + on=lambda left, right: left.carrier_code == right.code, + ).with_measures( + ratio=lambda t: t.avg_distance / t.all(t.avg_distance), + ) + + df = ( + joined.group_by("carrier_name") + .aggregate("avg_distance", "ratio") + .execute() + .sort_values("carrier_name") + .reset_index(drop=True) + ) + + # AA mean=150, UA mean=350; overall mean=250 (NOT 150+350=500). + by_name = dict(zip(df["carrier_name"], df["ratio"], strict=True)) + assert pytest.approx(by_name["American"]) == 150 / 250 + assert pytest.approx(by_name["United"]) == 350 / 250 + + +@pytest.mark.parametrize( + "reducer,expected_total,per_group", + [ + # Median: pooled rows [100, 200, 300, 400] → 250. + ("median", 250.0, {"AA": 150.0, "UA": 350.0}), + # Min: AA=100, UA=300, overall=100. + ("min", 100.0, {"AA": 100.0, "UA": 300.0}), + # Max: AA=200, UA=400, overall=400. + ("max", 400.0, {"AA": 200.0, "UA": 400.0}), + ], +) +def test_apply_calc_measures_non_sum_totals(reducer, expected_total, per_group): + """``t.all`` over min/max/median recomputes totals via the formula, not a windowed sum. + + Locks the same v1-bug fix as the mean case for the rest of the + common non-sum reductions: per-group sums-of-medians or sum-of-mins + would be obviously wrong. + """ + from boring_semantic_layer import to_semantic_table + + con = xo.duckdb.connect() + df = pd.DataFrame( + { + "carrier": ["AA", "AA", "UA", "UA"], + "distance": [100, 200, 300, 400], + } + ) + tbl = con.create_table(f"flights_nonsum_{reducer}", df) + + st = ( + to_semantic_table(tbl, f"flights_nonsum_{reducer}") + .with_measures(**{f"d_{reducer}": lambda t, op=reducer: getattr(t.distance, op)()}) + .with_measures( + ratio=lambda t, op=reducer: getattr(t, f"d_{op}") + / t.all(getattr(t, f"d_{op}")), + ) + ) + df_out = ( + st.group_by("carrier") + .aggregate(f"d_{reducer}", "ratio") + .execute() + .sort_values("carrier") + .reset_index(drop=True) + ) + by_carrier = dict(zip(df_out["carrier"], df_out[f"d_{reducer}"], strict=True)) + assert pytest.approx(by_carrier["AA"]) == per_group["AA"] + assert pytest.approx(by_carrier["UA"]) == per_group["UA"] + by_ratio = dict(zip(df_out["carrier"], df_out["ratio"], strict=True)) + assert pytest.approx(by_ratio["AA"]) == per_group["AA"] / expected_total + assert pytest.approx(by_ratio["UA"]) == per_group["UA"] / expected_total + + +def test_cast_to_float_survives_int_measure_substitution(): + """``int_measure.cast('float64') / int_measure_total * 100`` returns nonzero. + + Regression test for the bug where the preprocess step in + ``_compile_aggregation`` populated the virtual aggregated table's + schema with placeholder ``float64`` dtypes for every measure. + User casts like ``t.flight_count.cast('float64')`` were elided as + no-ops by ibis (the column was already float64 in the synthetic + schema). After substitution to the real aggregated table — where + ``flight_count`` is int64 (from ``CountStar``) — the Cast was gone, + so ``int / int * 100`` returned 0 for ratios less than 1. + + The fix uses the *real* dtype derived from ``agg_specs[name](base_tbl).type()`` + so the cast is preserved when substituted. This test pins the + behavior end-to-end with a count-style integer measure and a + ``cast('float64')``-using calc. + """ + from boring_semantic_layer import to_semantic_table + + con = xo.duckdb.connect() + df = pd.DataFrame( + { + "carrier": ["AA"] * 30 + ["UA"] * 70, + "value": list(range(100)), + } + ) + tbl = con.create_table("flights_cast_regression", df) + + st = ( + to_semantic_table(tbl, "flights_cast_regression") + .with_measures( + flight_count=lambda t: t.count(), # int64 + ) + .with_measures( + share_pct=( + lambda t: t.flight_count.cast("float64") / t.all(t.flight_count) * 100 + ), + ) + ) + result = ( + st.group_by("carrier") + .aggregate("flight_count", "share_pct") + .execute() + .sort_values("carrier") + .reset_index(drop=True) + ) + by_carrier = dict(zip(result["carrier"], result["share_pct"], strict=True)) + # AA = 30/100 = 30%, UA = 70/100 = 70%; sum = 100% (sanity) + assert pytest.approx(by_carrier["AA"]) == 30.0 + assert pytest.approx(by_carrier["UA"]) == 70.0 + assert pytest.approx(result["share_pct"].sum()) == 100.0 + + +def test_joined_model_totals_via_windowed_aggregation(): + """``t.all(measure)`` on a joined model returns correct per-group shares. + + Regression: previously, ``t.all(measure)`` on a joined model + compiled to two aggregations of a shared parent relation + (``Aggregate(JoinChain)`` for per-group + ``Aggregate(JoinChain)`` + without group_by, both referencing the same JoinChain) and + cross-joined them. SQL backends that fold shared ancestors + collapsed the cross-join to zero rows, so the user got an empty + result instead of per-group totals. + + The fix attaches a windowed aggregation to the base before + group_by — ``measure.over(window()) AS __bsl_totals__`` — + and carries it through the per-group aggregation via + ``arbitrary()``. There is no cross-join, so the shared-ancestor + optimization can't kick in. + """ + from boring_semantic_layer import to_semantic_table + + con = xo.duckdb.connect() + flights = pd.DataFrame( + { + "carrier_code": ["AA", "AA", "AA", "UA", "UA", "DL"], + "distance": [100, 200, 300, 400, 500, 600], + } + ) + carriers = pd.DataFrame( + {"code": ["AA", "UA", "DL"], "carrier_name": ["American", "United", "Delta"]} + ) + f_tbl = con.create_table("joined_totals_flights", flights) + c_tbl = con.create_table("joined_totals_carriers", carriers) + + flights_st = ( + to_semantic_table(f_tbl, "flights") + .with_measures( + flight_count=lambda t: t.count(), + ) + .with_measures( + share=lambda t: t.flight_count.cast("float64") / t.all(t.flight_count) * 100, + ) + ) + carriers_st = to_semantic_table(c_tbl, "carriers").with_dimensions( + carrier_name=lambda t: t.carrier_name, + ) + joined = flights_st.join_one( + carriers_st, + on=lambda left, right: left.carrier_code == right.code, + ) + + df = ( + joined.group_by("carriers.carrier_name") + .aggregate("flights.flight_count", "flights.share") + .execute() + .sort_values("carriers.carrier_name") + .reset_index(drop=True) + ) + # Three carriers — non-empty result is the regression-blocking property. + assert len(df) == 3 + # Per-carrier counts: AA=3, UA=2, DL=1; total=6. + by_name = dict( + zip(df["carriers.carrier_name"], df["flights.flight_count"], strict=True) + ) + assert by_name == {"American": 3, "United": 2, "Delta": 1} + # Shares sum to exactly 100% (windowed totals are constant across rows + # — no per-group / totals snapshot drift). + assert pytest.approx(df["flights.share"].sum()) == 100.0 + by_share = dict(zip(df["carriers.carrier_name"], df["flights.share"], strict=True)) + assert pytest.approx(by_share["American"]) == 50.0 + assert pytest.approx(by_share["United"]) == pytest.approx(100.0 * 2 / 6) + assert pytest.approx(by_share["Delta"]) == pytest.approx(100.0 * 1 / 6) + + +def test_joined_model_totals_does_not_emit_cross_join(): + """The joined-model totals path compiles to zero ``CROSS JOIN`` operations. + + The new strategy carries totals through the per-group aggregation + via window functions, not via a cross-joined totals table. Locking + this in SQL prevents a regression to the shared-ancestor cross-join + pattern that would silently return zero rows on some backends. + """ + from boring_semantic_layer import to_semantic_table + + con = xo.duckdb.connect() + flights = pd.DataFrame( + {"carrier_code": ["AA", "AA", "UA"], "distance": [100, 200, 300]} + ) + carriers = pd.DataFrame({"code": ["AA", "UA"], "name": ["American", "United"]}) + f_tbl = con.create_table("nocrossjoin_flights", flights) + c_tbl = con.create_table("nocrossjoin_carriers", carriers) + + flights_st = ( + to_semantic_table(f_tbl, "flights") + .with_measures(flight_count=lambda t: t.count()) + .with_measures(share=lambda t: t.flight_count / t.all(t.flight_count)) + ) + carriers_st = to_semantic_table(c_tbl, "carriers").with_dimensions( + name=lambda t: t.name, + ) + joined = flights_st.join_one( + carriers_st, on=lambda l, r: l.carrier_code == r.code + ) + + sql = joined.group_by("carriers.name").aggregate("flights.share").compile() + assert sql.upper().count("CROSS JOIN") == 0 + # And there's at least one OVER() window (the totals computation). + assert "OVER" in sql.upper() + + +def test_attach_windowed_totals_helper(): + """``attach_windowed_totals`` adds ``__bsl_totals__`` columns. + + Direct unit test: each base measure gets a window-aggregated + column on the base table, plus an arbitrary() spec for the + per-group aggregation. + """ + from boring_semantic_layer.calc_compiler import ( + TOTALS_PREFIX, + attach_windowed_totals, + ) + + con = xo.duckdb.connect() + df = pd.DataFrame( + {"g": ["a", "a", "b", "b"], "v": [10, 20, 30, 40]} + ) + base = con.create_table("attach_helper", df) + + agg_specs = { + "v_sum": lambda t: t.v.sum(), + "v_mean": lambda t: t.v.mean(), + } + new_base, totals_arbitrary_specs = attach_windowed_totals( + base, agg_specs, ["v_sum", "v_mean"] + ) + # Two new columns added to the base. + assert f"{TOTALS_PREFIX}v_sum" in new_base.columns + assert f"{TOTALS_PREFIX}v_mean" in new_base.columns + # Each row has the same totals value — verify by executing. + materialized = new_base.execute() + assert (materialized[f"{TOTALS_PREFIX}v_sum"] == 100).all() + assert (materialized[f"{TOTALS_PREFIX}v_mean"] == 25.0).all() + # The arbitrary specs return the totals when applied to the base. + assert set(totals_arbitrary_specs) == { + f"{TOTALS_PREFIX}v_sum", + f"{TOTALS_PREFIX}v_mean", + } + + +def test_attach_calc_totals_handles_calc_of_calc(): + """Calc-of-calc-AllOf chains derive their totals via the totals scope. + + Tests the ``attach_calc_totals`` helper directly: given a per-group + result with ``__bsl_totals__`` columns already attached and + a calc ``avg = total_distance / total_count`` whose totals are + needed (because some other calc references ``t.all(t.avg)``), the + helper should add ``__bsl_totals__avg`` computed from + ``__bsl_totals__total_distance / __bsl_totals__total_count``. + """ + from boring_semantic_layer import to_semantic_table + + con = xo.duckdb.connect() + df = pd.DataFrame( + { + "carrier": ["AA", "AA", "UA", "UA"], + "distance": [100, 200, 300, 400], + } + ) + tbl = con.create_table("calc_of_calc_totals", df) + + st = ( + to_semantic_table(tbl, "calc_of_calc_totals") + .with_measures( + total_distance=lambda t: t.distance.sum(), + total_flights=lambda t: t.count(), + ) + .with_measures( + avg_distance=lambda t: t.total_distance / t.total_flights, + ) + .with_measures( + ratio=lambda t: t.avg_distance / t.all(t.avg_distance), + ) + ) + + df_out = ( + st.group_by("carrier") + .aggregate("avg_distance", "ratio") + .execute() + .sort_values("carrier") + .reset_index(drop=True) + ) + by_carrier = dict( + zip(df_out["carrier"], df_out["avg_distance"], strict=True) + ) + # AA mean=150, UA mean=350; overall mean=250 (NOT sum-of-means=500). + assert pytest.approx(by_carrier["AA"]) == 150.0 + assert pytest.approx(by_carrier["UA"]) == 350.0 + by_ratio = dict(zip(df_out["carrier"], df_out["ratio"], strict=True)) + assert pytest.approx(by_ratio["AA"]) == 150 / 250 + assert pytest.approx(by_ratio["UA"]) == 350 / 250 + + +def test_lift_inline_reductions_routes_window_to_totals(): + """The two-pass substitution gives top-level reductions vt refs and + ``t.all(...)``-style windowed reductions totals_vt refs. + + Locks the contract documented in :func:`lift_inline_reductions`: + the same ``Reduction`` node may appear both at top level (per-group + value, want ``Field(vt, anon)``) and as a ``WindowFunction.func`` + (totals value, want ``Field(totals_vt, anon)``). Bind the reduction + to a single Python object so the duplicate-id case (which + ``op.replace`` would dedupe by equality) is exercised end-to-end. + """ + from boring_semantic_layer.calc_compiler import lift_inline_reductions + + base = xibis.table( + {"distance": "float64", "passengers": "int64"}, + "flights_lift", + ) + vt = xibis.table({"__bsl_unused__": "int64"}, "__vt__") + + shared = base.distance.sum() + expr = shared / shared.over(xibis.window()) + + rewritten, new_vt, new_totals_vt, lifted = lift_inline_reductions(expr, vt, base) + + # A single shared reduction should produce exactly one anonymous lift — + # locking the dedup-by-id behavior at the top of the function. + assert len(lifted) == 1 + anon_name = next(iter(lifted)) + + assert anon_name in dict(new_vt.op().schema.items()) + assert anon_name in dict(new_totals_vt.op().schema.items()) + + new_vt_id = id(new_vt.op()) + new_totals_id = id(new_totals_vt.op()) + rewritten_op = rewritten.op() if hasattr(rewritten, "op") else rewritten + + fields_seen: list[tuple[str, int]] = [] + seen: set[int] = set() + stack: list = [rewritten_op] + while stack: + cur = stack.pop() + if id(cur) in seen: + continue + seen.add(id(cur)) + if hasattr(cur, "name") and hasattr(cur, "rel"): + fields_seen.append((cur.name, id(cur.rel))) + for child in getattr(cur, "__args__", ()) or (): + if hasattr(child, "__args__") or hasattr(child, "rel"): + stack.append(child) + + rels_for_anon = {r for n, r in fields_seen if n == anon_name} + assert new_vt_id in rels_for_anon, "expected Field(new_vt, anon) for the bare reduction" + assert new_totals_id in rels_for_anon, ( + "expected Field(new_totals_vt, anon) for the windowed reduction" + ) diff --git a/src/boring_semantic_layer/tests/test_deferred_api.py b/src/boring_semantic_layer/tests/test_deferred_api.py index 1a828e6..61cf816 100644 --- a/src/boring_semantic_layer/tests/test_deferred_api.py +++ b/src/boring_semantic_layer/tests/test_deferred_api.py @@ -1530,7 +1530,14 @@ def test_group_by_level1_aggregate_level0(self, deep_model): assert df["shops.total_revenue"].iloc[0] == 2000 def test_all_with_aggregation_expr_post_ops(): - """Test t.all() with inline AggregationExpr that includes post-ops.""" + """``t.all()`` over an inline aggregation with post-ops works end-to-end. + + The analyzer-based compiler now lifts inline reductions into anonymous + base measures so patterns like ``t.value.sum().coalesce(0) / + t.all(t.value.sum().coalesce(0))`` compile correctly: each unique + inline reduction becomes a column on the post-aggregation table, and + the ``t.all(...)`` call wraps that column in a windowed sum. + """ con = ibis.duckdb.connect(":memory:") events = pd.DataFrame( { diff --git a/src/boring_semantic_layer/tests/test_measure_reference_styles.py b/src/boring_semantic_layer/tests/test_measure_reference_styles.py index 93c0c6b..fd80207 100644 --- a/src/boring_semantic_layer/tests/test_measure_reference_styles.py +++ b/src/boring_semantic_layer/tests/test_measure_reference_styles.py @@ -233,7 +233,16 @@ def test_inline_measure_with_different_reference_styles(): def test_all_of_multilayer_calc_measure(): - """t.all() should work when pointing at a calculated measure chain.""" + """``t.all()`` over a calc-of-calc chain re-aggregates from the base. + + The compiler builds a totals table by re-running the same aggregation + without group_by, applies non-AllOf calc measures to it, and + cross-joins it into the per-group result so ``t.all(measure_ref)`` + references the totals column directly. For non-sum-style chains + (e.g. ``avg_distance + 1``) this matches the curated-AST behavior: + the ``+ 1`` participates in the totals computation exactly once, + not once per group. + """ con = ibis.duckdb.connect(":memory:") flights = pd.DataFrame( { @@ -258,11 +267,55 @@ def test_all_of_multilayer_calc_measure(): df = flights_st.group_by("carrier").aggregate("pct_of_total").execute() + # AA avg_distance = 300/2 = 150, +1 = 151 + # UA avg_distance = 700/2 = 350, +1 = 351 + # Totals (re-aggregated from base): avg_distance = 1000/4 = 250, +1 = 251 assert len(df) == 2 assert "pct_of_total" in df.columns assert pytest.approx(sorted(df.pct_of_total.tolist())) == sorted([151 / 251, 351 / 251]) +def test_all_of_non_sum_measure_uses_totals_table(): + """``t.all(mean_measure)`` re-aggregates from base, not sum-of-means. + + The bug: emitting ``column.sum().over(window())`` for ``t.all(...)`` + is correct for sum-style measures (sum of per-group sums = overall + sum) but wrong for ``mean`` (sum of per-group means != overall mean). + The fix builds a totals table by re-aggregating from base without + group_by, then references the totals column. + """ + con = ibis.duckdb.connect(":memory:") + flights = pd.DataFrame( + { + "carrier": ["AA", "AA", "UA", "UA"], + "distance": [100, 200, 300, 400], + } + ) + f_tbl = con.create_table("flights_avg", flights) + + flights_st = ( + to_semantic_table(f_tbl, "flights_avg") + .with_measures(avg_distance=lambda t: t.distance.mean()) + .with_measures( + ratio=lambda t: t.avg_distance / t.all(t.avg_distance), + ) + ) + + df = ( + flights_st.group_by("carrier") + .aggregate("avg_distance", "ratio") + .execute() + .sort_values("carrier") + .reset_index(drop=True) + ) + + # AA avg = 150, UA avg = 350, overall avg = 250 (NOT 150+350=500). + aa = df[df.carrier == "AA"].iloc[0] + ua = df[df.carrier == "UA"].iloc[0] + assert pytest.approx(aa.ratio) == 150 / 250 + assert pytest.approx(ua.ratio) == 350 / 250 + + # --- Tests for .values / .schema / .columns with calc measures --- @@ -475,57 +528,35 @@ def test_method_call_fillna_on_calc_measure(): def test_method_call_serialization_roundtrip(): - """MethodCall should survive serialize/deserialize roundtrip.""" - from boring_semantic_layer.measure_scope import BinOp, MeasureRef, MethodCall - from boring_semantic_layer.serialization import ( - deserialize_calc_measures, - serialize_calc_measures, - ) - - # Build: (total_distance / flight_count).round(2) - expr = MethodCall( - receiver=BinOp("div", MeasureRef("total_distance"), MeasureRef("flight_count")), - method="round", - args=(2,), - kwargs=(), - ) - - calc_measures = {"avg_distance": expr} - serialized = serialize_calc_measures(calc_measures).unwrap() - deserialized = deserialize_calc_measures(serialized) + """A method-call calc measure survives serialize/deserialize roundtrip. - result = deserialized["avg_distance"] - assert isinstance(result, MethodCall) - assert result.method == "round" - assert result.args == (2,) - assert isinstance(result.receiver, BinOp) - assert result.receiver.op == "div" + Replaces the legacy curated-AST direct construction with the + behavioral round-trip through ``to_tagged`` / ``from_tagged``. + """ + from boring_semantic_layer import to_semantic_table + from boring_semantic_layer.serialization import from_tagged, to_tagged + con = ibis.duckdb.connect(":memory:") + df = pd.DataFrame({"carrier": ["AA", "AA", "UA"], "distance": [100.0, 200.0, 300.0]}) + tbl = con.create_table("flights_ms", df) -def test_validate_calc_ast_rejects_allof_binop(): - """AllOf wrapping a BinOp should fail at construction with a clear error.""" - from boring_semantic_layer.measure_scope import ( - AllOf, - BinOp, - MeasureRef, - validate_calc_ast, + st = to_semantic_table(tbl, "flights_ms").with_measures( + total_distance=lambda t: t.distance.sum(), + flight_count=lambda t: t.count(), + avg_distance=lambda t: (t.total_distance / t.flight_count).round(2), ) - - bad = AllOf(ref=BinOp("add", MeasureRef("a"), MeasureRef("b"))) - with pytest.raises(ValueError, match="Invalid AllOf.*BinOp"): - validate_calc_ast(bad, measure_name="ratio") - - -def test_validate_calc_ast_accepts_allof_aggregation_expr(): - """AllOf(AggregationExpr) is valid — handled by the rewrite pipeline.""" - from boring_semantic_layer.measure_scope import ( - AggregationExpr, - AllOf, - validate_calc_ast, + reconstructed = from_tagged(to_tagged(st)) + df_orig = st.group_by("carrier").aggregate("avg_distance").execute().sort_values("carrier") + df_round = ( + reconstructed.group_by("carrier") + .aggregate("avg_distance") + .execute() + .sort_values("carrier") + ) + pd.testing.assert_frame_equal( + df_orig.reset_index(drop=True), + df_round.reset_index(drop=True), ) - - ok = AllOf(ref=AggregationExpr(column="value", operation="sum")) - validate_calc_ast(ok, measure_name="pct") # no raise def test_calc_dtype_inference_with_inline_aggregation(): @@ -633,19 +664,20 @@ def test_typo_in_t_all_raises(): def test_substring_measure_name_does_not_trigger_typo(): """Names that are substrings of other measures should not trip the typo - detector. ``net_revenue`` referenced from a measure that also defines - ``total_net_revenue`` is legitimate, not a typo (similarity ≈ 0.79). + detector. Asking for a known measure name returns its column on the + virtual aggregated table without firing the typo path. """ - from boring_semantic_layer.measure_scope import MeasureScope - - # Probe MeasureScope directly: with measures named [net_revenue, - # total_net_revenue], looking up an unrelated name 'net_revenue' - # via getattr should NOT fire the typo path. + from boring_semantic_layer.calc_compiler import IbisCalcScope + from boring_semantic_layer.calc_analyzer import virtual_agg_table import ibis as i tbl = i.table({"col": "int64"}, name="t") - scope = MeasureScope(_tbl=tbl, _known=("net_revenue", "total_net_revenue")) - # Asking for 'net_revenue' is fine — it's a known measure. + vt = virtual_agg_table({"net_revenue": "float64", "total_net_revenue": "float64"}) + scope = IbisCalcScope( + base_tbl=tbl, + virtual_agg_tbl=vt, + known_measures=frozenset({"net_revenue", "total_net_revenue"}), + ) assert scope.net_revenue is not None # Asking for 'col' is fine — it's a column. assert scope.col is not None diff --git a/src/boring_semantic_layer/tests/test_nested_access.py b/src/boring_semantic_layer/tests/test_nested_access.py index fab56aa..f6af896 100644 --- a/src/boring_semantic_layer/tests/test_nested_access.py +++ b/src/boring_semantic_layer/tests/test_nested_access.py @@ -344,5 +344,55 @@ def test_nested_nunique(): assert result[result.session_id == 2]["unique_tags"].iloc[0] == 2 +def test_nest_then_regroup_unnests_struct_field_access(): + """Re-grouping an aggregate that contains an array-of-struct column + auto-unnests on struct field access in a follow-up aggregate. + + Regression for the docs ``query_nest_step2`` pattern: a measure + lambda like ``lambda t: t.flights.carrier.nunique()`` against a + previously-aggregated table used to fail with + ``'ArrayColumn' object has no attribute 'carrier'`` because the + post-aggregation lambda ran against the raw ibis table without the + ``ColumnScope`` wrapper that produces ``NestedAccessMarker`` values. + """ + con = ibis.duckdb.connect(":memory:") + data = pd.DataFrame( + [ + {"origin": "NYC", "carrier": "AA", "distance": 100}, + {"origin": "NYC", "carrier": "DL", "distance": 200}, + {"origin": "NYC", "carrier": "AA", "distance": 100}, + {"origin": "LAX", "carrier": "UA", "distance": 300}, + ] + ) + tbl = con.create_table("flights", data) + flights_st = to_semantic_table(tbl) + + nested = ( + flights_st.group_by("origin") + .aggregate( + flight_count=lambda t: t.count(), + nest={"flights": lambda t: t.group_by(["carrier", "distance"])}, + ) + ) + + result = ( + nested.group_by("origin") + .aggregate( + total_flights=lambda t: t.flight_count.sum(), + unique_carriers=lambda t: t.flights.carrier.nunique(), + avg_distance=lambda t: t.flights.distance.mean(), + ) + .execute() + .set_index("origin") + ) + + assert result.loc["NYC", "total_flights"] == 3 + assert result.loc["NYC", "unique_carriers"] == 2 + assert result.loc["NYC", "avg_distance"] == pytest.approx((100 + 200 + 100) / 3) + assert result.loc["LAX", "total_flights"] == 1 + assert result.loc["LAX", "unique_carriers"] == 1 + assert result.loc["LAX", "avg_distance"] == pytest.approx(300.0) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/src/boring_semantic_layer/tests/test_xorq_string_serialization.py b/src/boring_semantic_layer/tests/test_xorq_string_serialization.py index 4542721..337a121 100644 --- a/src/boring_semantic_layer/tests/test_xorq_string_serialization.py +++ b/src/boring_semantic_layer/tests/test_xorq_string_serialization.py @@ -487,6 +487,38 @@ def test_serialize_resolver_case_expr(): assert resolved.execute() == 1 +def test_serialize_resolver_item_subscript_roundtrips_and_hashes(): + """``_["flights.flight_count"]`` round-trips and the rebuilt resolver is hashable. + + The deserializer rebuilds resolvers via ``object.__new__`` + + ``__setattr__`` to bypass FrozenSlotted validation; missing + ``__precomputed_hash__`` would surface as ``AttributeError`` the + first time a deserialized resolver is hashed (e.g. when used as a + dict key inside ibis op replacement). Calc measures that look up + prefixed names via subscript depend on this round-trip. + """ + from boring_semantic_layer.utils import deserialize_resolver, serialize_resolver + + from xorq.vendor.ibis import _ + from xorq.vendor.ibis.common.deferred import Deferred + + d = _["flights.flight_count"] + data = serialize_resolver(d._resolver) + assert data[0] == "item" + + r = deserialize_resolver(data) + d2 = Deferred(r) + assert repr(d2) == repr(d) + + # Hashing the rebuilt resolver must not raise. ``hash(r)`` exercises + # the path that surfaces missing ``__precomputed_hash__``. + hash(r) + # An equal resolver built fresh (via ``__init__``) should hash to + # the same value — proving the precomputed hash matches normal + # construction, not just any arbitrary value. + assert hash(r) == hash(d._resolver) + + def test_serialize_resolver_ifelse(): """xo.ifelse(_.distance < 200, 1, 0).sum() round-trips.""" import xorq.api as xo diff --git a/src/boring_semantic_layer/utils.py b/src/boring_semantic_layer/utils.py index b83bba5..1c67f66 100644 --- a/src/boring_semantic_layer/utils.py +++ b/src/boring_semantic_layer/utils.py @@ -294,6 +294,7 @@ def serialize_resolver(resolver) -> tuple: Attr, BinaryOperator, Call, + Item, Just, JustUnhashable, Mapping as MappingResolver, @@ -333,6 +334,9 @@ def serialize_resolver(resolver) -> tuple: if isinstance(resolver, Attr): return ("attr", serialize_resolver(resolver.obj), serialize_resolver(resolver.name)) + if isinstance(resolver, Item): + return ("item", serialize_resolver(resolver.obj), serialize_resolver(resolver.name)) + if isinstance(resolver, Call): func_tuple = serialize_resolver(resolver.func) args_tuple = tuple(serialize_resolver(a) for a in resolver.args) @@ -398,12 +402,29 @@ def _resolve_qualname(module_obj, qualname: str): return obj +def _finalize_frozen_slotted(obj, *fields) -> None: + """Set ``__precomputed_hash__`` on a FrozenSlotted built via ``object.__new__``. + + xorq's vendored ibis FrozenSlotted base implements ``__hash__`` by + returning a precomputed value that ``__init__`` would normally set + via ``hash((cls, tuple(field_values)))``. When we bypass + ``__init__`` to skip validation during deserialization we must + mirror that exactly — note the inner ``tuple(...)`` wrap, which is + significant: ``hash((cls, *fields))`` produces a different value. + Without this the rebuilt resolver raises ``AttributeError`` the + first time it is hashed (e.g. as a key in ``op.replace`` + substitutions). + """ + object.__setattr__(obj, "__precomputed_hash__", hash((type(obj), tuple(fields)))) + + def deserialize_resolver(data: tuple): """Reconstruct a Resolver tree from a nested-tuple representation.""" from ._xorq import ( Attr, BinaryOperator, Call, + Item, Just, Mapping as MappingResolver, Sequence, @@ -434,8 +455,18 @@ def deserialize_resolver(data: tuple): attr = object.__new__(Attr) object.__setattr__(attr, "obj", obj_resolver) object.__setattr__(attr, "name", name_resolver) + _finalize_frozen_slotted(attr, obj_resolver, name_resolver) return attr + case ("item", obj_data, name_data): + obj_resolver = deserialize_resolver(obj_data) + name_resolver = deserialize_resolver(name_data) + item = object.__new__(Item) + object.__setattr__(item, "obj", obj_resolver) + object.__setattr__(item, "name", name_resolver) + _finalize_frozen_slotted(item, obj_resolver, name_resolver) + return item + case ("call", func_data, args_data, kwargs_data): func_resolver = deserialize_resolver(func_data) args_resolvers = tuple(deserialize_resolver(a) for a in args_data) @@ -447,6 +478,7 @@ def deserialize_resolver(data: tuple): object.__setattr__(call, "func", func_resolver) object.__setattr__(call, "args", args_resolvers) object.__setattr__(call, "kwargs", kwargs_resolvers) + _finalize_frozen_slotted(call, func_resolver, args_resolvers, kwargs_resolvers) return call case ("binop", op_name, left_data, right_data): @@ -459,6 +491,7 @@ def deserialize_resolver(data: tuple): object.__setattr__(binop, "func", func) object.__setattr__(binop, "left", left) object.__setattr__(binop, "right", right) + _finalize_frozen_slotted(binop, func, left, right) return binop case ("unop", op_name, arg_data): @@ -469,6 +502,7 @@ def deserialize_resolver(data: tuple): unop = object.__new__(UnaryOperator) object.__setattr__(unop, "func", func) object.__setattr__(unop, "arg", arg) + _finalize_frozen_slotted(unop, func, arg) return unop case ("seq", type_name, items_data): @@ -477,6 +511,7 @@ def deserialize_resolver(data: tuple): seq = object.__new__(Sequence) object.__setattr__(seq, "typ", typ) object.__setattr__(seq, "values", values) + _finalize_frozen_slotted(seq, typ, values) return seq case ("map", type_name, items_data): @@ -488,6 +523,7 @@ def deserialize_resolver(data: tuple): mapping = object.__new__(MappingResolver) object.__setattr__(mapping, "typ", typ) object.__setattr__(mapping, "values", values) + _finalize_frozen_slotted(mapping, typ, values) return mapping case _: