diff --git a/core/src/main/scala/dimwit/tensor/Axis.scala b/core/src/main/scala/dimwit/tensor/Axis.scala index 9414a59..be80572 100644 --- a/core/src/main/scala/dimwit/tensor/Axis.scala +++ b/core/src/main/scala/dimwit/tensor/Axis.scala @@ -4,34 +4,48 @@ import dimwit.|*| import scala.compiletime.{constValue, erasedValue, summonInline} -case class AxisExtent[T](axis: Axis[T], size: Int): - def *[T2](other: AxisExtent[T2]): AxisExtent[T |*| T2] = - AxisExtent(Axis[T |*| T2], size * other.size) - -// Axis selectors for indexing operations +/** Instances of this class represent an axis in a tensor with a specific label `L`. + * Axis objects are used whenever an axis needs to be selected at the value level, + * such as when indexing into a tensor or defining the shape of a tensor. + */ +final class Axis[L: Label]: + def extent(size: Int): AxisExtent[L] = AxisExtent(this, size) + def ->(size: Int): AxisExtent[L] = this.extent(size) + def at(index: Int): AxisAtIndex[L] = AxisAtIndex(this, index) + def at(range: Range): AxisAtRange[L] = AxisAtRange(this, range) + def at(indices: Seq[Int]): AxisAtIndices[L] = AxisAtIndices(this, indices) + def at(index: Tensor0[Int]): AxisAtTensorIndex[L] = AxisAtTensorIndex(this, index) + def at[I <: NonEmptyTuple](indices: I): AxisAtTupleIndices[L, I] = AxisAtTupleIndices(this, indices) + def as[U](newAxis: Axis[U]): (Axis[L], Axis[U]) = (this, newAxis) + +/** Represents the extent of an axis, which is a combination of an Axis and its size. */ +case class AxisExtent[L: Label](axis: Axis[L], size: Int): + + /** Combines this AxisExtent with another AxisExtent to create a new AxisExtent that represents the combined axes and their sizes. + * + * @param other The other AxisExtent to combine with this one. + * @return A new AxisExtent representing the combined axes and their sizes. + */ + def *[L2: Label](other: AxisExtent[L2]): AxisExtent[L |*| L2] = + AxisExtent(Axis[L |*| L2], size * other.size) + +/** Trait hierarchy to represent different ways to select an axis in a tensor, such as by index, range, or specific indices. + */ sealed trait AxisSelector[L]: def axis: Axis[L] +/** Represent an axis selection by a single index. + */ case class AxisAtIndex[L](axis: Axis[L], index: Int) extends AxisSelector[L] + +/** Represent an axis selection by a range of indices. */ case class AxisAtRange[L](axis: Axis[L], range: Range) extends AxisSelector[L] + +/** Represent an axis selection by a sequence of specific indices. */ case class AxisAtIndices[L](axis: Axis[L], indices: Seq[Int]) extends AxisSelector[L] -case class AxisAtTupleIndices[L, I <: NonEmptyTuple](axis: Axis[L], indices: I) extends AxisSelector[L] +/* Represent an axis selection by a tensor containing indices. This allows for dynamic indexing based on the contents of the tensor. */ case class AxisAtTensorIndex[L](axis: Axis[L], index: Tensor0[Int]) extends AxisSelector[L] -object Axis: - - def apply[A]: Axis[A] = new AxisImpl[A]() - -/** Represents an axis with A. This maps the type-level to a runtime representation. */ -sealed trait Axis[A]: - def extent(size: Int): AxisExtent[A] = AxisExtent(this, size) - def ->(size: Int): AxisExtent[A] = this.extent(size) - def at(index: Int): AxisAtIndex[A] = AxisAtIndex(this, index) - def at[I <: NonEmptyTuple](indices: I): AxisAtTupleIndices[A, I] = AxisAtTupleIndices(this, indices) - def at(range: Range): AxisAtRange[A] = AxisAtRange(this, range) - def at(indices: Seq[Int]): AxisAtIndices[A] = AxisAtIndices(this, indices) - def at(index: Tensor0[Int]): AxisAtTensorIndex[A] = AxisAtTensorIndex(this, index) - def as[U](newAxis: Axis[U]): (Axis[A], Axis[U]) = (this, newAxis) - -class AxisImpl[A] extends Axis[A] +/* Represent an axis selection by a tuple containing indices. */ +case class AxisAtTupleIndices[L, I <: NonEmptyTuple](axis: Axis[L], indices: I) extends AxisSelector[L] diff --git a/core/src/main/scala/dimwit/tensor/Shape.scala b/core/src/main/scala/dimwit/tensor/Shape.scala index f806efe..dec02a9 100644 --- a/core/src/main/scala/dimwit/tensor/Shape.scala +++ b/core/src/main/scala/dimwit/tensor/Shape.scala @@ -5,17 +5,37 @@ import scala.annotation.publicInBinary import ShapeTypeHelpers.AxisIndex import dimwit.tensor.{Labels, Label} -/** Represents the (typed) Shape of a tensor with runtime labels +/** Represents the Shape of a tensor. Conceptually, a shape is an order list of AxisExtents, + * where each AxisExtent is a label associated with a size. */ final case class Shape[T <: Tuple: Labels] @publicInBinary private ( val dimensions: List[Int] ): + /** Returns the labels of the shape + */ lazy val labels: List[String] = summon[Labels[T]].names + /** Returns the rank of the shape, which is the number of dimensions. + */ def rank: Int = dimensions.size + + /** Returns the total number of elements in the tensor represented by this shape. + */ def size: Int = dimensions.foldLeft(1)((acc, d) => acc * d.asInstanceOf[Int]) - def extent[L](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = AxisExtent(axis, this(axis)) + + /** Returns the extent of the specified axis. + * + * @param axis The axis for which to retrieve the extent. + * @return The extent of the specified axis. + */ + def extent[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = AxisExtent(axis, this(axis)) + + /** Returns the size of the specified axis. + * + * @param axis The axis for which to retrieve the size. + * @return The size of the specified axis. + */ def apply[L](axis: Axis[L])(using ev: AxisIndex[T, L]): Int = this.dimensions(ev.index) override def toString: String = @@ -32,44 +52,92 @@ final case class Shape[T <: Tuple: Labels] @publicInBinary private ( object Shape: - private[tensor] type ExtractLabels[Args <: Tuple] <: Tuple = Args match + private[tensor] type ExtractLabels[Extents <: Tuple] <: Tuple = Extents match case EmptyTuple => EmptyTuple case AxisExtent[l] *: tail => l *: ExtractLabels[tail] + /** Creates an empty shape with no dimensions. + * @return An empty shape with no dimensions. + */ def empty: Shape[EmptyTuple] = new Shape(Nil) + /** Creates a shape with a single dimension. + * + * @param dim The extent of the single dimension. + * @return A shape with one dimension. + */ def apply[L: Label](dim: AxisExtent[L]): Shape[L *: EmptyTuple] = Shape.fromTuple(Tuple1(dim)) - def apply[A <: Tuple](args: A)(using n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] = - Shape.fromTuple(args) - - def fromTuple[A <: Tuple](args: A)(using n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] = - val sizes = args.toList.collect: + /** Create a shape from a tuple of AxisExtents. + * + * @param axisExtends A tuple of AxisExtents from which to create the shape. + * @return A shape with dimensions corresponding to the sizes of the AxisExtents in the tuple. + */ + def apply[Extents <: Tuple](axisExtends: Extents)(using n: Labels[ExtractLabels[Extents]]): Shape[ExtractLabels[Extents]] = + fromTuple(axisExtends) + + /** Create a shape from a tuple of AxisExtents. + * @param axisExtents A tuple of AxisExtents from which to create the shape. + * @return A shape with dimensions corresponding to the sizes of the AxisExtents in the tuple. + */ + def fromTuple[Extents <: Tuple](axisExtents: Extents)(using n: Labels[ExtractLabels[Extents]]): Shape[ExtractLabels[Extents]] = + val sizes = axisExtents.toList.collect: case ae: AxisExtent[?] => ae.size new Shape(sizes) private[tensor] def fromSeq[T <: Tuple: Labels](dims: Seq[Int]) = new Shape[T](dims.toList) +/** Type alias for an empty shape (rank 0). */ type Shape0 = Shape[EmptyTuple] + +/** Type alias for a shape with one dimension (rank 1). */ type Shape1[L] = Shape[L *: EmptyTuple] + +/** Type alias for a shape with two dimensions (rank 2). */ type Shape2[L1, L2] = Shape[L1 *: L2 *: EmptyTuple] + +/** Type alias for a shape with three dimensions (rank 3). */ type Shape3[L1, L2, L3] = Shape[L1 *: L2 *: L3 *: EmptyTuple] +/** Companion object for Shape0, providing a convenient way to create an empty shape. */ val Shape0 = Shape.empty +/** Companion object for Shape1, providing a convenient way to create a shape with one dimension. */ object Shape1: + + /** Creates a shape with a single dimension. + * + * @param dim The extent of the single dimension. + * @return A shape with one dimension. + */ def apply[L: Label](dim: AxisExtent[L]): Shape[Tuple1[L]] = Shape(dim) +/** Companion object for Shape2, providing a convenient way to create a shape with two dimensions. */ object Shape2: + + /** Creates a shape with two dimensions. + * + * @param dim1 The extent of the first dimension. + * @param dim2 The extent of the second dimension. + * @return A shape with two dimensions. + */ def apply[L1: Label, L2: Label]( dim1: AxisExtent[L1], dim2: AxisExtent[L2] - ): Shape[(L1, L2)] = Shape.fromTuple(dim1, dim2) + ): Shape[(L1, L2)] = Shape.fromTuple((dim1, dim2)) object Shape3: + + /** Creates a shape with three dimensions. + * + * @param dim1 The extent of the first dimension. + * @param dim2 The extent of the second dimension. + * @param dim3 The extent of the third dimension. + * @return A shape with three dimensions. + */ def apply[L1: Label, L2: Label, L3: Label]( dim1: AxisExtent[L1], dim2: AxisExtent[L2], dim3: AxisExtent[L3] - ): Shape[(L1, L2, L3)] = Shape.fromTuple(dim1, dim2, dim3) + ): Shape[(L1, L2, L3)] = Shape.fromTuple((dim1, dim2, dim3)) diff --git a/core/src/main/scala/dimwit/tensor/Tensor.scala b/core/src/main/scala/dimwit/tensor/Tensor.scala index bc5dc86..e81504e 100644 --- a/core/src/main/scala/dimwit/tensor/Tensor.scala +++ b/core/src/main/scala/dimwit/tensor/Tensor.scala @@ -48,7 +48,7 @@ class Tensor[T <: Tuple: Labels, V] private[dimwit] ( s"TracerTensor(${shape.toString})" case _ => jaxValue.toString() - def extent[L](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = + def extent[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = shape.extent(axis) private val jaxTypeName: String = py.Dynamic.global.`type`(jaxValue).`__name__`.as[String] diff --git a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala index 71b8995..082d98d 100644 --- a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala @@ -2,6 +2,7 @@ package dimwit.autodiff import dimwit.* import dimwit.Conversions.given +import dimwit.autodiff.FloatTree.* import dimwit.autodiff.FloatTree.ops.* import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers diff --git a/examples/src/main/scala/basic/Autoencoder.scala b/examples/src/main/scala/basic/Autoencoder.scala index 6de7bb3..395b375 100644 --- a/examples/src/main/scala/basic/Autoencoder.scala +++ b/examples/src/main/scala/basic/Autoencoder.scala @@ -13,6 +13,7 @@ import dimwit.jax.Jax import nn.ActivationFunctions.sigmoid import dimwit.random.Random.Key import dimwit.autodiff.* +import dimwit.autodiff.FloatTree.* import examples.dataset.MNISTLoader import MNISTLoader.{Sample, TrainSample, TestSample, Height, Width} diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index 86920ef..2e7785a 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -4,6 +4,7 @@ import examples.timed import dimwit.* import dimwit.autodiff.* +import dimwit.autodiff.FloatTree.* import dimwit.Conversions.given import dimwit.stats.Normal import dimwit.random.Random diff --git a/nn/src/main/scala/nn/GradientOptimizer.scala b/nn/src/main/scala/nn/GradientOptimizer.scala index b31195f..5237e1b 100644 --- a/nn/src/main/scala/nn/GradientOptimizer.scala +++ b/nn/src/main/scala/nn/GradientOptimizer.scala @@ -3,6 +3,7 @@ package nn import dimwit.* import dimwit.Conversions.given import dimwit.autodiff.FloatTree.ops.* +import dimwit.autodiff.FloatTree.* import dimwit.autodiff.* import dimwit.jax.Jax import dimwit.jax.Jit