[Feat] Complete BasisAttr support in IntTupleBuilder (#574)#605
[Feat] Complete BasisAttr support in IntTupleBuilder (#574)#605jhinpan wants to merge 6 commits into
Conversation
Extend IntTupleBuilder<IntTupleAttr> to accept scaled-basis (BasisAttr / CuTe E<I>) stride leaves where the algebra is well-defined, and add a Python construction surface for them. - div(Basis, Int): divide the coefficient, keep modes (reuse intSafeDiv). - eq / ne: compare basis leaves by (coefficient, modes); a basis monomial never equals a plain integer leaf. - Remaining ops (mod, lt/le/gt/ge, min/max, shapeDiv, logical*, swizzle) stay integer-only -- basis is either ill-posed there or structurally unreachable -- but now carry precise assert messages instead of bare leaf-int asserts. - Python: fx.E(mode, *, value=1) and fx.make_basis_stride(value, modes), wired through the int-tuple builder via a __fly_basis__ marker. div(Basis, Basis) is reachable for rank>=3 identity layouts via complement; it has no quotient mode, so it is rejected with a named assert rather than miscomputing a stride. Located op-layer diagnostics are out of scope here (IntTupleBuilder carries no Location) and belong to the sibling issue #583. Tests: tests/mlir/LayoutAlgebra/basis.mlir (div + identity logical_divide), equal folding in tests/mlir/Transforms/layout_lowering.mlir, and Python surface + pipeline tests in tests/unit/test_layout_algebra.py. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds first-class support for scaled-basis (CuTe E<I>) stride leaves end-to-end (Python authoring → IR printing → layout algebra/lowering), with tests and docs demonstrating identity layouts and equality folding.
Changes:
- Added Python APIs to construct scaled-basis stride leaves (
fx.E) and flat basis stride tuples (fx.make_basis_stride). - Extended Python→C++ IntTuple attribute building to accept basis leaves via duck-typing.
- Updated IntTuple utilities + MLIR/Python tests/docs to cover basis stride algebra and lowering behavior.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/test_layout_algebra.py | Adds Python-side IR and lowering tests for fx.E/basis identity size. |
| tests/mlir/Transforms/layout_lowering.mlir | Adds FileCheck tests for fly.equal folding with basis leaves. |
| tests/mlir/LayoutAlgebra/basis.mlir | New MLIR tests for basis leaves through int_tuple_div and logical_divide. |
| python/flydsl/expr/primitive.py | Exposes E + make_basis_stride to construct basis strides from Python. |
| lib/Dialect/Fly/Utils/IntTupleUtils.cpp | Extends div/eq/ne semantics to handle basis leaves; improves asserts. |
| lib/Bindings/Python/FlyExtension.cpp | Adds basis-leaf handling to IntTupleAttrBuilder using __fly_basis__ marker. |
| docs/layout_system_guide.md | Documents constructing basis-strided layouts with fx.E. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # fx.E(0), fx.E(1) produce a (1E0, 1E1) stride, identical to make_identity_layout. | ||
| assert "!fly.layout<(4,8):(1E0,1E1)>" in ir | ||
| # make_basis_stride(1, (0, 1)) yields the same flat basis stride. | ||
| assert "!fly.int_tuple<(1E0,1E1)>" in ir | ||
| # value and multi-mode forms: fx.E(0, 1, value=2) -> 2E0E1. | ||
| assert "!fly.int_tuple<(2E0E1)>" in ir |
There was a problem hiding this comment.
Done in 7eab7d4 — check() now matches against "".join(ir.split()), so the assertions no longer depend on the printer's spacing around commas/colons.
| } else if (nb::hasattr(args, "__fly_basis__")) { | ||
| // Scaled-basis stride leaf (e.g. fx.E(0) -> 1E0), duck-typed on __fly_basis__. | ||
| int32_t value = nb::cast<int32_t>(args.attr("value")); | ||
| SmallVector<int32_t> modes; | ||
| for (auto mode : nb::cast<nb::list>(args.attr("modes"))) { | ||
| modes.push_back(nb::cast<int32_t>(mode)); | ||
| } | ||
| return IntTupleAttr::get(BasisAttr::get(IntAttr::getStatic(ctx, value), modes)); |
There was a problem hiding this comment.
Both fixed in 7eab7d4. The branch now gates on a truthy __fly_basis__ via PyObject_IsTrue, so an object that sets it to False is no longer mistaken for a basis leaf — and a truthy non-bool marker still works (a plain nb::cast<bool> throws on anything that isn't the Py_True/Py_False singleton). Modes are read by iterating the object directly, so any sequence/iterable (tuple, list, …) is accepted rather than only nb::list.
| def E(*modes, value=1): | ||
| """Build a scaled-basis stride leaf. | ||
|
|
||
| Examples: | ||
| E(0) -> 1E0 | ||
| E(1) -> 1E1 | ||
| E(0, value=2) -> 2E0 | ||
| E(0, 1) -> 1E0E1 (a single leaf along modes 0 and 1) | ||
| """ | ||
| if not modes: | ||
| raise ValueError("E requires at least one mode") | ||
| return _BasisElem(value, list(modes)) |
There was a problem hiding this comment.
Done in 7eab7d4. E now validates in Python via _as_int32: value and every mode must be integer-like (Python int or a NumPy integer scalar, via operator.index) and within int32 range, raising a clear TypeError/ValueError instead of deferring to the nanobind cast. Covered by the new test_basis_E_rejects_invalid_modes.
| assert(lhs.isLeaf() && rhs.isLeaf()); | ||
| if (lhs.isLeafInt() && rhs.isLeafInt()) { | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() == rhs.getLeafAsInt()); | ||
| } | ||
| if (lhs.isLeafBasis() && rhs.isLeafBasis()) { | ||
| return IntTupleAttr::get(lhs.getLeafAsBasis() == rhs.getLeafAsBasis()); | ||
| } | ||
| // A scalar leaf and a basis monomial never coincide. | ||
| return materializeConstantLeaf(0); |
There was a problem hiding this comment.
Switched both the eq and ne mismatch branches to IntTupleAttr::getLeafStatic(ctx, 0/1) in 7eab7d4 — an explicit static leaf, consistent with the sibling branches and free of the builder method. It's the same attribute materializeConstantLeaf(0/1) produced, just constructed directly.
| """ | ||
| if not modes: | ||
| raise ValueError("E requires at least one mode") | ||
| return _BasisElem(value, list(modes)) |
There was a problem hiding this comment.
Confirmed bug: E accepts negative modes, but the BasisAttr assembly parser cannot round-trip them. fx.make_stride(fx.E(-1)) prints !fly.int_tuple<(1E-1)>, and feeding that IR back to fly-opt fails to parse at E-1 (expected ')'). Please validate modes here as non-negative int32 values before constructing _BasisElem (and mirror that in the binding if the duck-typed __fly_basis__ hook is meant to be public).
There was a problem hiding this comment.
Fixed in 7eab7d4. E now rejects negative modes (_as_int32(..., nonneg=True)), and the public __fly_basis__ binding hook mirrors it — a negative mode there throws a clear error before constructing the BasisAttr. value stays sign-unrestricted since -2E0/0E0 round-trip fine; only the mode index breaks the assembly format. Regression added in test_basis_E_rejects_invalid_modes.
| assert(lhs.isLeafInt() && rhs.isLeafInt()); | ||
| return IntTupleAttr::get(lhs.getLeafAsInt() / rhs.getLeafAsInt()); | ||
| // A basis divisor has no quotient mode (reachable for rank>=3 identity layouts | ||
| // via complement), so reject it rather than miscompute a stride. |
There was a problem hiding this comment.
The rank>=3 path described here does not currently reach this named assert. A direct user-level repro is fx.logical_divide(fx.make_identity_layout((4, 8, 2)), fx.make_layout((2, 4, 2), fx.make_stride(1, 2, 4))); it SIGFPEs during LogicalDivideOp::inferReturnTypes inside compositionImpl before producing the promised div(Basis, Basis) rejection. If rank-3 identity divide is intentionally unsupported, please add a regression for this case and guard it before the algebra can hit the integer %/division-by-zero path.
There was a problem hiding this comment.
Good catch — dug in, and it's a bit different from what my comment implied, so I split it (7eab7d4):
-
The
div()comment was imprecise.div(Basis, Basis)is reachable, but viacomplement()of a basis-strided identity layout of rank ≥ 2 —complement(make_identity_layout((4,8)))and(4,8,2)both hit the named assert. It is not reached throughlogical_divide. Reworded the comment to say so. -
The SIGFPE is a separate, pre-existing bug and isn't basis-specific.
(2,4,2):(1,2,4)is a non-tiling (overlapping) divisor: its complement has a 0-extent mode, andcompositionImplthen does0 % 0. The same divisor SIGFPEs on a non-coalescible integer layout(4,8,2):(1,100,7)too — the identity case only reaches it because basis strides never coalesce to rank-1, whereas a compact integer layout like(1,4,32)coalesces away and dodges the loop. A rank-3 identity divide with a valid tiler, e.g.(2,4,2):(1,2,8), works fine. -
Guarded it.
compositionImplnow assertsnewShapeVal != 0with a named message before the%, so a non-tiling divisor aborts with "divisor is not a tiling layout" instead of a SIGFPE. Added the rank-3 valid-divisor case tobasis.mliras a positive regression.
The recoverable op-layer diagnostic (so a bad tiler surfaces as an expected-error rather than an assert) needs a Location, which IntTupleBuilder doesn't carry — that's the #583 op-gate work, consistent with this PR's deferral of located diagnostics.
) Resolve the inline review comments: - E()/make_basis_stride: validate value and modes as int32 in Python (operator.index, so NumPy integer scalars are accepted) and reject negative modes -- the IntTuple assembly format cannot round-trip a negative E<mode> (1E-1 fails to re-parse). Mirror the non-negative mode check in the public __fly_basis__ binding hook. - FlyExtension binding: gate the basis branch on a *truthy* __fly_basis__ (PyObject_IsTrue, not a strict Py_True/Py_False cast that throws on non-bool) so a falsy marker is not mistaken for a basis; read modes from any iterable, not only a list. - IntTupleUtils eq/ne: return an explicit getLeafStatic(0/1) leaf instead of materializeConstantLeaf, consistent with the sibling branches. - IntTupleUtils div(): correct the comment -- a basis divisor is reached via complement() of a rank>=2 identity layout, not logical_divide. - compositionImpl: a non-tiling divisor yields a 0-extent complement mode; assert before the % so it fails with a named message instead of a SIGFPE. Pre-existing and not basis-specific (a non-coalescible integer layout with the same divisor crashes identically). Add a rank-3 identity logical_divide regression with a valid tiler. - test_layout_algebra: normalize IR whitespace before matching; add an E() input-validation regression. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
| } else if (args.is_none()) { | ||
| return IntTupleAttr::getLeafNone(ctx); | ||
| } else if (nb::hasattr(args, "__fly_basis__") && | ||
| PyObject_IsTrue(nb::object(args.attr("__fly_basis__")).ptr()) == 1) { |
There was a problem hiding this comment.
This still leaves a crash path for the falsy-marker case. With the current branch, a custom object that exposes __fly_basis__ = False is no longer parsed as a basis leaf, but passing it through fx.make_stride() segfaults the Python process instead of raising the normal invalid-argument error.
Minimal repro against this PR head:
from flydsl import ir
from flydsl import language as fx
class FalsyMarker:
__fly_basis__ = False
value = 1
modes = [0]
with ir.Context() as ctx, ir.Location.unknown(ctx):
fx.make_stride(FalsyMarker())This exits with SIGSEGV/139 in my local build. Since this code explicitly handles truthy-vs-falsy markers, the falsy marker should be covered by a regression and should fail safely, for example by explicitly rejecting it before the generic _CAPIPtr fallback or otherwise making that fallback safe for arbitrary Python objects.
There was a problem hiding this comment.
Fixed in 7255fff. Root cause was in the generic fallback's error message: it called nb::type_name(args), and nb_type_name expects a type object — it casts to PyTypeObject* and calls PyType_GetName/PyType_HasFeature. Passing an instance (your FalsyMarker(), or any non-stride object) reinterprets it as a type and segfaults (on Python < 3.11 the __name__ lookup on the instance also yields a NULL that then feeds PyUnicode_FromFormat). Switched to Py_TYPE(args.ptr())->tp_name, which is always valid, so the fallback now fails safely for any object:
>>> fx.make_stride(FalsyMarker()) # __fly_basis__ = False
ValueError: Expected I32, got: FalsyMarker
>>> fx.make_stride(object())
ValueError: Expected I32, got: object
It was pre-existing (any invalid stride element hit the same path — a bare object() crashed identically), but the falsy-marker handling is what made it reachable, so good catch. Regression added in test_make_stride_rejects_non_stride_object, covering both the falsy marker and a bare object().
…605 review) The generic fallback in IntTupleAttrBuilder reported the rejected value's type via nb::type_name(args), but nb_type_name expects a *type* object -- passing an instance reinterprets it as a PyTypeObject and segfaults (and on Python <3.11 the __name__ lookup on the instance yields a NULL that feeds PyUnicode_FromFormat). Use Py_TYPE(args.ptr())->tp_name, which is always valid, so any non-stride object (e.g. an object exposing a falsy __fly_basis__ marker) raises a clean ValueError. Add a regression. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Closes #574 (the implementable core — see the scope note in the issue thread).
What
Completes
BasisAttr(scaled-basis / CuTeE<I>) leaf support inIntTupleBuilder<IntTupleAttr>, finishing what #195 started (it extended onlymul/safeDiv/ceilDiv), plus a Python construction surface.Implemented — basis is well-defined and reachable:
div(Basis, Int)→(value/k)·E: divide the coefficient, keep modes. Reuses the existingintSafeDiv(BasisAttr, IntAttr); no new scalar overload.eq/ne: compare basis leaves by(coefficient, modes); a basis monomial never equals a plain integer leaf.Kept integer-only, now with precise assert messages (
mod,lt/le/gt/ge,min/max,shapeDiv,logicalAnd/Or/Not,applySwizzle,applyCoordSwizzle): for each, basis is either ill-posed (no total order on free-module generators; bit-XOR on a symbol; symmetric integer divisibility) or structurally unreachable. Implementing formulas there would be untested dead code — per-op reasoning is in the #574 thread.Python:
fx.E(mode, *, value=1)andfx.make_basis_stride(value, modes), wired through the int-tuple builder via a__fly_basis__duck-type marker.fx.make_layout(fx.make_shape(4, 8), fx.make_stride(fx.E(0), fx.E(1)))is byte-identical tomake_identity_layout((4, 8)).Two notes
div(Basis, Basis)is reachable for rank≥3 identity layouts viacomplement(LayoutUtils.h:lastStridebecomes basis aftermul(minStride, shape), so the nextdiv(minStride, lastStride)is basis-by-basis). It has no quotient mode, so it's rejected with a named assert rather than miscomputing a stride.emitOpErroris out of scope here.IntTupleBuilder<IntTupleAttr>carries only anMLIRContext*, noLocation, so located diagnostics aren't reachable at this layer — asserts are the deliberate invariant (theIntTupleValueAdaptortwin asserts too). The user-facing op-layer gate belongs to the sibling issue [Compiler]: Adding verification to the layout-algebra ops — what's the right approach? #583. This PR and [Compiler]: Adding verification to the layout-algebra ops — what's the right approach? #583 touch disjoint files (IntTupleUtils.cppasserts vsFlyOps.cppinference).Verification
bash scripts/build.shwith assertions live — clean.tests/mlirFileCheck tests pass, incl. the newLayoutAlgebra/basis.mlir(div + identitylogical_divideinference) andequalfolding appended toTransforms/layout_lowering.mlir.tests/unit/test_layout_algebra.py: 24 passed / 1 pre-existing skip, incl. the newfx.Esurface and a fullconvert-fly-to-rocdlpipeline test; broader unit sweep 398 passed, 0 failed.🤖 Generated with Claude Code