Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions core/src/main/scala/dimwit/tensor/Axis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,48 @@ import dimwit.|*|

import scala.compiletime.{constValue, erasedValue, summonInline}

case class AxisExtent[T](axis: Axis[T], size: Int):
def *[T2](other: AxisExtent[T2]): AxisExtent[T |*| T2] =
AxisExtent(Axis[T |*| T2], size * other.size)

// Axis selectors for indexing operations
/** Instances of this class represent an axis in a tensor with a specific label `L`.
* Axis objects are used whenever an axis needs to be selected at the value level,
* such as when indexing into a tensor or defining the shape of a tensor.
*/
final class Axis[L: Label]:
def extent(size: Int): AxisExtent[L] = AxisExtent(this, size)
def ->(size: Int): AxisExtent[L] = this.extent(size)
def at(index: Int): AxisAtIndex[L] = AxisAtIndex(this, index)
def at(range: Range): AxisAtRange[L] = AxisAtRange(this, range)
def at(indices: Seq[Int]): AxisAtIndices[L] = AxisAtIndices(this, indices)
def at(index: Tensor0[Int]): AxisAtTensorIndex[L] = AxisAtTensorIndex(this, index)
def at[I <: NonEmptyTuple](indices: I): AxisAtTupleIndices[L, I] = AxisAtTupleIndices(this, indices)
def as[U](newAxis: Axis[U]): (Axis[L], Axis[U]) = (this, newAxis)

/** Represents the extent of an axis, which is a combination of an Axis and its size. */
case class AxisExtent[L: Label](axis: Axis[L], size: Int):

/** Combines this AxisExtent with another AxisExtent to create a new AxisExtent that represents the combined axes and their sizes.
*
* @param other The other AxisExtent to combine with this one.
* @return A new AxisExtent representing the combined axes and their sizes.
*/
def *[L2: Label](other: AxisExtent[L2]): AxisExtent[L |*| L2] =
AxisExtent(Axis[L |*| L2], size * other.size)

/** Trait hierarchy to represent different ways to select an axis in a tensor, such as by index, range, or specific indices.
*/
sealed trait AxisSelector[L]:
def axis: Axis[L]

/** Represent an axis selection by a single index.
*/
case class AxisAtIndex[L](axis: Axis[L], index: Int) extends AxisSelector[L]

/** Represent an axis selection by a range of indices. */
case class AxisAtRange[L](axis: Axis[L], range: Range) extends AxisSelector[L]

/** Represent an axis selection by a sequence of specific indices. */
case class AxisAtIndices[L](axis: Axis[L], indices: Seq[Int]) extends AxisSelector[L]
case class AxisAtTupleIndices[L, I <: NonEmptyTuple](axis: Axis[L], indices: I) extends AxisSelector[L]

/* Represent an axis selection by a tensor containing indices. This allows for dynamic indexing based on the contents of the tensor. */
case class AxisAtTensorIndex[L](axis: Axis[L], index: Tensor0[Int]) extends AxisSelector[L]

object Axis:

def apply[A]: Axis[A] = new AxisImpl[A]()

/** Represents an axis with A. This maps the type-level to a runtime representation. */
sealed trait Axis[A]:
def extent(size: Int): AxisExtent[A] = AxisExtent(this, size)
def ->(size: Int): AxisExtent[A] = this.extent(size)
def at(index: Int): AxisAtIndex[A] = AxisAtIndex(this, index)
def at[I <: NonEmptyTuple](indices: I): AxisAtTupleIndices[A, I] = AxisAtTupleIndices(this, indices)
def at(range: Range): AxisAtRange[A] = AxisAtRange(this, range)
def at(indices: Seq[Int]): AxisAtIndices[A] = AxisAtIndices(this, indices)
def at(index: Tensor0[Int]): AxisAtTensorIndex[A] = AxisAtTensorIndex(this, index)
def as[U](newAxis: Axis[U]): (Axis[A], Axis[U]) = (this, newAxis)

