From d8277d138723c600d17b6ec90131ae89295e75c9 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 21 Apr 2026 13:35:08 +0200 Subject: [PATCH 1/3] adding api doc and cleanup of Shape and related concepts - Add API doc to all classes - Reorder concepts in Axis - Add Label instance to axis - Make Axis a concrete class --- core/src/main/scala/dimwit/tensor/Axis.scala | 55 ++++++++----- core/src/main/scala/dimwit/tensor/Shape.scala | 77 ++++++++++++++++++- .../src/main/scala/dimwit/tensor/Tensor.scala | 2 +- 3 files changed, 108 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/dimwit/tensor/Axis.scala b/core/src/main/scala/dimwit/tensor/Axis.scala index 0972cee..be9d493 100644 --- a/core/src/main/scala/dimwit/tensor/Axis.scala +++ b/core/src/main/scala/dimwit/tensor/Axis.scala @@ -4,31 +4,44 @@ import dimwit.|*| import scala.compiletime.{constValue, erasedValue, summonInline} -case class AxisExtent[T](axis: Axis[T], size: Int): - def *[T2](other: AxisExtent[T2]): AxisExtent[T |*| T2] = - AxisExtent(Axis[T |*| T2], size * other.size) - -// Axis selectors for indexing operations +/** Instances of this class represent an axis in a tensor with a specific label `L`. + * Axis objects are used whenever an axis needs to be selected at the value level, + * such as when indexing into a tensor or defining the shape of a tensor. + */ +final class Axis[L: Label]: + def extent(size: Int): AxisExtent[L] = AxisExtent(this, size) + def ->(size: Int): AxisExtent[L] = this.extent(size) + def at(index: Int): AxisAtIndex[L] = AxisAtIndex(this, index) + def at(range: Range): AxisAtRange[L] = AxisAtRange(this, range) + def at(indices: Seq[Int]): AxisAtIndices[L] = AxisAtIndices(this, indices) + def at(index: Tensor0[Int]): AxisAtTensorIndex[L] = AxisAtTensorIndex(this, index) + def as[U](newAxis: Axis[U]): (Axis[L], Axis[U]) = (this, newAxis) + +/** Represents the extent of an axis, which is a combination of an Axis and its size. */ +case class AxisExtent[L: Label](axis: Axis[L], size: Int): + + /** Combines this AxisExtent with another AxisExtent to create a new AxisExtent that represents the combined axes and their sizes. + * + * @param other The other AxisExtent to combine with this one. + * @return A new AxisExtent representing the combined axes and their sizes. + */ + def *[L2: Label](other: AxisExtent[L2]): AxisExtent[L |*| L2] = + AxisExtent(Axis[L |*| L2], size * other.size) + +/** Trait hierarchy to represent different ways to select an axis in a tensor, such as by index, range, or specific indices. + */ sealed trait AxisSelector[L]: def axis: Axis[L] +/** Represent an axis selection by a single index. + */ case class AxisAtIndex[L](axis: Axis[L], index: Int) extends AxisSelector[L] -case class AxisAtRange[L](axis: Axis[L], range: Range) extends AxisSelector[L] -case class AxisAtIndices[L](axis: Axis[L], indices: Seq[Int]) extends AxisSelector[L] -case class AxisAtTensorIndex[L](axis: Axis[L], index: Tensor0[Int]) extends AxisSelector[L] -object Axis: - - def apply[A]: Axis[A] = new AxisImpl[A]() +/** Represent an axis selection by a range of indices. */ +case class AxisAtRange[L](axis: Axis[L], range: Range) extends AxisSelector[L] -/** Represents an axis with A. This maps the type-level to a runtime representation. */ -sealed trait Axis[A]: - def extent(size: Int): AxisExtent[A] = AxisExtent(this, size) - def ->(size: Int): AxisExtent[A] = this.extent(size) - def at(index: Int): AxisAtIndex[A] = AxisAtIndex(this, index) - def at(range: Range): AxisAtRange[A] = AxisAtRange(this, range) - def at(indices: Seq[Int]): AxisAtIndices[A] = AxisAtIndices(this, indices) - def at(index: Tensor0[Int]): AxisAtTensorIndex[A] = AxisAtTensorIndex(this, index) - def as[U](newAxis: Axis[U]): (Axis[A], Axis[U]) = (this, newAxis) +/** Represent an axis selection by a sequence of specific indices. */ +case class AxisAtIndices[L](axis: Axis[L], indices: Seq[Int]) extends AxisSelector[L] -class AxisImpl[A] extends Axis[A] +/* Represent an axis selection by a tensor containing indices. This allows for dynamic indexing based on the contents of the tensor. */ +case class AxisAtTensorIndex[L](axis: Axis[L], index: Tensor0[Int]) extends AxisSelector[L] diff --git a/core/src/main/scala/dimwit/tensor/Shape.scala b/core/src/main/scala/dimwit/tensor/Shape.scala index f806efe..ac1227d 100644 --- a/core/src/main/scala/dimwit/tensor/Shape.scala +++ b/core/src/main/scala/dimwit/tensor/Shape.scala @@ -4,18 +4,39 @@ import scala.collection.View.Empty import scala.annotation.publicInBinary import ShapeTypeHelpers.AxisIndex import dimwit.tensor.{Labels, Label} +import scala.Tuple.Union -/** Represents the (typed) Shape of a tensor with runtime labels +/** Represents the Shape of a tensor. Conceptually, a shape is an order list of AxisExtents, + * where each AxisExtent is a label associated with a size. */ final case class Shape[T <: Tuple: Labels] @publicInBinary private ( val dimensions: List[Int] ): + /** Returns the labels of the shape + */ lazy val labels: List[String] = summon[Labels[T]].names + /** Returns the rank of the shape, which is the number of dimensions. + */ def rank: Int = dimensions.size + + /** Returns the total number of elements in the tensor represented by this shape. + */ def size: Int = dimensions.foldLeft(1)((acc, d) => acc * d.asInstanceOf[Int]) - def extent[L](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = AxisExtent(axis, this(axis)) + + /** Returns the extent of the specified axis. + * + * @param axis The axis for which to retrieve the extent. + * @return The extent of the specified axis. + */ + def extent[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = AxisExtent(axis, this(axis)) + + /** Returns the size of the specified axis. + * + * @param axis The axis for which to retrieve the size. + * @return The size of the specified axis. + */ def apply[L](axis: Axis[L])(using ev: AxisIndex[T, L]): Int = this.dimensions(ev.index) override def toString: String = @@ -36,38 +57,86 @@ object Shape: case EmptyTuple => EmptyTuple case AxisExtent[l] *: tail => l *: ExtractLabels[tail] + /** Creates an empty shape with no dimensions. + * @return An empty shape with no dimensions. + */ def empty: Shape[EmptyTuple] = new Shape(Nil) + /** Creates a shape with a single dimension. + * + * @param dim The extent of the single dimension. + * @return A shape with one dimension. + */ def apply[L: Label](dim: AxisExtent[L]): Shape[L *: EmptyTuple] = Shape.fromTuple(Tuple1(dim)) - def apply[A <: Tuple](args: A)(using n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] = + /** Create a shape from a tuple of AxisExtents. The labels of the shape are extracted from the AxisExtents in the tuple. + * + * @param args A tuple of AxisExtents from which to create the shape. + * @return A shape with dimensions corresponding to the sizes of the AxisExtents in the + */ + def apply[A <: Tuple](args: A)(using ev: Union[A] <:< AxisExtent[?], n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] = Shape.fromTuple(args) - def fromTuple[A <: Tuple](args: A)(using n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] = + /** Create a shape from a tuple of AxisExtents. The labels of the shape are extracted from the AxisExtents in the tuple. + * @param args A tuple of AxisExtents from which to create the shape. + * @return A shape with dimensions corresponding to the sizes of the AxisExtents in the tuple. + */ + def fromTuple[A <: Tuple](args: A)(using ev: Union[A] <:< AxisExtent[?], n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] = val sizes = args.toList.collect: case ae: AxisExtent[?] => ae.size new Shape(sizes) private[tensor] def fromSeq[T <: Tuple: Labels](dims: Seq[Int]) = new Shape[T](dims.toList) +/** Type alias for an empty shape (rank 0). */ type Shape0 = Shape[EmptyTuple] + +/** Type alias for a shape with one dimension (rank 1). */ type Shape1[L] = Shape[L *: EmptyTuple] + +/** Type alias for a shape with two dimensions (rank 2). */ type Shape2[L1, L2] = Shape[L1 *: L2 *: EmptyTuple] + +/** Type alias for a shape with three dimensions (rank 3). */ type Shape3[L1, L2, L3] = Shape[L1 *: L2 *: L3 *: EmptyTuple] +/** Companion object for Shape0, providing a convenient way to create an empty shape. */ val Shape0 = Shape.empty +/** Companion object for Shape1, providing a convenient way to create a shape with one dimension. */ object Shape1: + + /** Creates a shape with a single dimension. + * + * @param dim The extent of the single dimension. + * @return A shape with one dimension. + */ def apply[L: Label](dim: AxisExtent[L]): Shape[Tuple1[L]] = Shape(dim) +/** Companion object for Shape2, providing a convenient way to create a shape with two dimensions. */ object Shape2: + + /** Creates a shape with two dimensions. + * + * @param dim1 The extent of the first dimension. + * @param dim2 The extent of the second dimension. + * @return A shape with two dimensions. + */ def apply[L1: Label, L2: Label]( dim1: AxisExtent[L1], dim2: AxisExtent[L2] ): Shape[(L1, L2)] = Shape.fromTuple(dim1, dim2) object Shape3: + + /** Creates a shape with three dimensions. + * + * @param dim1 The extent of the first dimension. + * @param dim2 The extent of the second dimension. + * @param dim3 The extent of the third dimension. + * @return A shape with three dimensions. + */ def apply[L1: Label, L2: Label, L3: Label]( dim1: AxisExtent[L1], dim2: AxisExtent[L2], diff --git a/core/src/main/scala/dimwit/tensor/Tensor.scala b/core/src/main/scala/dimwit/tensor/Tensor.scala index bc5dc86..e81504e 100644 --- a/core/src/main/scala/dimwit/tensor/Tensor.scala +++ b/core/src/main/scala/dimwit/tensor/Tensor.scala @@ -48,7 +48,7 @@ class Tensor[T <: Tuple: Labels, V] private[dimwit] ( s"TracerTensor(${shape.toString})" case _ => jaxValue.toString() - def extent[L](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = + def extent[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = shape.extent(axis) private val jaxTypeName: String = py.Dynamic.global.`type`(jaxValue).`__name__`.as[String] From 57375a062483096cff3c33643fef39a2ae228b0b Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 21 Apr 2026 15:06:45 +0200 Subject: [PATCH 2/3] regenerate AGENTS.md --- AGENTS.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index e2d98c7..412e969 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -89,7 +89,7 @@ val badShape = Shape(Axis[UndefinedLabel] -> 10) // Ensure that all axis types repl.MdocSession.MdocApp.UndefinedLabel are defined with 'derives Label' (e.g. 'trait T derives Label') // // val badShape = Shape(Axis[UndefinedLabel] -> 10) -// ^ +// ^ ``` ```scala @@ -196,7 +196,7 @@ val noLabel = Tensor(Shape1(Axis[NoLabel] -> 2)).fill(1.0f) // Ensure that all axis types repl.MdocSession.MdocApp.NoLabel are defined with 'derives Label' (e.g. 'trait T derives Label') // // val noLabel = Tensor(Shape1(Axis[NoLabel] -> 2)).fill(1.0f) -// ^ +// ^ ``` ```scala @@ -563,11 +563,16 @@ val wrong = t.vmap(Axis[A])(wrongFunc) // error: // Not found: type A // error: +// +// An axis label was given or inferred, which does not have a Label instance. +// Ensure that all axis types are defined with 'derives Label' (e.g. 'trait T derives Label') +// +// error: // Not found: type B // error: // -// An axis label Any was given or inferred, which does not have a Label instance. -// Ensure that all axis types Any are defined with 'derives Label' (e.g. 'trait T derives Label') +// An axis label was given or inferred, which does not have a Label instance. +// Ensure that all axis types are defined with 'derives Label' (e.g. 'trait T derives Label') // // error: // Not found: type A From 6fb182f2a7293265fbd3cec694d470125fec9084 Mon Sep 17 00:00:00 2001 From: Marcel Luethi Date: Tue, 28 Apr 2026 14:43:35 +0200 Subject: [PATCH 3/3] cleanup Co-authored-by: Copilot --- core/src/main/scala/dimwit/tensor/Axis.scala | 1 + core/src/main/scala/dimwit/tensor/Shape.scala | 25 +++++++++---------- .../autodiff/FloatTensorTreeSuite.scala | 1 + .../src/main/scala/basic/Autoencoder.scala | 1 + .../complex/VariationalAutoencoder.scala | 1 + nn/src/main/scala/nn/GradientOptimizer.scala | 1 + 6 files changed, 17 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/dimwit/tensor/Axis.scala b/core/src/main/scala/dimwit/tensor/Axis.scala index 281f9cf..be80572 100644 --- a/core/src/main/scala/dimwit/tensor/Axis.scala +++ b/core/src/main/scala/dimwit/tensor/Axis.scala @@ -15,6 +15,7 @@ final class Axis[L: Label]: def at(range: Range): AxisAtRange[L] = AxisAtRange(this, range) def at(indices: Seq[Int]): AxisAtIndices[L] = AxisAtIndices(this, indices) def at(index: Tensor0[Int]): AxisAtTensorIndex[L] = AxisAtTensorIndex(this, index) + def at[I <: NonEmptyTuple](indices: I): AxisAtTupleIndices[L, I] = AxisAtTupleIndices(this, indices) def as[U](newAxis: Axis[U]): (Axis[L], Axis[U]) = (this, newAxis) /** Represents the extent of an axis, which is a combination of an Axis and its size. */ diff --git a/core/src/main/scala/dimwit/tensor/Shape.scala b/core/src/main/scala/dimwit/tensor/Shape.scala index ac1227d..dec02a9 100644 --- a/core/src/main/scala/dimwit/tensor/Shape.scala +++ b/core/src/main/scala/dimwit/tensor/Shape.scala @@ -4,7 +4,6 @@ import scala.collection.View.Empty import scala.annotation.publicInBinary import ShapeTypeHelpers.AxisIndex import dimwit.tensor.{Labels, Label} -import scala.Tuple.Union /** Represents the Shape of a tensor. Conceptually, a shape is an order list of AxisExtents, * where each AxisExtent is a label associated with a size. @@ -53,7 +52,7 @@ final case class Shape[T <: Tuple: Labels] @publicInBinary private ( object Shape: - private[tensor] type ExtractLabels[Args <: Tuple] <: Tuple = Args match + private[tensor] type ExtractLabels[Extents <: Tuple] <: Tuple = Extents match case EmptyTuple => EmptyTuple case AxisExtent[l] *: tail => l *: ExtractLabels[tail] @@ -70,20 +69,20 @@ object Shape: def apply[L: Label](dim: AxisExtent[L]): Shape[L *: EmptyTuple] = Shape.fromTuple(Tuple1(dim)) - /** Create a shape from a tuple of AxisExtents. The labels of the shape are extracted from the AxisExtents in the tuple. + /** Create a shape from a tuple of AxisExtents. * - * @param args A tuple of AxisExtents from which to create the shape. - * @return A shape with dimensions corresponding to the sizes of the AxisExtents in the + * @param axisExtends A tuple of AxisExtents from which to create the shape. + * @return A shape with dimensions corresponding to the sizes of the AxisExtents in the tuple. */ - def apply[A <: Tuple](args: A)(using ev: Union[A] <:< AxisExtent[?], n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] = - Shape.fromTuple(args) + def apply[Extents <: Tuple](axisExtends: Extents)(using n: Labels[ExtractLabels[Extents]]): Shape[ExtractLabels[Extents]] = + fromTuple(axisExtends) - /** Create a shape from a tuple of AxisExtents. The labels of the shape are extracted from the AxisExtents in the tuple. - * @param args A tuple of AxisExtents from which to create the shape. + /** Create a shape from a tuple of AxisExtents. + * @param axisExtents A tuple of AxisExtents from which to create the shape. * @return A shape with dimensions corresponding to the sizes of the AxisExtents in the tuple. */ - def fromTuple[A <: Tuple](args: A)(using ev: Union[A] <:< AxisExtent[?], n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] = - val sizes = args.toList.collect: + def fromTuple[Extents <: Tuple](axisExtents: Extents)(using n: Labels[ExtractLabels[Extents]]): Shape[ExtractLabels[Extents]] = + val sizes = axisExtents.toList.collect: case ae: AxisExtent[?] => ae.size new Shape(sizes) @@ -126,7 +125,7 @@ object Shape2: def apply[L1: Label, L2: Label]( dim1: AxisExtent[L1], dim2: AxisExtent[L2] - ): Shape[(L1, L2)] = Shape.fromTuple(dim1, dim2) + ): Shape[(L1, L2)] = Shape.fromTuple((dim1, dim2)) object Shape3: @@ -141,4 +140,4 @@ object Shape3: dim1: AxisExtent[L1], dim2: AxisExtent[L2], dim3: AxisExtent[L3] - ): Shape[(L1, L2, L3)] = Shape.fromTuple(dim1, dim2, dim3) + ): Shape[(L1, L2, L3)] = Shape.fromTuple((dim1, dim2, dim3)) diff --git a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala index 71b8995..082d98d 100644 --- a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala @@ -2,6 +2,7 @@ package dimwit.autodiff import dimwit.* import dimwit.Conversions.given +import dimwit.autodiff.FloatTree.* import dimwit.autodiff.FloatTree.ops.* import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers diff --git a/examples/src/main/scala/basic/Autoencoder.scala b/examples/src/main/scala/basic/Autoencoder.scala index 6de7bb3..395b375 100644 --- a/examples/src/main/scala/basic/Autoencoder.scala +++ b/examples/src/main/scala/basic/Autoencoder.scala @@ -13,6 +13,7 @@ import dimwit.jax.Jax import nn.ActivationFunctions.sigmoid import dimwit.random.Random.Key import dimwit.autodiff.* +import dimwit.autodiff.FloatTree.* import examples.dataset.MNISTLoader import MNISTLoader.{Sample, TrainSample, TestSample, Height, Width} diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index 86920ef..2e7785a 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -4,6 +4,7 @@ import examples.timed import dimwit.* import dimwit.autodiff.* +import dimwit.autodiff.FloatTree.* import dimwit.Conversions.given import dimwit.stats.Normal import dimwit.random.Random diff --git a/nn/src/main/scala/nn/GradientOptimizer.scala b/nn/src/main/scala/nn/GradientOptimizer.scala index b31195f..5237e1b 100644 --- a/nn/src/main/scala/nn/GradientOptimizer.scala +++ b/nn/src/main/scala/nn/GradientOptimizer.scala @@ -3,6 +3,7 @@ package nn import dimwit.* import dimwit.Conversions.given import dimwit.autodiff.FloatTree.ops.* +import dimwit.autodiff.FloatTree.* import dimwit.autodiff.* import dimwit.jax.Jax import dimwit.jax.Jit