class AxisImpl[A] extends Axis[A]
/* Represent an axis selection by a tuple containing indices. */
case class AxisAtTupleIndices[L, I <: NonEmptyTuple](axis: Axis[L], indices: I) extends AxisSelector[L]
88 changes: 78 additions & 10 deletions core/src/main/scala/dimwit/tensor/Shape.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,37 @@ import scala.annotation.publicInBinary
import ShapeTypeHelpers.AxisIndex
import dimwit.tensor.{Labels, Label}

/** Represents the (typed) Shape of a tensor with runtime labels
/** Represents the Shape of a tensor. Conceptually, a shape is an order list of AxisExtents,
* where each AxisExtent is a label associated with a size.
*/
final case class Shape[T <: Tuple: Labels] @publicInBinary private (
val dimensions: List[Int]
):

/** Returns the labels of the shape
*/
lazy val labels: List[String] = summon[Labels[T]].names

/** Returns the rank of the shape, which is the number of dimensions.
*/
def rank: Int = dimensions.size

/** Returns the total number of elements in the tensor represented by this shape.
*/
def size: Int = dimensions.foldLeft(1)((acc, d) => acc * d.asInstanceOf[Int])
def extent[L](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = AxisExtent(axis, this(axis))

/** Returns the extent of the specified axis.
*
* @param axis The axis for which to retrieve the extent.
* @return The extent of the specified axis.
*/
def extent[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] = AxisExtent(axis, this(axis))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L Label evidence added. See other comment.


/** Returns the size of the specified axis.
*
* @param axis The axis for which to retrieve the size.
* @return The size of the specified axis.
*/
def apply[L](axis: Axis[L])(using ev: AxisIndex[T, L]): Int = this.dimensions(ev.index)

override def toString: String =
Expand All @@ -32,44 +52,92 @@ final case class Shape[T <: Tuple: Labels] @publicInBinary private (

object Shape:

private[tensor] type ExtractLabels[Args <: Tuple] <: Tuple = Args match
private[tensor] type ExtractLabels[Extents <: Tuple] <: Tuple = Extents match
case EmptyTuple => EmptyTuple
case AxisExtent[l] *: tail => l *: ExtractLabels[tail]

/** Creates an empty shape with no dimensions.
* @return An empty shape with no dimensions.
*/
def empty: Shape[EmptyTuple] = new Shape(Nil)

/** Creates a shape with a single dimension.
*
* @param dim The extent of the single dimension.
* @return A shape with one dimension.
*/
def apply[L: Label](dim: AxisExtent[L]): Shape[L *: EmptyTuple] =
Shape.fromTuple(Tuple1(dim))

def apply[A <: Tuple](args: A)(using n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] =
Shape.fromTuple(args)

def fromTuple[A <: Tuple](args: A)(using n: Labels[ExtractLabels[A]]): Shape[ExtractLabels[A]] =
val sizes = args.toList.collect:
/** Create a shape from a tuple of AxisExtents.
*
* @param axisExtends A tuple of AxisExtents from which to create the shape.
* @return A shape with dimensions corresponding to the sizes of the AxisExtents in the tuple.
*/
def apply[Extents <: Tuple](axisExtends: Extents)(using n: Labels[ExtractLabels[Extents]]): Shape[ExtractLabels[Extents]] =
fromTuple(axisExtends)

/** Create a shape from a tuple of AxisExtents.
* @param axisExtents A tuple of AxisExtents from which to create the shape.
* @return A shape with dimensions corresponding to the sizes of the AxisExtents in the tuple.
*/
def fromTuple[Extents <: Tuple](axisExtents: Extents)(using n: Labels[ExtractLabels[Extents]]): Shape[ExtractLabels[Extents]] =
val sizes = axisExtents.toList.collect:
case ae: AxisExtent[?] => ae.size
new Shape(sizes)

private[tensor] def fromSeq[T <: Tuple: Labels](dims: Seq[Int]) = new Shape[T](dims.toList)

/** Type alias for an empty shape (rank 0). */
type Shape0 = Shape[EmptyTuple]

/** Type alias for a shape with one dimension (rank 1). */
type Shape1[L] = Shape[L *: EmptyTuple]

/** Type alias for a shape with two dimensions (rank 2). */
type Shape2[L1, L2] = Shape[L1 *: L2 *: EmptyTuple]

/** Type alias for a shape with three dimensions (rank 3). */
type Shape3[L1, L2, L3] = Shape[L1 *: L2 *: L3 *: EmptyTuple]

/** Companion object for Shape0, providing a convenient way to create an empty shape. */
val Shape0 = Shape.empty

/** Companion object for Shape1, providing a convenient way to create a shape with one dimension. */
object Shape1:

/** Creates a shape with a single dimension.
*
* @param dim The extent of the single dimension.
* @return A shape with one dimension.
*/
def apply[L: Label](dim: AxisExtent[L]): Shape[Tuple1[L]] = Shape(dim)

/** Companion object for Shape2, providing a convenient way to create a shape with two dimensions. */
object Shape2:

/** Creates a shape with two dimensions.
*
* @param dim1 The extent of the first dimension.
* @param dim2 The extent of the second dimension.
* @return A shape with two dimensions.
*/
def apply[L1: Label, L2: Label](
dim1: AxisExtent[L1],
dim2: AxisExtent[L2]
): Shape[(L1, L2)] = Shape.fromTuple(dim1, dim2)
): Shape[(L1, L2)] = Shape.fromTuple((dim1, dim2))

object Shape3:

/** Creates a shape with three dimensions.
*
* @param dim1 The extent of the first dimension.
* @param dim2 The extent of the second dimension.
* @param dim3 The extent of the third dimension.
* @return A shape with three dimensions.
*/
def apply[L1: Label, L2: Label, L3: Label](
dim1: AxisExtent[L1],
dim2: AxisExtent[L2],
dim3: AxisExtent[L3]
): Shape[(L1, L2, L3)] = Shape.fromTuple(dim1, dim2, dim3)
): Shape[(L1, L2, L3)] = Shape.fromTuple((dim1, dim2, dim3))
2 changes: 1 addition & 1 deletion core/src/main/scala/dimwit/tensor/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Tensor[T <: Tuple: Labels, V] private[dimwit] (
s"TracerTensor(${shape.toString})"
case _ => jaxValue.toString()

def extent[L](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] =
def extent[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): AxisExtent[L] =
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was label evidence added on purpose. I suggest removing it if not necessary in the method.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here I thought that adding a label for documentation could help. It is not really needed but in most other method when we have this type parameter it also needs to be a label. so why not adding it here? Again, I have no strong opinion.

shape.extent(axis)

private val jaxTypeName: String = py.Dynamic.global.`type`(jaxValue).`__name__`.as[String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dimwit.autodiff

import dimwit.*
import dimwit.Conversions.given
import dimwit.autodiff.FloatTree.*
import dimwit.autodiff.FloatTree.ops.*
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers
Expand Down
1 change: 1 addition & 0 deletions examples/src/main/scala/basic/Autoencoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import dimwit.jax.Jax
import nn.ActivationFunctions.sigmoid
import dimwit.random.Random.Key
import dimwit.autodiff.*
import dimwit.autodiff.FloatTree.*
import examples.dataset.MNISTLoader

import MNISTLoader.{Sample, TrainSample, TestSample, Height, Width}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import examples.timed

import dimwit.*
import dimwit.autodiff.*
import dimwit.autodiff.FloatTree.*
import dimwit.Conversions.given
import dimwit.stats.Normal
import dimwit.random.Random
Expand Down
1 change: 1 addition & 0 deletions nn/src/main/scala/nn/GradientOptimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package nn
import dimwit.*
import dimwit.Conversions.given
import dimwit.autodiff.FloatTree.ops.*
import dimwit.autodiff.FloatTree.*
import dimwit.autodiff.*
import dimwit.jax.Jax
import dimwit.jax.Jit
Expand Down