From a7cc080de47f2da1f75bf31a2d3c5c0d5f8a4877 Mon Sep 17 00:00:00 2001 From: Benjamin Meyer Date: Sun, 10 May 2026 12:59:50 +0200 Subject: [PATCH] Add new Tensor API with precision types: Float16, Float32, Float64, ... instead of "Float" --- AGENTS.md | 406 +++++++++++------- README.md | 4 +- .../main/scala/dimwit/autodiff/Autodiff.scala | 13 +- .../scala/dimwit/autodiff/FloatTree.scala | 112 +++-- .../src/main/scala/dimwit/autodiff/Grad.scala | 12 +- core/src/main/scala/dimwit/jax/JaxDType.scala | 4 + core/src/main/scala/dimwit/package.scala | 7 +- .../src/main/scala/dimwit/random/Random.scala | 3 +- .../scala/dimwit/stats/Distributions.scala | 24 +- .../stats/IndependentDistributions.scala | 86 ++-- .../stats/MultivariateDistributions.scala | 33 +- .../stats/UnivariateDistributions.scala | 9 +- .../scala/dimwit/tensor/ArrayWriter.scala | 109 ++--- core/src/main/scala/dimwit/tensor/Axis.scala | 5 +- core/src/main/scala/dimwit/tensor/DType.scala | 87 ++++ .../src/main/scala/dimwit/tensor/Tensor.scala | 251 +++++++++-- .../main/scala/dimwit/tensor/TensorOps.scala | 206 ++++++--- core/src/main/scala/dimwit/tensor/VType.scala | 16 + core/src/main/scala/dimwit/tensor/Value.scala | 45 -- .../scala/dimwit/autodiff/AutodiffSuite.scala | 25 +- .../autodiff/FloatTensorTreeSuite.scala | 73 +++- .../scala/dimwit/autodiff/PyTreeSuite.scala | 13 +- .../dimwit/autodiff/TensorTreeSuite.scala | 12 +- core/src/test/scala/dimwit/jax/JitSuite.scala | 28 +- .../dimwit/memory/DimWitMemorySuite.scala | 10 +- core/src/test/scala/dimwit/package.scala | 20 +- .../scala/dimwit/python/PyWrapSuite.scala | 32 +- .../scala/dimwit/random/RandomSuite.scala | 1 - .../dimwit/stats/DistributionSuite.scala | 36 +- .../dimwit/tensor/TensorCovarianceSuite.scala | 11 +- .../dimwit/tensor/TensorCreationSuite.scala | 15 +- .../dimwit/tensor/TensorOpsBinarySuite.scala | 10 +- .../tensor/TensorOpsBroadcastSuite.scala | 1 + .../tensor/TensorOpsContractionSuite.scala | 2 +- .../tensor/TensorOpsConvolutionSuite.scala | 1 - .../tensor/TensorOpsElementwiseSuite.scala | 19 +- .../tensor/TensorOpsFunctionalSuite.scala | 9 +- .../tensor/TensorOpsReductionSuite.scala | 5 +- .../tensor/TensorOpsStructureSuite.scala | 3 +- .../tensor/TensorWithValueClassSuite.scala | 12 +- docs/quickstart.md | 68 +-- .../src/main/scala/basic/Autoencoder.scala | 14 +- .../main/scala/basic/LogisticRegression.scala | 32 +- .../main/scala/basic/MLClassifierMNist.scala | 51 ++- .../scala/basic/MLClassifierMNistCNN.scala | 32 +- .../src/main/scala/basic/Playground.scala | 2 +- examples/src/main/scala/complex/GPT2.scala | 90 ++-- .../src/main/scala/complex/GPT2Train.scala | 99 ++--- .../complex/VariationalAutoencoder.scala | 20 +- .../src/main/scala/dataset/MNISTLoader.scala | 18 +- mdocs/AGENTS.md | 107 ++--- mdocs/README.md | 4 +- mdocs/docs/quickstart.md | 54 +-- nn/src/main/scala/nn/Conv2DLayer.scala | 16 +- nn/src/main/scala/nn/GradientOptimizer.scala | 46 +- nn/src/main/scala/nn/LinearLayer.scala | 19 +- nn/src/main/scala/nn/Loss.scala | 8 +- .../main/scala/nn/TransposeConv2DLayer.scala | 6 +- 58 files changed, 1432 insertions(+), 1024 deletions(-) create mode 100644 core/src/main/scala/dimwit/tensor/VType.scala delete mode 100644 core/src/main/scala/dimwit/tensor/Value.scala diff --git a/AGENTS.md b/AGENTS.md index 412e969d..40c66cc5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -169,18 +169,18 @@ val t2dNested = Tensor2(Axis[A], Axis[B]).fromArray( ```scala // Scalar (0D) -val scalar: Tensor0[Float] = Tensor0(42.0f) +val scalar: Tensor0[Float32] = Tensor0(42.0f) // Vector (1D) -val vector: Tensor1[Feature, Float] = Tensor1(Axis[Feature]).fromArray(Array(1.0f, 2.0f, 3.0f)) +val vector: Tensor1[Feature, Float32] = Tensor1(Axis[Feature]).fromArray(Array(1.0f, 2.0f, 3.0f)) // Matrix (2D) -val matrix: Tensor2[Batch, Feature, Float] = Tensor2(Axis[Batch], Axis[Feature]).fromArray( +val matrix: Tensor2[Batch, Feature, Float32] = Tensor2(Axis[Batch], Axis[Feature]).fromArray( Array(Array(1.0f, 2.0f), Array(3.0f, 4.0f)) ) // 3D Tensor -val tensor3d: Tensor3[Batch, Feature, Hidden, Float] = +val tensor3d: Tensor3[Batch, Feature, Hidden, Float32] = Tensor(Shape3(Axis[Batch] -> 2, Axis[Feature] -> 3, Axis[Hidden] -> 4)).fill(0.0f) ``` @@ -279,20 +279,20 @@ val data = Tensor2(Axis[A], Axis[B]).fromArray( ) // Reduce to scalar -val totalSum: Tensor0[Float] = data.sum -val totalMean: Tensor0[Float] = data.mean -val totalMax: Tensor0[Float] = data.max -val totalMin: Tensor0[Float] = data.min -val totalStd: Tensor0[Float] = data.std +val totalSum: Tensor0[Float32] = data.sum +val totalMean: Tensor0[Float32] = data.mean +val totalMax: Tensor0[Float32] = data.max +val totalMin: Tensor0[Float32] = data.min +val totalStd: Tensor0[Float32] = data.std println(s"Sum: ${totalSum.item}") // 21.0 // Reduce along axis A (across rows) -val sumA: Tensor1[B, Float] = data.sum(Axis[A]) +val sumA: Tensor1[B, Float32] = data.sum(Axis[A]) println(s"Sum along A: ${sumA}") // [5.0, 7.0, 9.0] // Reduce along axis B (across columns) -val sumB: Tensor1[A, Float] = data.sum(Axis[B]) +val sumB: Tensor1[A, Float32] = data.sum(Axis[B]) println(s"Sum along B: ${sumB}") // [6.0, 15.0] // Mean along axes @@ -300,8 +300,8 @@ val meanA = data.mean(Axis[A]) // [2.5, 3.5, 4.5] val meanB = data.mean(Axis[B]) // [2.0, 5.0] // Argmax / Argmin (returns indices) -val argmaxB: Tensor1[A, Int] = data.argmax(Axis[B]) -val argminB: Tensor1[A, Int] = data.argmin(Axis[B]) +val argmaxB: Tensor1[A, Int32] = data.argmax(Axis[B]) +val argminB: Tensor1[A, Int32] = data.argmin(Axis[B]) ``` **Error: Reducing on non-existent axis** @@ -335,8 +335,12 @@ val wrong = t.sum(Axis[C]) // ^ // error: // Conflicting definitions: -// val t: dimwit.tensor.Tensor[(MdocApp0.this.A, MdocApp0.this.B), Float] in class MdocApp0 at line 53 and -// val t: dimwit.tensor.Tensor[(MdocApp0.this.A, MdocApp0.this.B), Float] in class MdocApp0 at line 88 +// val t: +// dimwit.tensor.Tensor2[MdocApp0.this.A, MdocApp0.this.B, +// dimwit.tensor.DType.Float32] in class MdocApp0 at line 53 and +// val t: +// dimwit.tensor.Tensor2[MdocApp0.this.A, MdocApp0.this.B, +// dimwit.tensor.DType.Float32] in class MdocApp0 at line 88 // ``` @@ -372,11 +376,16 @@ val t = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f))) val wrong = t + 5.0f // Use +! instead // error: // Found: (5.0f : Float) -// Required: dimwit.tensor.Tensor[(MdocApp0.this.A, MdocApp0.this.B), Float] +// Required: dimwit.tensor.Tensor[(MdocApp0.this.A, MdocApp0.this.B), +// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)] // error: // Conflicting definitions: -// val t: dimwit.tensor.Tensor[(MdocApp0.this.A, MdocApp0.this.B), Float] in class MdocApp0 at line 53 and -// val t: dimwit.tensor.Tensor[(MdocApp0.this.A, MdocApp0.this.B), Float] in class MdocApp0 at line 97 +// val t: +// dimwit.tensor.Tensor2[MdocApp0.this.A, MdocApp0.this.B, +// dimwit.tensor.DType.Float32] in class MdocApp0 at line 53 and +// val t: +// dimwit.tensor.Tensor2[MdocApp0.this.A, MdocApp0.this.B, +// dimwit.tensor.DType.Float32] in class MdocApp0 at line 97 // ``` @@ -395,7 +404,7 @@ trait D derives Label // Dot product (vector · vector) val v1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f, 3.0f)) val v2 = Tensor1(Axis[A]).fromArray(Array(4.0f, 5.0f, 6.0f)) -val dotProduct: Tensor0[Float] = v1.dot(Axis[A])(v2) +val dotProduct: Tensor0[Float32] = v1.dot(Axis[A])(v2) println(s"Dot product: ${dotProduct.item}") // 32.0 // Matrix-vector multiplication @@ -464,13 +473,21 @@ val wrong = m1.dot(Axis[B])(m2) // ^ // error: // Conflicting definitions: -// val m1: dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B), Float] in class MdocApp1 at line 119 and -// val m1: dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B), Float] in class MdocApp1 at line 122 +// val m1: +// dimwit.tensor.Tensor2[MdocApp1.this.A, MdocApp1.this.B, +// dimwit.tensor.DType.Float32] in class MdocApp1 at line 119 and +// val m1: +// dimwit.tensor.Tensor2[MdocApp1.this.A, MdocApp1.this.B, +// dimwit.tensor.DType.Float32] in class MdocApp1 at line 122 // // error: // Conflicting definitions: -// val m2: dimwit.tensor.Tensor[(MdocApp1.this.B, MdocApp1.this.C), Float] in class MdocApp1 at line 120 and -// val m2: dimwit.tensor.Tensor[(MdocApp1.this.C, MdocApp1.this.D), Float] in class MdocApp1 at line 123 +// val m2: +// dimwit.tensor.Tensor2[MdocApp1.this.B, MdocApp1.this.C, +// dimwit.tensor.DType.Float32] in class MdocApp1 at line 120 and +// val m2: +// dimwit.tensor.Tensor2[MdocApp1.this.C, MdocApp1.this.D, +// dimwit.tensor.DType.Float32] in class MdocApp1 at line 123 // ``` @@ -484,7 +501,7 @@ val original = Tensor2(Axis[A], Axis[B]).fromArray( ) // Transpose -val transposed: Tensor2[B, A, Float] = original.transpose +val transposed: Tensor2[B, A, Float32] = original.transpose println(s"Original shape: ${original.shape}") println(s"Transposed shape: ${transposed.shape}") @@ -532,20 +549,20 @@ val data = Tensor2(Axis[Batch], Axis[Feature]).fromArray( ) // Normalize each sample (row) independently -def normalize(x: Tensor1[Feature, Float]): Tensor1[Feature, Float] = +def normalize(x: Tensor1[Feature, Float32]): Tensor1[Feature, Float32] = val mean = x.mean val std = x.std + Tensor0(1e-6f) // Avoid division by zero (x -! mean) /! std -val normalized: Tensor2[Batch, Feature, Float] = data.vmap(Axis[Batch])(normalize) +val normalized: Tensor2[Batch, Feature, Float32] = data.vmap(Axis[Batch])(normalize) println(s"Normalized data: ${normalized}") // Sum each row -val rowSums: Tensor1[Batch, Float] = data.vmap(Axis[Batch])(_.sum) +val rowSums: Tensor1[Batch, Float32] = data.vmap(Axis[Batch])(_.sum) println(s"Row sums: ${rowSums}") // [6.0, 15.0, 24.0] // vmap over columns (note: axis moves to front) -val colSums: Tensor1[Feature, Float] = data.vmap(Axis[Feature])(_.sum) +val colSums: Tensor1[Feature, Float32] = data.vmap(Axis[Feature])(_.sum) println(s"Column sums: ${colSums}") // [12.0, 15.0, 18.0] // Identity vmap doesn't change data, only axis order @@ -558,7 +575,7 @@ val reordered = data.vmap(Axis[Feature])(x => x) // Same as data.transpose ```scala val t = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f))) // ERROR: Function expects Tensor2, but vmap provides Tensor1 -def wrongFunc(x: Tensor2[A, B, Float]): Tensor0[Float] = x.sum +def wrongFunc(x: Tensor2[A, B, Float32]): Tensor0[Float32] = x.sum val wrong = t.vmap(Axis[A])(wrongFunc) // error: // Not found: type A @@ -576,11 +593,11 @@ val wrong = t.vmap(Axis[A])(wrongFunc) // // error: // Not found: type A -// def wrongFunc(x: Tensor2[A, B, Float]): Tensor0[Float] = x.sum +// def wrongFunc(x: Tensor2[A, B, Float32]): Tensor0[Float32] = x.sum // ^ // error: // Not found: type B -// def wrongFunc(x: Tensor2[A, B, Float]): Tensor0[Float] = x.sum +// def wrongFunc(x: Tensor2[A, B, Float32]): Tensor0[Float32] = x.sum // ^ ``` @@ -596,10 +613,10 @@ val t1 = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f), Array(3.0f val t2 = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(10.0f, 20.0f), Array(30.0f, 40.0f))) // Compute L2 distance between corresponding rows -def l2Distance(v1: Tensor1[B, Float], v2: Tensor1[B, Float]): Tensor0[Float] = +def l2Distance(v1: Tensor1[B, Float32], v2: Tensor1[B, Float32]): Tensor0[Float32] = (v1 - v2).pow(Tensor0(2.0f)).sum.sqrt -val distances: Tensor1[A, Float] = zipvmap(Axis[A])(t1, t2)(l2Distance) +val distances: Tensor1[A, Float32] = zipvmap(Axis[A])(t1, t2)(l2Distance) println(s"L2 distances: ${distances}") // zipvmap with 4 tensors @@ -663,7 +680,7 @@ trait A derives Label import dimwit.autodiff.Autodiff // Scalar function: f(x) = x² -def f(x: Tensor0[Float]): Tensor0[Float] = x * x +def f(x: Tensor0[Float32]): Tensor0[Float32] = x * x val df = Autodiff.grad(f) val x = Tensor0(3.0f) @@ -671,7 +688,7 @@ val gradient = df(x) println(s"df/dx at x=3: ${gradient.value.item}") // 6.0 // Vector function: f(x) = sum(x²) -def g(x: Tensor1[A, Float]): Tensor0[Float] = (x * x).sum +def g(x: Tensor1[A, Float32]): Tensor0[Float32] = (x * x).sum val dg = Autodiff.grad(g) val xVec = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f, 3.0f)) @@ -682,16 +699,16 @@ println(s"dg/dx: ${vecGradient}") // [2.0, 4.0, 6.0] ### Higher-Order Derivatives ```scala -def f2(x: Tensor0[Float]): Tensor0[Float] = x * x +def f2(x: Tensor0[Float32]): Tensor0[Float32] = x * x // First derivative val df2 = Autodiff.grad(f2) // Second derivative -val ddf2 = Autodiff.grad((x: Tensor0[Float]) => df2(x).value) +val ddf2 = Autodiff.grad((x: Tensor0[Float32]) => df2(x).value) // Third derivative -val dddf2 = Autodiff.grad((x: Tensor0[Float]) => ddf2(x).value) +val dddf2 = Autodiff.grad((x: Tensor0[Float32]) => ddf2(x).value) val x2 = Tensor0(3.0f) println(s"f''(3) = ${ddf2(x2).value.item}") // 2.0 @@ -702,7 +719,7 @@ println(s"f'''(3) = ${dddf2(x2).value.item}") // 0.0 ```scala // f(x, y) = (x + 2y)² -def twoParam(x: Tensor1[A, Float], y: Tensor1[A, Float]): Tensor0[Float] = +def twoParam(x: Tensor1[A, Float32], y: Tensor1[A, Float32]): Tensor0[Float32] = ((x + (y *! Tensor0(2.0f))).pow(Tensor0(2.0f))).sum val dtwoParam = Autodiff.grad(twoParam) @@ -725,7 +742,7 @@ trait Batch derives Label trait Feature derives Label // Gradient works seamlessly with vmap -def batched(x: Tensor2[Batch, Feature, Float]): Tensor0[Float] = +def batched(x: Tensor2[Batch, Feature, Float32]): Tensor0[Float32] = x.vmap(Axis[Batch])(_.sum).sum val dBatched = Autodiff.grad(batched) @@ -747,17 +764,17 @@ trait Feature derives Label trait Hidden derives Label case class LinearParams( - weight: Tensor2[Feature, Hidden, Float], - bias: Tensor1[Hidden, Float] + weight: Tensor2[Feature, Hidden, Float32], + bias: Tensor1[Hidden, Float32] ) // Define a model -def linear(params: LinearParams)(input: Tensor1[Feature, Float]): Tensor1[Hidden, Float] = +def linear(params: LinearParams)(input: Tensor1[Feature, Float32]): Tensor1[Hidden, Float32] = val weighted = params.weight.transpose.dot(Axis[Feature])(input) weighted + params.bias // Define loss function -def loss(data: Tensor1[Feature, Float], target: Tensor1[Hidden, Float])(params: LinearParams): Tensor0[Float] = +def loss(data: Tensor1[Feature, Float32], target: Tensor1[Hidden, Float32])(params: LinearParams): Tensor0[Float32] = val prediction = linear(params)(data) (prediction - target).pow(Tensor0(2.0f)).sum @@ -779,15 +796,14 @@ println(s"Bias gradient shape: ${paramGradients.bias.shape}") ```scala // ERROR: Cannot differentiate with respect to Int tensors -def intFunc(x: Tensor1[A, Int]): Tensor0[Int] = x.sum +def intFunc(x: Tensor1[A, Int32]): Tensor0[Int32] = x.sum val wrong = Autodiff.grad(intFunc) // error: // Not found: type A -// def intFunc(x: Tensor1[A, Int]): Tensor0[Int] = x.sum +// def intFunc(x: Tensor1[A, Int32]): Tensor0[Int32] = x.sum // ^ // error: -// No given instance of type dimwit.autodiff.TensorTree[dimwit.tensor.Tensor1[, Int] -// ] was found for parameter inTree of method grad in object Autodiff +// Operation only valid for Floating tensors. // val wrong = Autodiff.grad(intFunc) // ^ ``` @@ -801,7 +817,7 @@ import dimwit.autodiff.* trait A derives Label // Jacobian of f: R² -> R², f(x) = 2x -def linearMap(x: Tensor1[A, Float]): Tensor1[A, Float] = x *! Tensor0(2.0f) +def linearMap(x: Tensor1[A, Float32]): Tensor1[A, Float32] = x *! Tensor0(2.0f) val jacobian = Autodiff.jacobian(linearMap) val xJac = Tensor1(Axis[A]).fromArray(Array(1.0f, 1.0f)) @@ -828,11 +844,11 @@ trait Feature derives Label trait Batch derives Label // Define model parameters -case class SimpleModelParams(w: Tensor1[Feature, Float], b: Tensor0[Float]) +case class SimpleModelParams(w: Tensor1[Feature, Float32], b: Tensor0[Float32]) // Define loss function -def mse(data: Tensor2[Batch, Feature, Float], labels: Tensor1[Batch, Float]) - (params: SimpleModelParams): Tensor0[Float] = +def mse(data: Tensor2[Batch, Feature, Float32], labels: Tensor1[Batch, Float32]) + (params: SimpleModelParams): Tensor0[Float32] = val predictions = data.vmap(Axis[Batch]) { sample => sample.dot(Axis[Feature])(params.w) + params.b } @@ -897,11 +913,11 @@ val yData = Tensor1(Axis[Sample]).fromArray( ) // Model parameters -case class RegressionParams(slope: Tensor1[InputDim, Float], intercept: Tensor0[Float]) +case class RegressionParams(slope: Tensor1[InputDim, Float32], intercept: Tensor0[Float32]) // Loss function (MSE) -def regressionLoss(x: Tensor2[Sample, InputDim, Float], y: Tensor1[Sample, Float]) - (params: RegressionParams): Tensor0[Float] = +def regressionLoss(x: Tensor2[Sample, InputDim, Float32], y: Tensor1[Sample, Float32]) + (params: RegressionParams): Tensor0[Float32] = val predictions = x.vmap(Axis[Sample]) { xi => xi.dot(Axis[InputDim])(params.slope) + params.intercept } @@ -941,7 +957,7 @@ import dimwit.jax.Jit trait A derives Label // Define a complex function -def complexComputation(x: Tensor1[A, Float]): Tensor1[A, Float] = +def complexComputation(x: Tensor1[A, Float32]): Tensor1[A, Float32] = val y = x.exp val z = y.log val w = z.sin @@ -968,7 +984,7 @@ import dimwit.jitDonating import dimwit.jitDonatingUnsafe // jitDonating allows reusing input memory -def inPlaceOp(x: Tensor1[A, Float]): Tensor1[A, Float] = x *! Tensor0(2.0f) +def inPlaceOp(x: Tensor1[A, Float32]): Tensor1[A, Float32] = x *! Tensor0(2.0f) val (jitDonate, jitF, jitReclaim) = jitDonating(inPlaceOp) val inputDonate = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) val donated = jitDonate(inputDonate) @@ -1036,17 +1052,26 @@ trait A derives Label trait B derives Label trait C derives Label trait D derives Label +``` + +```scala // ERROR: Cannot perform floating-point operations on Int tensors val intTensor = Tensor1(Axis[A]).fromArray(Array(1, 2, 3)) val wrong = intTensor.exp // exp requires IsFloating constraint // error: -// value exp is not a member of dimwit.tensor.Tensor[Tuple1[MdocApp11.this.A], Int]. +// value exp is not a member of dimwit.tensor.Tensor1[MdocApp11.this.A, dimwit.tensor.DType.Int32]. // An extension method was tried, but could not be fully constructed: // -// dimwit.exp[Tuple1[MdocApp11.this.A], Int](this.intTensor)( +// dimwit.exp[Tuple1[MdocApp11.this.A], +// (dimwit.tensor.DType.Int32 : dimwit.tensor.DType)](this.intTensor)( // dimwit.tensor.Labels.concat[MdocApp11.this.A, EmptyTuple.type]( // this.A.derived$Label, dimwit.tensor.Labels.namesOfEmpty), -// /* missing */summon[dimwit.tensor.TensorOps.IsFloating[Int]]) +// /* missing */ +// summon[ +// dimwit.tensor.TensorOps.IsFloating[ +// (dimwit.tensor.DType.Int32 : dimwit.tensor.DType)] +// ] +// ) // // failed with: // @@ -1057,12 +1082,26 @@ val wrong = intTensor.exp // exp requires IsFloating constraint // ERROR: Cannot compute mean of Boolean tensor val boolTensor = Tensor1(Axis[A]).fromArray(Array(true, false, true)) val wrong = boolTensor.mean -// error: -// Not found: Tensor1 -// error: -// Not found: type A -// error: -// Not found: Axis +// error: +// value mean is not a member of dimwit.tensor.Tensor1[MdocApp11.this.A, dimwit.tensor.DType.Bool]. +// An extension method was tried, but could not be fully constructed: +// +// dimwit.mean[Tuple1[MdocApp11.this.A], +// (dimwit.tensor.DType.Bool : dimwit.tensor.DType)](this.boolTensor)( +// dimwit.tensor.Labels.concat[MdocApp11.this.A, EmptyTuple.type]( +// this.A.derived$Label, dimwit.tensor.Labels.namesOfEmpty), +// /* missing */ +// summon[ +// dimwit.tensor.TensorOps.IsFloating[ +// (dimwit.tensor.DType.Bool : dimwit.tensor.DType)] +// ] +// ) +// +// failed with: +// +// Operation only valid for Floating tensors. +// val wrong = boolTensor.mean +// ^^^^^^^^^^^^^^^ ``` ### Shape Mismatches @@ -1073,19 +1112,17 @@ val t1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) val t2 = Tensor1(Axis[B]).fromArray(Array(3.0f, 4.0f, 5.0f)) val wrong = t1 + t2 // Different labels AND different sizes // error: -// Not found: Tensor1 -// error: -// Not found: type A -// error: -// Not found: Axis -// error: -// Not found: Tensor1 -// val wrong = t1 + t2 // Different labels AND different sizes -// ^ -// error: -// Not found: type B -// error: -// Not found: Axis +// +// A tuple of axis labels Tuple1[MdocApp11.this.A | MdocApp11.this.B] was given or inferred that does not have a valid Labels instance. +// +// Ensure that all of the types in the tuple have a 'derives Label' clause. +// . +// I found: +// +// dimwit.tensor.Labels.concat[head, tail]( +// /* missing */summon[dimwit.tensor.Label[head]], ???) +// +// But no implicit values were found that match type dimwit.tensor.Label[head]. ``` ```scala @@ -1094,25 +1131,24 @@ val m1 = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f))) // Shape val m2 = Tensor2(Axis[C], Axis[D]).fromArray(Array(Array(3.0f), Array(4.0f))) // Shape: (2, 1) val wrong = m1.dot(Axis[B])(m2) // Axis[B] not in m2 // error: -// Not found: Tensor2 -// error: -// Not found: type A -// error: -// Not found: Axis -// error: -// Not found: type B -// error: -// Not found: Axis -// error: -// Not found: Tensor2 -// error: -// Not found: type C -// error: -// Not found: Axis -// error: -// Not found: type D -// error: -// Not found: Axis +// Axis[MdocApp11.this.B] not found in Tensor[(MdocApp11.this.C, MdocApp11.this.D)]. +// I found: +// +// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ +// (MdocApp11.this.C, MdocApp11.this.D), MdocApp11.this.B, Tuple]( +// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.C, +// MdocApp11.this.D *: EmptyTuple.type, MdocApp11.this.B]( +// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.D, +// EmptyTuple.type, MdocApp11.this.B]( +// dimwit.tensor.ShapeTypeHelpers.AxisIndex.concatRight[A, B², L]) +// ), +// ???) +// +// But given instance concatRight in object AxisIndex does not match type dimwit.tensor.ShapeTypeHelpers.AxisIndex[EmptyTuple.type, MdocApp11.this.B] +// +// where: B is a trait in class MdocApp11 +// B² is a type variable with constraint <: Tuple +// . ``` ### Broadcast vs Non-Broadcast Confusion @@ -1122,15 +1158,9 @@ val wrong = m1.dot(Axis[B])(m2) // Axis[B] not in m2 val t = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f))) val wrong = t + 10.0f // Should use +! for scalar broadcast // error: -// Not found: Tensor2 -// error: -// Not found: type A -// error: -// Not found: Axis -// error: -// Not found: type B -// error: -// Not found: Axis +// Found: (10.0f : Float) +// Required: dimwit.tensor.Tensor[(MdocApp11.this.A, MdocApp11.this.B), +// (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)] ``` ```scala @@ -1138,19 +1168,36 @@ val wrong = t + 10.0f // Should use +! for scalar broadcast val t1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) val t2 = Tensor1(Axis[A]).fromArray(Array(3.0f, 4.0f)) // This works but is semantically wrong (use + instead) -// val result = t1 +! t2 // Compiles but misleading -// error: -// Not found: Tensor1 -// error: -// Not found: type A -// error: -// Not found: Axis -// error: -// Not found: Tensor1 -// error: -// Not found: type A -// error: -// Not found: Axis +val wrong = t1 +! t2 +// error: +// Cannot broadcast tensors of shapes Tuple1[MdocApp11.this.A] and Tuple1[MdocApp11.this.A]. If same shape no broadcasting allowed!. +// I found: +// +// dimwit.tensor.TensorOpsUtil.Broadcast.broadcastLeft[Tuple1[MdocApp11.this.A], +// Tuple1[MdocApp11.this.A], (dimwit.tensor.DType.Float32 : dimwit.tensor.DType)] +// ( +// dimwit.tensor.Labels.concat[MdocApp11.this.A, EmptyTuple.type]( +// this.A.derived$Label, dimwit.tensor.Labels.namesOfEmpty), +// dimwit.tensor.Labels.concat[MdocApp11.this.A, EmptyTuple.type]( +// this.A.derived$Label, dimwit.tensor.Labels.namesOfEmpty), +// dimwit.tensor.TupleHelpers.StrictSubset.derive[Tuple1[MdocApp11.this.A], +// Tuple1[MdocApp11.this.A]]( +// dimwit.tensor.TupleHelpers.Subset.head[MdocApp11.this.A, EmptyTuple.type, +// Tuple1[MdocApp11.this.A]]( +// dimwit.tensor.TupleHelpers.SetMember.found[MdocApp11.this.A, +// EmptyTuple.type], +// dimwit.tensor.TupleHelpers.Subset.empty[Tuple1[MdocApp11.this.A]]), +// /* missing */ +// summon[ +// scala.util.NotGiven[Tuple1[MdocApp11.this.A] =:= +// Tuple1[MdocApp11.this.A]] +// ] +// ) +// ) +// +// But no implicit values were found that match type scala.util.NotGiven[Tuple1[MdocApp11.this.A] =:= Tuple1[MdocApp11.this.A]]. +// val wrong = t1 +! t2 +// ^^ ``` ### Axis Errors @@ -1160,15 +1207,26 @@ val t2 = Tensor1(Axis[A]).fromArray(Array(3.0f, 4.0f)) val t = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f))) val wrong = t.sum(Axis[C]) // Axis[C] not in tensor // error: -// Not found: Tensor2 -// error: -// Not found: type A -// error: -// Not found: Axis -// error: -// Not found: type B -// error: -// Not found: Axis +// Axis[MdocApp11.this.C] not found in Tensor[(MdocApp11.this.A, MdocApp11.this.B)]. +// I found: +// +// dimwit.tensor.ShapeTypeHelpers.AxisRemover.bridge[ +// (MdocApp11.this.A, MdocApp11.this.B), MdocApp11.this.C, Tuple]( +// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.A, +// MdocApp11.this.B *: EmptyTuple.type, MdocApp11.this.C]( +// dimwit.tensor.ShapeTypeHelpers.AxisIndex.tail[MdocApp11.this.B, +// EmptyTuple.type, MdocApp11.this.C]( +// dimwit.tensor.ShapeTypeHelpers.AxisIndex.concatRight[A², B², L]) +// ), +// ???) +// +// But given instance concatRight in object AxisIndex does not match type dimwit.tensor.ShapeTypeHelpers.AxisIndex[EmptyTuple.type, MdocApp11.this.C] +// +// where: A is a trait in class MdocApp11 +// A² is a type variable with constraint <: Tuple +// B is a trait in class MdocApp11 +// B² is a type variable with constraint <: Tuple +// . ``` ```scala @@ -1176,53 +1234,69 @@ val wrong = t.sum(Axis[C]) // Axis[C] not in tensor val t = Tensor2(Axis[A], Axis[B]).fill(1.0f) val wrong = t.vmap(Axis[C])(_.sum) // Axis[C] doesn't exist // error: -// Not found: Tensor2 -// error: -// Not found: type A -// error: -// Not found: Axis -// error: -// Not found: type B -// error: -// Not found: Axis +// value fill is not a member of dimwit.tensor.Tensor2.DefaultsFactory[MdocApp11.this.A, MdocApp11.this.B] ``` ### Gradient Errors ```scala // ERROR: Cannot differentiate Integer functions -def intFunc(x: Tensor0[Int]): Tensor0[Int] = x + x +def intFunc(x: Tensor0[Int32]): Tensor0[Int32] = x + x val wrong = Autodiff.grad(intFunc) // error: -// Not found: type Tensor0 -// def intFunc(x: Tensor0[Int]): Tensor0[Int] = x + x -// ^^^^^^^ -// error: -// Not found: type Tensor0 -// def intFunc(x: Tensor0[Int]): Tensor0[Int] = x + x -// ^^^^^^^ -// error: -// Not found: Autodiff +// Operation only valid for Floating tensors. // val wrong = Autodiff.grad(intFunc) -// ^^^^^^^^ +// ^ ``` ```scala // ERROR: Function doesn't return scalar for grad -def nonScalar(x: Tensor1[A, Float]): Tensor1[A, Float] = x * x +def nonScalar(x: Tensor1[A, Float32]): Tensor1[A, Float32] = x * x val wrong = Autodiff.grad(nonScalar) // Use jacobian instead // error: -// Not found: type Tensor1 -// error: -// Not found: type A -// error: -// Not found: type Tensor1 -// val wrong = Autodiff.grad(nonScalar) // Use jacobian instead -// ^ -// error: -// Not found: type A -// error: -// Not found: Autodiff +// None of the overloaded alternatives of method grad in object Autodiff with types +// [Input, V] +// (f: Input => dimwit.tensor.Tensor0[V]) +// (using evidence$1: dimwit.tensor.TensorOps.IsFloating[V], +// inTree: dimwit.autodiff.TensorTree[Input], outTree: +// dimwit.autodiff.TensorTree[dimwit.tensor.Tensor0[V]]): Input => +// dimwit.autodiff.Grad[Input] +// [T1, T2, T3, V²] +// (f: (T1, T2, T3) => dimwit.tensor.Tensor0[V²]) +// (using evidence$1²: dimwit.tensor.TensorOps.IsFloating[V²], +// t1Tree: dimwit.autodiff.TensorTree[T1], +// t2Tree: dimwit.autodiff.TensorTree[T2], +// t3Tree: dimwit.autodiff.TensorTree[T3], outTree²: +// dimwit.autodiff.TensorTree[dimwit.tensor.Tensor0[V²]]): (T1, T2, T3) => +// dimwit.autodiff.Grad[(T1, T2, T3)] +// [T1², T2², V³] +// (f: (T1², T2²) => dimwit.tensor.Tensor0[V³]) +// (using evidence$1³: dimwit.tensor.TensorOps.IsFloating[V³], +// t1Tree²: dimwit.autodiff.TensorTree[T1²], +// t2Tree²: dimwit.autodiff.TensorTree[T2²], outTree³: +// dimwit.autodiff.TensorTree[dimwit.tensor.Tensor0[V³]]): (T1², T2²) => +// dimwit.autodiff.Grad[(T1², T2²)] +// match arguments (dimwit.tensor.Tensor1[MdocApp11.this.A, dimwit.Float32] => +// dimwit.tensor.Tensor1[MdocApp11.this.A, dimwit.Float32]) +// +// where: T1 is a type variable +// T1² is a type variable +// T2 is a type variable +// T2² is a type variable +// V is a type variable +// V² is a type variable +// V³ is a type variable +// evidence$1 is a reference to a value parameter +// evidence$1² is a reference to a value parameter +// evidence$1³ is a reference to a value parameter +// outTree is a reference to a value parameter +// outTree² is a reference to a value parameter +// outTree³ is a reference to a value parameter +// t1Tree is a reference to a value parameter +// t1Tree² is a reference to a value parameter +// t2Tree is a reference to a value parameter +// t2Tree² is a reference to a value parameter +// ``` ### Device Mismatches @@ -1259,7 +1333,7 @@ val embeddings = Tensor(Shape3(Axis[Batch] -> 8, Axis[SeqLen] -> 128, Axis[Embed trait Feature derives Label // GOOD: Clear type signatures -def process(input: Tensor2[Batch, Feature, Float]): Tensor1[Batch, Float] = +def process(input: Tensor2[Batch, Feature, Float32]): Tensor1[Batch, Float32] = input.vmap(Axis[Batch])(_.sum) // AVOID: Opaque Tensor types without explicit parameters @@ -1274,9 +1348,9 @@ trait Output derives Label // GOOD: Structured parameters with TensorTree case class ModelParams( - encoder: Tensor2[InputDim, Hidden, Float], - decoder: Tensor2[Hidden, Output, Float], - bias: Tensor1[Output, Float] + encoder: Tensor2[InputDim, Hidden, Float32], + decoder: Tensor2[Hidden, Output, Float32], + bias: Tensor1[Output, Float32] ) // AVOID: Tuples or loose parameters @@ -1324,7 +1398,7 @@ import dimwit.jax.Jit.jit trait Input derives Label -val simpleFunc = (x: Tensor1[Input, Float]) => x *! Tensor0(2.0f) +val simpleFunc = (x: Tensor1[Input, Float32]) => x *! Tensor0(2.0f) // GOOD: JIT for repeated calls val jitFunc = jit(simpleFunc) diff --git a/README.md b/README.md index 2f2e7aaf..29da3cf6 100644 --- a/README.md +++ b/README.md @@ -45,11 +45,11 @@ val t = Tensor( ) // Function to normalize a single feature vector -def normalize(x: Tensor1[Feature, Float]) : Tensor1[Feature, Float] = +def normalize(x: Tensor1[Feature, Float32]) : Tensor1[Feature, Float32] = (x -! x.mean) /! x.std // Apply the normalization function across the Batch dimension -val normalized: Tensor2[Batch, Feature, Float] = +val normalized: Tensor2[Batch, Feature, Float32] = t.vmap(Axis[Batch])(normalize) ``` diff --git a/core/src/main/scala/dimwit/autodiff/Autodiff.scala b/core/src/main/scala/dimwit/autodiff/Autodiff.scala index 2880775c..b26028da 100644 --- a/core/src/main/scala/dimwit/autodiff/Autodiff.scala +++ b/core/src/main/scala/dimwit/autodiff/Autodiff.scala @@ -7,6 +7,7 @@ import dimwit.tensor.TupleHelpers.PrimeConcatType import dimwit.jax.Jax import me.shadaj.scalapy.py import dimwit.tensor.Label +import dimwit.tensor.TensorOps.IsFloating object Autodiff: @@ -22,10 +23,10 @@ object Autodiff: case Tensor[inS, v2] => Tensor[PrimeConcatType[OutShape, inS], V] // TODO replace with TupledFunction when available (no longer experimental) - def grad[T1, T2, V](f: (T1, T2) => Tensor0[V])(using t1Tree: TensorTree[T1], t2Tree: TensorTree[T2], outTree: TensorTree[Tensor0[V]]): (T1, T2) => Grad[(T1, T2)] = (t1, t2) => grad(f.tupled)((t1, t2)) - def grad[T1, T2, T3, V](f: (T1, T2, T3) => Tensor0[V])(using t1Tree: TensorTree[T1], t2Tree: TensorTree[T2], t3Tree: TensorTree[T3], outTree: TensorTree[Tensor0[V]]): (T1, T2, T3) => Grad[(T1, T2, T3)] = (t1, t2, t3) => grad(f.tupled)((t1, t2, t3)) + def grad[T1, T2, V: IsFloating](f: (T1, T2) => Tensor0[V])(using t1Tree: TensorTree[T1], t2Tree: TensorTree[T2], outTree: TensorTree[Tensor0[V]]): (T1, T2) => Grad[(T1, T2)] = (t1, t2) => grad(f.tupled)((t1, t2)) + def grad[T1, T2, T3, V: IsFloating](f: (T1, T2, T3) => Tensor0[V])(using t1Tree: TensorTree[T1], t2Tree: TensorTree[T2], t3Tree: TensorTree[T3], outTree: TensorTree[Tensor0[V]]): (T1, T2, T3) => Grad[(T1, T2, T3)] = (t1, t2, t3) => grad(f.tupled)((t1, t2, t3)) - def grad[Input, V](f: Input => Tensor0[V])(using + def grad[Input, V: IsFloating](f: Input => Tensor0[V])(using inTree: TensorTree[Input], outTree: TensorTree[Tensor0[V]] ): Input => Grad[Input] = @@ -42,10 +43,10 @@ object Autodiff: val pyGrad = gpy(pyParams) Grad(inTree.fromPyTree(pyGrad).asInstanceOf[Input]) - def valueAndGrad[T1, T2, V](f: (T1, T2) => Tensor0[V])(using t1Tree: TensorTree[T1], t2Tree: TensorTree[T2], outTree: TensorTree[Tensor0[V]]): (T1, T2) => (Tensor0[V], Grad[(T1, T2)]) = (t1, t2) => valueAndGrad(f.tupled)((t1, t2)) - def valueAndGrad[T1, T2, T3, V](f: (T1, T2, T3) => Tensor0[V])(using t1Tree: TensorTree[T1], t2Tree: TensorTree[T2], t3Tree: TensorTree[T3], outTree: TensorTree[Tensor0[V]]): (T1, T2, T3) => (Tensor0[V], Grad[(T1, T2, T3)]) = (t1, t2, t3) => valueAndGrad(f.tupled)((t1, t2, t3)) + def valueAndGrad[T1, T2, V: IsFloating](f: (T1, T2) => Tensor0[V])(using t1Tree: TensorTree[T1], t2Tree: TensorTree[T2], outTree: TensorTree[Tensor0[V]]): (T1, T2) => (Tensor0[V], Grad[(T1, T2)]) = (t1, t2) => valueAndGrad(f.tupled)((t1, t2)) + def valueAndGrad[T1, T2, T3, V: IsFloating](f: (T1, T2, T3) => Tensor0[V])(using t1Tree: TensorTree[T1], t2Tree: TensorTree[T2], t3Tree: TensorTree[T3], outTree: TensorTree[Tensor0[V]]): (T1, T2, T3) => (Tensor0[V], Grad[(T1, T2, T3)]) = (t1, t2, t3) => valueAndGrad(f.tupled)((t1, t2, t3)) - def valueAndGrad[Input, V](f: Input => Tensor0[V])(using + def valueAndGrad[Input, V: IsFloating](f: Input => Tensor0[V])(using inTree: TensorTree[Input], outTree: TensorTree[Tensor0[V]] ): Input => (Tensor0[V], Grad[Input]) = diff --git a/core/src/main/scala/dimwit/autodiff/FloatTree.scala b/core/src/main/scala/dimwit/autodiff/FloatTree.scala index b01d9b9f..f1d0a755 100644 --- a/core/src/main/scala/dimwit/autodiff/FloatTree.scala +++ b/core/src/main/scala/dimwit/autodiff/FloatTree.scala @@ -6,73 +6,97 @@ import scala.deriving.* import scala.compiletime.* import scala.util.NotGiven -/** A marker trait for structures that are trees of Float tensors. +/** A marker trait for structures that are trees of floating-point tensors. * The given instances give evidence that the tensors are - * really of type float + * of type V, constrained by IsFloating. */ -trait FloatTree[P] +trait FloatTree[P, V] object FloatTree: - given [Q <: Tuple]: FloatTree[Tensor[Q, Float]] with {} + // 1. Base case for Tensors + given [Q <: Tuple, V: IsFloating]: FloatTree[Tensor[Q, V], V] with {} - given listInstance[A](using FloatTree[A]): FloatTree[List[A]] with {} + // 2. Inductive base cases for Tuples + // This allows the compiler to step through the case class fields and lock in V. + given emptyTuple[V]: FloatTree[EmptyTuple, V] with {} - given mapInstance[K, A](using FloatTree[A]): FloatTree[Map[K, A]] with {} + given consTuple[H, T <: Tuple, V](using + h: FloatTree[H, V], + t: FloatTree[T, V] + ): FloatTree[H *: T, V] with {} - inline given derived[P <: Product](using m: Mirror.ProductOf[P]): FloatTree[P] = - summonAll[Tuple.Map[m.MirroredElemTypes, FloatTree]] - FloatTreeImpl[P]() - class FloatTreeImpl[P] extends FloatTree[P] + // 3. Standard collections + given listInstance[A, V](using FloatTree[A, V]): FloatTree[List[A], V] with {} - extension [P](p: P)(using tt: TensorTree[P], af: FloatTree[P]) - /** Maps a function over the TensorTree, as for a regula rtensor tree, - * but provides knowledge that tensors are of type float + given mapInstance[K, A, V](using FloatTree[A, V]): FloatTree[Map[K, A], V] with {} + + inline given derived[P <: Product, V](using + evNotTuple: NotGiven[P <:< Tuple], + m: Mirror.ProductOf[P], + evElems: FloatTree[m.MirroredElemTypes, V] + ): FloatTree[P, V] = + FloatTreeImpl[P, V]() + + class FloatTreeImpl[P, V] extends FloatTree[P, V] + + extension [P, V](p: P)(using tt: TensorTree[P], ft: FloatTree[P, V], isF: IsFloating[V]) + /** Maps a function over the TensorTree, as for a regular tensor tree, + * but provides knowledge that tensors are of type V */ - def map(f: [T <: Tuple] => Labels[T] ?=> (Tensor[T, Float] => Tensor[T, Float])): P = - tt.map(p, [T <: Tuple, V] => (n: Labels[T]) ?=> (t: Tensor[T, V]) => f[T](using n)(t.asInstanceOf[Tensor[T, Float]]).asInstanceOf[Tensor[T, V]]) + def map[NewV](f: [T <: Tuple] => Labels[T] ?=> (Tensor[T, V] => Tensor[T, NewV])): P = + tt.map(p, [T <: Tuple, V0] => (n: Labels[T]) ?=> (t: Tensor[T, V0]) => f[T](using n)(t.asInstanceOf[Tensor[T, V]]).asInstanceOf[Tensor[T, V0]]) /** Zipmaps a function over the TensorTree, as for tensor tree, - * but provides knowledge that tensors are of type float + * but provides knowledge that tensors are of type V */ - def zipMap(p2: P, f: [T <: Tuple] => Labels[T] ?=> ((Tensor[T, Float], Tensor[T, Float]) => Tensor[T, Float])): P = + def zipMap(p2: P, f: [T <: Tuple] => Labels[T] ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): P = tt.zipMap( p, p2, - [T <: Tuple, V] => (n: Labels[T]) ?=> (t1: Tensor[T, V], t2: Tensor[T, V]) => f[T](using n)(t1.asInstanceOf[Tensor[T, Float]], t2.asInstanceOf[Tensor[T, Float]]).asInstanceOf[Tensor[T, V]] + [T <: Tuple, V0] => (n: Labels[T]) ?=> (t1: Tensor[T, V0], t2: Tensor[T, V0]) => f[T](using n)(t1.asInstanceOf[Tensor[T, V]], t2.asInstanceOf[Tensor[T, V]]).asInstanceOf[Tensor[T, V0]] ) - /** Arithmetic and math operations for tensor trees of floats. + /** Arithmetic and math operations for tensor trees of floating-point types. */ object ops: // helper typeclass - trait IsFloatTensor[P] - object IsFloatTensor: - given [T <: Tuple]: IsFloatTensor[Tensor[T, Float]] with {} + trait IsFloatingTensor[P, V] + object IsFloatingTensor: + given [T <: Tuple, V: IsFloating]: IsFloatingTensor[Tensor[T, V], V] with {} // Scalar broadcast extensions (Tensor0 op Tree) - extension (p2: Tensor0[Float]) - def ++![P: TensorTree: FloatTree](p1: P): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => a +! p2) - def --![P: TensorTree: FloatTree](p1: P): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => a -! p2) - def **![P: TensorTree: FloatTree](p1: P): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => a *! p2) - def `//!`[P: TensorTree: FloatTree](p1: P): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => a /! p2) + extension [V: IsFloating](p2: Tensor0[V]) + def ++![P](p1: P)(using TensorTree[P], FloatTree[P, V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a +! p2) + def --![P](p1: P)(using TensorTree[P], FloatTree[P, V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a -! p2) + def **![P](p1: P)(using TensorTree[P], FloatTree[P, V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a *! p2) + def `//!`[P](p1: P)(using TensorTree[P], FloatTree[P, V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a /! p2) // Tree extensions (Tree op Tree, Tree op Scalar, and math ops) - // Excluded for bare Tensor[T, Float] to avoid conflicts with tensor's own operators - extension [P](p1: P)(using tt: TensorTree[P], af: FloatTree[P], ev: NotGiven[IsFloatTensor[P]]) - def ++(p2: P): P = p1.zipMap(p2, [T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float], b: Tensor[T, Float]) => a + b) - def ++!(p2: Tensor0[Float]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => a +! p2) - def --(p2: P): P = p1.zipMap(p2, [T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float], b: Tensor[T, Float]) => a - b) - def --!(p2: Tensor0[Float]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => a -! p2) - def **(p2: P): P = p1.zipMap(p2, [T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float], b: Tensor[T, Float]) => a * b) - def **!(p2: Tensor0[Float]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => a *! p2) - def `//`(p2: P): P = p1.zipMap(p2, [T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float], b: Tensor[T, Float]) => a / b) - def `//!`(p2: Tensor0[Float]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => a /! p2) - - def sqrt: P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => TensorOps.sqrt(a)) - def pow(exponent: Tensor0[Float]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => TensorOps.pow(a)(exponent)) - def scale(scalar: Tensor0[Float]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => TensorOps.scale(a)(scalar)) - def sign: P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => TensorOps.sign(a)) - - def fillCopy(value: Float): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, Float]) => Tensor(a.shape).fill(value)) + // Excluded for bare Tensor[T, V] to avoid conflicts with tensor's own operators + extension [P, V](p1: P)(using tt: TensorTree[P], ft: FloatTree[P, V], isF: IsFloating[V], ev: NotGiven[IsFloatingTensor[P, V]]) + def ++(p2: P): P = p1.zipMap(p2, [T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V], b: Tensor[T, V]) => a + b) + def ++!(p2: Tensor0[V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a +! p2) + def --(p2: P): P = p1.zipMap(p2, [T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V], b: Tensor[T, V]) => a - b) + def --!(p2: Tensor0[V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a -! p2) + def **(p2: P): P = p1.zipMap(p2, [T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V], b: Tensor[T, V]) => a * b) + def **!(p2: Tensor0[V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a *! p2) + def `//`(p2: P): P = p1.zipMap(p2, [T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V], b: Tensor[T, V]) => a / b) + def `//!`(p2: Tensor0[V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a /! p2) + + def sqrt: P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => TensorOps.sqrt(a)) + + def pow(exponent: Float): P = pow(Tensor0(VType[V])(exponent)) + def pow(exponent: Tensor0[V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => TensorOps.pow(a)(exponent)) + def scale(scalar: Tensor0[V]): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => TensorOps.scale(a)(scalar)) + def sign: P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => TensorOps.sign(a)) + + def fillCopy(value: Float): P = p1.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => Tensor(a.shape, VType[V]).fill(value)) + + extension [F[_], V](p: F[V])(using tt: TensorTree[F[V]], ft: FloatTree[F[V], V], isF: IsFloating[V]) + + def asFloats[NewV: IsFloating](vtype: VType[NewV])(using m: Mirror.ProductOf[F[NewV]]): F[NewV] = + p.map([T <: Tuple] => (n: Labels[T]) ?=> (a: Tensor[T, V]) => a.asFloat(vtype)).asInstanceOf[F[NewV]] + +type FloatTreeFor[V] = [P] =>> FloatTree[P, V] diff --git a/core/src/main/scala/dimwit/autodiff/Grad.scala b/core/src/main/scala/dimwit/autodiff/Grad.scala index 7adcec9e..472e3d56 100644 --- a/core/src/main/scala/dimwit/autodiff/Grad.scala +++ b/core/src/main/scala/dimwit/autodiff/Grad.scala @@ -2,6 +2,7 @@ package dimwit.autodiff import dimwit.* import dimwit.jax.Jax +import scala.deriving.Mirror /** Type-level tag marking a parameter structure as gradients. * @@ -36,4 +37,13 @@ object Grad: def fromPyTree(pyVal: Jax.PyAny): Grad[T] = Grad(ev.fromPyTree(pyVal)) // FloatTree witness for gradient math (++, --, scale, etc.) - given [T](using FloatTree[T]): FloatTree[Grad[T]] with {} + // given [T, V: IsFloating](using FloatTree[T, V]): FloatTree[Grad[T], V] with {} + + // Bridge extension so we can call .asFloats directly on Grad[Params[V]] + extension [F[_], V](g: Grad[F[V]])(using + tt: TensorTree[F[V]], + ft: FloatTree[F[V], V], + isF: IsFloating[V] + ) + def asFloats[NewV: IsFloating](vtype: VType[NewV])(using m: Mirror.ProductOf[F[NewV]]): Grad[F[NewV]] = + Grad(dimwit.FloatTree.ops.asFloats(g.value)(vtype)) diff --git a/core/src/main/scala/dimwit/jax/JaxDType.scala b/core/src/main/scala/dimwit/jax/JaxDType.scala index 109696b9..34d9ee0a 100644 --- a/core/src/main/scala/dimwit/jax/JaxDType.scala +++ b/core/src/main/scala/dimwit/jax/JaxDType.scala @@ -19,6 +19,8 @@ object JaxDType: try val dtypeStr = jaxDtype.name.as[String] dtypeStr match + case "bfloat16" => DType.BFloat16 + case "float16" => DType.Float16 case "float32" => DType.Float32 case "float64" => DType.Float64 case "int32" => DType.Int32 @@ -43,6 +45,8 @@ object JaxDType: try val jnp = Jax.jnp dtype match + case DType.BFloat16 => jnp.bfloat16 + case DType.Float16 => jnp.float16 case DType.Float32 => jnp.float32 case DType.Float64 => jnp.float64 case DType.Int32 => jnp.int32 diff --git a/core/src/main/scala/dimwit/package.scala b/core/src/main/scala/dimwit/package.scala index 8084f72a..d3b98120 100644 --- a/core/src/main/scala/dimwit/package.scala +++ b/core/src/main/scala/dimwit/package.scala @@ -61,10 +61,10 @@ package object dimwit: export dimwit.tensor.{Tensor, Tensor0, Tensor1, Tensor2, Tensor3} export dimwit.tensor.{Shape, Shape0, Shape1, Shape2, Shape3} export dimwit.tensor.DType + export dimwit.tensor.DType.{BFloat16, Float16, Float32, Float64, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, Bool} + export dimwit.tensor.{ VType, - ExecutionType, - ExecutionTypeFor, Label, Labels, Axis, @@ -89,7 +89,8 @@ package object dimwit: export dimwit.jax.EagerCleanup.eagerCleanup object Conversions: - export dimwit.tensor.Tensor0.{float2FloatTensor, int2IntTensor, int2FloatTensor, boolean2BooleanTensor} + export dimwit.tensor.Tensor0.{boolean2BooleanTensor, byte2IntegerTensor, short2IntegerTensor, int2IntegerTensor, long2IntegerTensor, float2FloatingTensor, int2FloatingTensor, double2FloatingTensor} + // Export random object export dimwit.random.Random export dimwit.random.Random.Key diff --git a/core/src/main/scala/dimwit/random/Random.scala b/core/src/main/scala/dimwit/random/Random.scala index 63e4b1db..ab586cee 100644 --- a/core/src/main/scala/dimwit/random/Random.scala +++ b/core/src/main/scala/dimwit/random/Random.scala @@ -1,6 +1,7 @@ package dimwit.random import dimwit.tensor.* +import dimwit.tensor.DType.Int32 import dimwit.tensor.TensorOps.* import dimwit.jax.{Jax, JaxDType} import dimwit.autodiff.TensorTree @@ -83,7 +84,7 @@ object Random: * @return * A 1D tensor containing a random permutation of [0, 1, ..., n-1] */ - def permutation[L: Label](dim: AxisExtent[L])(key: Key): Tensor1[L, Int] = + def permutation[L: Label](dim: AxisExtent[L])(key: Key): Tensor1[L, Int32] = liftPyTensor(Jax.jrandom.permutation(key.jaxKey, dim.size)) object Key: diff --git a/core/src/main/scala/dimwit/stats/Distributions.scala b/core/src/main/scala/dimwit/stats/Distributions.scala index 2c237dfb..a20f7090 100644 --- a/core/src/main/scala/dimwit/stats/Distributions.scala +++ b/core/src/main/scala/dimwit/stats/Distributions.scala @@ -7,32 +7,32 @@ import dimwit.jax.Jax.scipy_stats as jstats import dimwit.jax.Jax.PyDynamic import dimwit.tensor.TensorOps -opaque type LogProb = Float -opaque type Prob = Float +opaque type LogProb = Float32 +opaque type Prob = Float32 object LogProb: - given IsFloating[LogProb] = summon[IsFloating[Float]] + given IsFloating[LogProb] = summon[IsFloating[Float32]] - def apply[T <: Tuple: Labels](t: Tensor[T, Float]): Tensor[T, LogProb] = t + def apply[T <: Tuple: Labels](t: Tensor[T, Float32]): Tensor[T, LogProb] = t extension [T <: Tuple: Labels](t: Tensor[T, LogProb]) def exp: Tensor[T, Prob] = TensorOps.exp(t) - def log: Tensor[T, Float] = TensorOps.log(t) // Lose LogProb if we log again - def asFloat: Tensor[T, Float] = t + def log: Tensor[T, Float32] = TensorOps.log(t) // Lose LogProb if we log again + def asFloat: Tensor[T, Float32] = t object Prob: - given IsFloating[Prob] = summon[IsFloating[Float]] + given IsFloating[Prob] = summon[IsFloating[Float32]] - def apply[T <: Tuple: Labels](t: Tensor[T, Float]): Tensor[T, Prob] = t + def apply[T <: Tuple: Labels](t: Tensor[T, Float32]): Tensor[T, Prob] = t extension [T <: Tuple: Labels](t: Tensor[T, Prob]) - def exp: Tensor[T, Float] = TensorOps.exp(t) // Lose Prob if we exp again + def exp: Tensor[T, Float32] = TensorOps.exp(t) // Lose Prob if we exp again def log: Tensor[T, LogProb] = TensorOps.log(t) - def asFloat: Tensor[T, Float] = t + def asFloat: Tensor[T, Float32] = t trait Distribution[EventShape <: Tuple: Labels, V]: @@ -58,7 +58,7 @@ trait Distribution[EventShape <: Tuple: Labels, V]: * @tparam EventShape Shape of the tensor of independent values * @tparam V Value type */ -trait IndependentDistribution[EventShape <: Tuple: Labels, V: ExecutionType] extends Distribution[EventShape, V]: +trait IndependentDistribution[EventShape <: Tuple: Labels, V] extends Distribution[EventShape, V]: /** Element-wise log probabilities (primitive operation) */ def elementWiseLogProb(x: Tensor[EventShape, V]): Tensor[EventShape, LogProb] @@ -78,7 +78,7 @@ object IndependentDistribution: * Each element of the resulting tensor is an independent sample from * the same univariate distribution. */ - def fromUnivariate[EventShape <: Tuple: Labels, V: ExecutionType]( + def fromUnivariate[EventShape <: Tuple: Labels, V]( shape: Shape[EventShape], univariate: UnivariateDistribution[V] ): IndependentDistribution[EventShape, V] = diff --git a/core/src/main/scala/dimwit/stats/IndependentDistributions.scala b/core/src/main/scala/dimwit/stats/IndependentDistributions.scala index 59d7a4c1..cccdaf0c 100644 --- a/core/src/main/scala/dimwit/stats/IndependentDistributions.scala +++ b/core/src/main/scala/dimwit/stats/IndependentDistributions.scala @@ -1,6 +1,7 @@ package dimwit.stats import dimwit.* +import dimwit.DType.Float32 import dimwit.jax.Jax.scipy_stats as jstats import dimwit.jax.Jax import dimwit.jax.Jax.PyDynamic @@ -10,47 +11,48 @@ import dimwit.random.Random import dimwit.python.PyBridge.liftPyTensor /** Normal (Gaussian) distribution */ -class Normal[T <: Tuple: Labels](val loc: Tensor[T, Float], val scale: Tensor[T, Float]) extends IndependentDistribution[T, Float]: +class Normal[T <: Tuple: Labels, V: IsFloating](val loc: Tensor[T, V], val scale: Tensor[T, V]) extends IndependentDistribution[T, V]: - override def elementWiseLogProb(x: Tensor[T, Float]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, V]): Tensor[T, LogProb] = liftPyTensor(jstats.norm.logpdf(x.jaxValue, loc = loc.jaxValue, scale = scale.jaxValue)) - override def sample(key: Random.Key): Tensor[T, Float] = + override def sample(key: Random.Key): Tensor[T, V] = val standardNormal = liftPyTensor(loc.shape, loc.vtype)(Jax.jrandom.normal(key.jaxKey, loc.shape.dimensions.toPythonProxy)) standardNormal * scale + loc object Normal: /** Create a Normal distribution from location and scale tensors */ - def apply[T <: Tuple: Labels](loc: Tensor[T, Float], scale: Tensor[T, Float]): Normal[T] = + def apply[T <: Tuple: Labels, V: IsFloating](loc: Tensor[T, V], scale: Tensor[T, V]): Normal[T, V] = require(loc.shape.dimensions == scale.shape.dimensions, "loc and scale must have the same dimensions") new Normal(loc, scale) - def isotropic[T <: Tuple: Labels](loc: Tensor[T, Float], scale: Tensor0[Float]): Normal[T] = new Normal(loc = loc, scale = scale.broadcastTo(loc.shape)) - def standardIsotropic[T <: Tuple: Labels](shape: Shape[T], scale: Tensor0[Float]): Normal[T] = isotropic(loc = Tensor(shape).fill(0f), scale = scale) + def isotropic[T <: Tuple: Labels, V: IsFloating](loc: Tensor[T, V], scale: Tensor0[V]): Normal[T, V] = new Normal(loc = loc, scale = scale.broadcastTo(loc.shape)) + def standardIsotropic[T <: Tuple: Labels, V: IsFloating](shape: Shape[T], scale: Tensor0[V]): Normal[T, V] = isotropic(loc = Tensor(shape, VType[V]).fill(0f), scale = scale) + def standardIsotropic[T <: Tuple: Labels](shape: Shape[T], scale: Float): Normal[T, Float32] = isotropic(loc = Tensor(shape).fill(0f), scale = Tensor0(scale)) /** Sample from standard normal distribution N(0, 1) */ - def standardSample(key: Random.Key): Tensor0[Float] = new Normal(Tensor0(0f), Tensor0(1f)).sample(key) - def standardNormal[T <: Tuple: Labels](shape: Shape[T])(using executionType: ExecutionType[Float]): Normal[T] = Normal.standardIsotropic(shape, scale = Tensor0(1f)) + def standardSample(key: Random.Key): Tensor0[Float32] = new Normal(Tensor0(0f), Tensor0(1f)).sample(key) + def standardNormal[T <: Tuple: Labels](shape: Shape[T]): Normal[T, Float32] = Normal.standardIsotropic(shape, scale = Tensor0(VType[Float32])(1f)) /** Uniform distribution */ -class Uniform[T <: Tuple: Labels](val low: Tensor[T, Float], val high: Tensor[T, Float]) extends IndependentDistribution[T, Float]: +class Uniform[T <: Tuple: Labels, V: IsFloating](val low: Tensor[T, V], val high: Tensor[T, V]) extends IndependentDistribution[T, V]: - override def elementWiseLogProb(x: Tensor[T, Float]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, V]): Tensor[T, LogProb] = liftPyTensor(jstats.uniform.logpdf(x.jaxValue, loc = low.jaxValue, scale = (high - low).jaxValue)) - override def sample(key: Random.Key): Tensor[T, Float] = + override def sample(key: Random.Key): Tensor[T, V] = liftPyTensor( Jax.jrandom.uniform(key.jaxKey, shape = low.shape.dimensions.toPythonProxy, minval = low.jaxValue, maxval = high.jaxValue) ) /** Uniform distribution */ -class DiscreteUniform[T <: Tuple: Labels](val min: Tensor[T, Int], val max: Tensor[T, Int]) extends IndependentDistribution[T, Int]: +class DiscreteUniform[T <: Tuple: Labels](val min: Tensor[T, Int32], val max: Tensor[T, Int32]) extends IndependentDistribution[T, Int32]: - override def elementWiseLogProb(x: Tensor[T, Int]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, Int32]): Tensor[T, LogProb] = liftPyTensor(jstats.randint.logpmf(x.jaxValue, low = min.jaxValue, high = max.jaxValue)) - override def sample(key: Random.Key): Tensor[T, Int] = + override def sample(key: Random.Key): Tensor[T, Int32] = liftPyTensor( Jax.jrandom.randint(key.jaxKey, shape = min.shape.dimensions.toPythonProxy, minval = min.jaxValue, maxval = max.jaxValue) ) @@ -58,22 +60,22 @@ class DiscreteUniform[T <: Tuple: Labels](val min: Tensor[T, Int], val max: Tens object Uniform: /** Create a Uniform distribution from low and high tensors */ - def apply[T <: Tuple: Labels](low: Tensor[T, Float], high: Tensor[T, Float]): Uniform[T] = + def apply[T <: Tuple: Labels, V: IsFloating](low: Tensor[T, V], high: Tensor[T, V]): Uniform[T, V] = require(low.shape.dimensions == high.shape.dimensions, "Low and high must have the same dimensions") new Uniform(low, high) /** Create a discrete Uniform distribution from low and high int tensors */ - def apply[T <: Tuple: Labels](min: Tensor[T, Int], max: Tensor[T, Int]): DiscreteUniform[T] = + def apply[T <: Tuple: Labels](min: Tensor[T, Int32], max: Tensor[T, Int32]): DiscreteUniform[T] = require(min.shape.dimensions == max.shape.dimensions, "min and max must have the same dimensions") new DiscreteUniform(min, max) /** Bernoulli distribution */ -class Bernoulli[T <: Tuple: Labels](val probs: Tensor[T, Prob]) extends IndependentDistribution[T, Boolean]: +class Bernoulli[T <: Tuple: Labels](val probs: Tensor[T, Prob]) extends IndependentDistribution[T, Bool]: - override def elementWiseLogProb(x: Tensor[T, Boolean]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, Bool]): Tensor[T, LogProb] = liftPyTensor(jstats.bernoulli.logpmf(x.jaxValue, p = probs.jaxValue)) - override def sample(key: Random.Key): Tensor[T, Boolean] = + override def sample(key: Random.Key): Tensor[T, Bool] = liftPyTensor(Jax.jrandom.bernoulli(key.jaxKey, p = probs.jaxValue)) object Bernoulli: @@ -83,69 +85,69 @@ object Bernoulli: new Bernoulli(probs) /** Binomial distribution - number of successes in n independent Bernoulli trials */ -class Binomial[T <: Tuple: Labels](val n: Tensor0[Int], val probs: Tensor[T, Prob]) extends IndependentDistribution[T, Int]: +class Binomial[T <: Tuple: Labels, V: IsInteger](val n: Tensor0[V], val probs: Tensor[T, Prob]) extends IndependentDistribution[T, V]: - override def elementWiseLogProb(x: Tensor[T, Int]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, V]): Tensor[T, LogProb] = liftPyTensor(jstats.binom.logpmf(x.jaxValue, n = n.jaxValue, p = probs.jaxValue)) - override def sample(key: Random.Key): Tensor[T, Int] = - liftPyTensor(probs.shape, VType[Int])( + override def sample(key: Random.Key): Tensor[T, V] = + liftPyTensor(probs.shape, VType[V])( Jax.jrandom.binomial(key.jaxKey, n = n.jaxValue, p = probs.jaxValue) ) object Binomial: /** Create a Binomial distribution from number of trials and probability tensor */ - def apply[T <: Tuple: Labels](n: Tensor0[Int], probs: Tensor[T, Prob]): Binomial[T] = + def apply[T <: Tuple: Labels, V: IsInteger](n: Tensor0[V], probs: Tensor[T, Prob]): Binomial[T, V] = new Binomial(n, probs) /** Cauchy distribution */ -class Cauchy[T <: Tuple: Labels](val loc: Tensor[T, Float], val scale: Tensor[T, Float]) extends IndependentDistribution[T, Float]: +class Cauchy[T <: Tuple: Labels, V: IsFloating](val loc: Tensor[T, V], val scale: Tensor[T, V]) extends IndependentDistribution[T, V]: - override def elementWiseLogProb(x: Tensor[T, Float]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, V]): Tensor[T, LogProb] = liftPyTensor(jstats.cauchy.logpdf(x.jaxValue, loc = loc.jaxValue, scale = scale.jaxValue)) - override def sample(k: Random.Key): Tensor[T, Float] = + override def sample(k: Random.Key): Tensor[T, V] = liftPyTensor(Jax.jrandom.cauchy(k.jaxKey, shape = loc.shape.dimensions.toPythonProxy)) * scale + loc object Cauchy: /** Create a Cauchy distribution from location and scale tensors */ - def apply[T <: Tuple: Labels](loc: Tensor[T, Float], scale: Tensor[T, Float]): Cauchy[T] = + def apply[T <: Tuple: Labels, V: IsFloating](loc: Tensor[T, V], scale: Tensor[T, V]): Cauchy[T, V] = require(loc.shape.dimensions == scale.shape.dimensions, "Location and scale must have the same dimensions") new Cauchy(loc, scale) /** Half-normal distribution */ -class HalfNormal[T <: Tuple: Labels](val loc: Tensor[T, Float], val scale: Tensor[T, Float]) extends IndependentDistribution[T, Float]: +class HalfNormal[T <: Tuple: Labels, V: IsFloating](val loc: Tensor[T, V], val scale: Tensor[T, V]) extends IndependentDistribution[T, V]: - override def elementWiseLogProb(x: Tensor[T, Float]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, V]): Tensor[T, LogProb] = // Half-normal logpdf = log(2) + norm.logpdf for x >= loc, -inf otherwise val rawLogProb = liftPyTensor(x.shape, VType[LogProb])( Jax.jnp.log(2.0) + jstats.norm.logpdf(x.jaxValue, loc = loc.jaxValue, scale = scale.jaxValue) ) val valid = x >= loc - val negInf = LogProb(Tensor.like(x).fill(Float.NegativeInfinity)) + val negInf = LogProb(Tensor.like(x).fill(Float.NegativeInfinity).asFloat32) where(valid, rawLogProb, negInf) - override def sample(k: Random.Key): Tensor[T, Float] = + override def sample(k: Random.Key): Tensor[T, V] = // Half-normal: |N(0,1)| * scale + loc - val normal = liftPyTensor(loc.shape, VType[Float])(Jax.jrandom.normal(k.jaxKey, shape = loc.shape.dimensions.toPythonProxy)) + val normal = liftPyTensor(loc.shape, VType[V])(Jax.jrandom.normal(k.jaxKey, shape = loc.shape.dimensions.toPythonProxy)) normal.abs * scale + loc object HalfNormal: /** Create a half-normal distribution from location and scale tensors */ - def apply[T <: Tuple: Labels](loc: Tensor[T, Float], scale: Tensor[T, Float]): HalfNormal[T] = + def apply[T <: Tuple: Labels, V: IsFloating](loc: Tensor[T, V], scale: Tensor[T, V]): HalfNormal[T, V] = require(loc.shape.dimensions == scale.shape.dimensions, "Mean and scale must have the same dimensions") new HalfNormal(loc, scale) /** Student's t-distribution */ -class StudentT[T <: Tuple: Labels](val df: Tensor0[Float], val loc: Tensor[T, Float], val scale: Tensor[T, Float]) extends IndependentDistribution[T, Float]: +class StudentT[T <: Tuple: Labels, V: IsFloating](val df: Tensor0[V], val loc: Tensor[T, V], val scale: Tensor[T, V]) extends IndependentDistribution[T, V]: - override def elementWiseLogProb(x: Tensor[T, Float]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, V]): Tensor[T, LogProb] = liftPyTensor(jstats.t.logpdf(x.jaxValue, df = df.jaxValue, loc = loc.jaxValue, scale = scale.jaxValue)) - override def sample(k: Random.Key): Tensor[T, Float] = + override def sample(k: Random.Key): Tensor[T, V] = liftPyTensor( Jax.jrandom.t(k.jaxKey, df = df.jaxValue, shape = loc.shape.dimensions.toPythonProxy) ) * scale + loc @@ -153,17 +155,17 @@ class StudentT[T <: Tuple: Labels](val df: Tensor0[Float], val loc: Tensor[T, Fl object StudentT: /** Create a Student's t-distribution from parameters */ - def apply[T <: Tuple: Labels](df: Tensor0[Float], loc: Tensor[T, Float], scale: Tensor[T, Float]): StudentT[T] = + def apply[T <: Tuple: Labels, V: IsFloating](df: Tensor0[V], loc: Tensor[T, V], scale: Tensor[T, V]): StudentT[T, V] = require(loc.shape.dimensions == scale.shape.dimensions, "loc, and scale must have the same dimensions") new StudentT(df, loc, scale) /** Beta distribution */ -class Beta[T <: Tuple: Labels](val alpha: Tensor[T, Float], val beta: Tensor[T, Float]) extends IndependentDistribution[T, Float]: +class Beta[T <: Tuple: Labels, V: IsFloating](val alpha: Tensor[T, V], val beta: Tensor[T, V]) extends IndependentDistribution[T, V]: - override def elementWiseLogProb(x: Tensor[T, Float]): Tensor[T, LogProb] = + override def elementWiseLogProb(x: Tensor[T, V]): Tensor[T, LogProb] = liftPyTensor(jstats.beta.logpdf(x.jaxValue, a = alpha.jaxValue, b = beta.jaxValue)) - override def sample(k: Random.Key): Tensor[T, Float] = + override def sample(k: Random.Key): Tensor[T, V] = liftPyTensor( Jax.jrandom.beta(k.jaxKey, a = alpha.jaxValue, b = beta.jaxValue, shape = alpha.shape.dimensions.toPythonProxy) ) @@ -171,6 +173,6 @@ class Beta[T <: Tuple: Labels](val alpha: Tensor[T, Float], val beta: Tensor[T, object Beta: /** Create a Beta distribution from alpha and beta tensors */ - def apply[T <: Tuple: Labels](alpha: Tensor[T, Float], beta: Tensor[T, Float]): Beta[T] = + def apply[T <: Tuple: Labels, V: IsFloating](alpha: Tensor[T, V], beta: Tensor[T, V]): Beta[T, V] = require(alpha.shape.dimensions == beta.shape.dimensions, "alpha and beta must have the same dimensions") new Beta(alpha, beta) diff --git a/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala b/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala index f8c93522..ed66bc7b 100644 --- a/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala +++ b/core/src/main/scala/dimwit/stats/MultivariateDistributions.scala @@ -6,20 +6,21 @@ import dimwit.jax.Jax import dimwit.jax.Jax.scipy_stats as jstats import dimwit.jax.Jax.PyDynamic import dimwit.python.PyBridge.liftPyTensor +import me.shadaj.scalapy.readwrite.Reader /** Distribution over a vector of random variables. */ trait MultivariateDistribution[L: Label, V] extends Distribution[Tuple1[L], V] -class MVNormal[L: Label]( - val mean: Tensor1[L, Float], - val covariance: Tensor2[L, Prime[L], Float] -) extends MultivariateDistribution[L, Float]: +class MVNormal[L: Label, V: IsFloating]( + val mean: Tensor1[L, V], + val covariance: Tensor2[L, Prime[L], V] +) extends MultivariateDistribution[L, V]: - override def logProb(x: Tensor1[L, Float]): Tensor0[LogProb] = + override def logProb(x: Tensor1[L, V]): Tensor0[LogProb] = liftPyTensor(jstats.multivariate_normal.logpdf(x.jaxValue, mean = mean.jaxValue, cov = covariance.jaxValue)) - override def sample(k: Random.Key): Tensor1[L, Float] = + override def sample(k: Random.Key): Tensor1[L, V] = liftPyTensor( Jax.jrandom.multivariate_normal( k.jaxKey, @@ -28,14 +29,14 @@ class MVNormal[L: Label]( ) ) -class Dirichlet[L: Label]( - val concentration: Tensor1[L, Float] -) extends MultivariateDistribution[L, Float]: +class Dirichlet[L: Label, V: IsFloating]( + val concentration: Tensor1[L, V] +) extends MultivariateDistribution[L, V]: - override def logProb(x: Tensor1[L, Float]): Tensor0[LogProb] = + override def logProb(x: Tensor1[L, V]): Tensor0[LogProb] = liftPyTensor(jstats.dirichlet.logpdf(x.jaxValue, alpha = concentration.jaxValue)) - override def sample(k: Random.Key): Tensor1[L, Float] = + override def sample(k: Random.Key): Tensor1[L, V] = liftPyTensor( Jax.jrandom.dirichlet( k.jaxKey, @@ -44,16 +45,16 @@ class Dirichlet[L: Label]( ) class Multinomial[L: Label]( - val n: Tensor0[Int], + val n: Tensor0[Int32], val probs: Tensor1[L, Prob] -) extends MultivariateDistribution[L, Int]: +) extends MultivariateDistribution[L, Int32]: - private val categorical: Categorical[L] = Categorical(probs) + private val categorical: Categorical[L] = Categorical[L](probs) - override def logProb(x: Tensor1[L, Int]): Tensor0[LogProb] = + override def logProb(x: Tensor1[L, Int32]): Tensor0[LogProb] = liftPyTensor(jstats.multinomial.logpmf(x.jaxValue, n = n.jaxValue, p = probs.jaxValue)) - override def sample(key: Random.Key): Tensor1[L, Int] = + override def sample(key: Random.Key): Tensor1[L, Int32] = // Sample from categorical n times using splitvmap, then bincount trait Draws derives Label val draws = key.splitvmap(Axis[Draws] -> n.item)(k => categorical.sample(k)) diff --git a/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala b/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala index 013cd2a2..3bf66df5 100644 --- a/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala +++ b/core/src/main/scala/dimwit/stats/UnivariateDistributions.scala @@ -1,6 +1,7 @@ package dimwit.stats import dimwit.* +import dimwit.DType.{Int32, Float32} import dimwit.random.* import dimwit.jax.Jax import dimwit.python.PyBridge.liftPyTensor @@ -13,16 +14,16 @@ import dimwit.python.PyBridge.liftPyTensor */ type UnivariateDistribution[V] = Distribution[EmptyTuple, V] -class Categorical[L: Label](val probs: Tensor1[L, Prob]) extends UnivariateDistribution[Int]: +class Categorical[L: Label](val probs: Tensor1[L, Prob]) extends UnivariateDistribution[Int32]: private val logProbs: Tensor1[L, LogProb] = probs.log - override def logProb(x: Tensor0[Int]): Tensor0[LogProb] = + override def logProb(x: Tensor0[Int32]): Tensor0[LogProb] = liftPyTensor(logProbs.jaxValue.__getitem__(x.jaxValue)) - override def sample(key: Key): Tensor0[Int] = + override def sample(key: Key): Tensor0[Int32] = liftPyTensor(Jax.jrandom.categorical(key.jaxKey, logProbs.jaxValue)) object Categorical: def apply[L: Label](probs: Tensor1[L, Prob]): Categorical[L] = new Categorical(probs) - def fromFloat[L: Label](probs: Tensor1[L, Float]): Categorical[L] = new Categorical(Prob(probs)) + def fromFloat[L: Label](probs: Tensor1[L, Float32]): Categorical[L] = new Categorical(Prob(probs)) diff --git a/core/src/main/scala/dimwit/tensor/ArrayWriter.scala b/core/src/main/scala/dimwit/tensor/ArrayWriter.scala index 8ff22e22..fcd4524e 100644 --- a/core/src/main/scala/dimwit/tensor/ArrayWriter.scala +++ b/core/src/main/scala/dimwit/tensor/ArrayWriter.scala @@ -8,43 +8,9 @@ import me.shadaj.scalapy.py.SeqConverters import me.shadaj.scalapy.readwrite.Writer import me.shadaj.scalapy.interpreter.PyValue import dimwit.jax.Jax - -trait WriterEvidence[A]: - type V - -/** Type class for providing evidence that a scalar of type A can be converted to type V using a ScalaPy Writer in a Tensor context. - * The type A is the input scalar type, allowing to define an internal precision (dtype) based on the scalar type. - * For example creating a Tensor[?, Int] from a scalar of type Byte with internal dtype uint8, int8, int16 or int32 (based on given ExecutionType[Byte]). - * - * The type V is the value type of the resulting Tensor (should be Boolean, Int or Float; or custom opaque types). - */ -object WriterEvidence: - - type Aux[A, V0] = WriterEvidence[A] { type V = V0 } - -// Helper to instantiate - def apply[A, V0]: Aux[A, V0] = new WriterEvidence[A]: - type V = V0 - - given Aux[Float, Float] = apply - given Aux[Int, Int] = apply - given Aux[Boolean, Boolean] = apply - given Aux[Double, Float] = apply // Double casts to Float - given Aux[Byte, Int] = apply // Byte casts to Int - -/** Type class for creating Tensors of different value types from arrays of different base types. - * While allowing to define an internal precision (dtype) based on the array type. - * For example creating a Tensor[?, Int] from an Array[Byte] with internal dtype uint8, int8, int16 or int32 (based on given ExecutionType[Byte]). - * - * @param A The base type of the input array. - * @param V The value type of the resulting Tensor. - */ -trait ArrayWriter[A]: - type V - def fromArray[T <: Tuple: Labels](shape: Shape[T])(values: Array[A]): Tensor[T, V] +import dimwit.tensor.TensorOps.{IsBoolean, IsInteger, IsFloating} object ArrayWriter: - type Aux[A, V0] = ArrayWriter[A] { type V = V0 } val base64Loader = py.eval("lambda b64, shape, dtype: __import__('jax').numpy.array(__import__('numpy').frombuffer(__import__('base64').b64decode(b64), dtype=dtype).reshape(shape))") @@ -52,42 +18,37 @@ object ArrayWriter: val b64String = Base64.getEncoder.encodeToString(byteArray) Tensor(base64Loader(b64String, shape.dimensions.toPythonProxy, jaxDType)) - given (using ExecutionType[Double]): ArrayWriter.Aux[Double, Float] = new ArrayWriter[Double]: - type V = Float - def fromArray[T <: Tuple: Labels](shape: Shape[T])(values: Array[Double]): Tensor[T, Float] = - require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") - val dtype = ExecutionType[Double].dtype - val byteArray = dtype.write(values) - byteArrayToTensor(shape, byteArray, dtype.jaxType) - - given (using ExecutionType[Float]): ArrayWriter.Aux[Float, Float] = new ArrayWriter[Float]: - type V = Float - def fromArray[T <: Tuple: Labels](shape: Shape[T])(values: Array[Float]): Tensor[T, Float] = - require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") - val dtype = ExecutionType[Float].dtype - val byteArray = dtype.write(values) - byteArrayToTensor(shape, byteArray, dtype.jaxType) - - given (using ExecutionType[Int]): ArrayWriter.Aux[Int, Int] = new ArrayWriter[Int]: - type V = Int - def fromArray[T <: Tuple: Labels](shape: Shape[T])(values: Array[Int]): Tensor[T, Int] = - require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") - val dtype = ExecutionType[Int].dtype - val byteArray = dtype.write(values) - byteArrayToTensor(shape, byteArray, dtype.jaxType) - - given (using ExecutionType[Byte]): ArrayWriter.Aux[Byte, Int] = new ArrayWriter[Byte]: - type V = Int - def fromArray[T <: Tuple: Labels](shape: Shape[T])(values: Array[Byte]): Tensor[T, Int] = - require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") - val dtype = ExecutionType[Byte].dtype - val byteArray = dtype.write(values) - byteArrayToTensor(shape, byteArray, dtype.jaxType) - - given (using ExecutionType[Boolean]): ArrayWriter.Aux[Boolean, Boolean] = new ArrayWriter[Boolean]: - type V = Boolean - def fromArray[T <: Tuple: Labels](shape: Shape[T])(values: Array[Boolean]): Tensor[T, Boolean] = - require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") - val dtype = ExecutionType[Boolean].dtype - val byteArray = dtype.write(values) - byteArrayToTensor(shape, byteArray, dtype.jaxType) + def fromArray[T <: Tuple: Labels, V: IsFloating](shape: Shape[T], values: Array[Double]): Tensor[T, V] = + require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") + val dtype = IsFloating[V].dtype + byteArrayToTensor(shape, dtype.write(values), dtype.jaxType) + + def fromArray[T <: Tuple: Labels, V: IsFloating](shape: Shape[T], values: Array[Float]): Tensor[T, V] = + require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") + val dtype = IsFloating[V].dtype + byteArrayToTensor(shape, dtype.write(values), dtype.jaxType) + + def fromArray[T <: Tuple: Labels, V: IsInteger](shape: Shape[T], values: Array[Int]): Tensor[T, V] = + require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") + val dtype = IsInteger[V].dtype + byteArrayToTensor(shape, dtype.write(values), dtype.jaxType) + + def fromArray[T <: Tuple: Labels, V: IsInteger](shape: Shape[T], values: Array[Long]): Tensor[T, V] = + require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") + val dtype = IsInteger[V].dtype + byteArrayToTensor(shape, dtype.write(values), dtype.jaxType) + + def fromArray[T <: Tuple: Labels, V: IsInteger](shape: Shape[T], values: Array[Byte]): Tensor[T, V] = + require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") + val dtype = IsInteger[V].dtype + byteArrayToTensor(shape, dtype.write(values), dtype.jaxType) + + def fromArray[T <: Tuple: Labels, V: IsInteger](shape: Shape[T], values: Array[Short]): Tensor[T, V] = + require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") + val dtype = IsInteger[V].dtype + byteArrayToTensor(shape, dtype.write(values), dtype.jaxType) + + def fromArray[T <: Tuple: Labels, V: IsBoolean](shape: Shape[T], values: Array[Boolean]): Tensor[T, V] = + require(values.length == shape.size, s"Values length ${values.length} does not match shape size ${shape.size}") + val dtype = IsBoolean[V].dtype + byteArrayToTensor(shape, dtype.write(values), dtype.jaxType) diff --git a/core/src/main/scala/dimwit/tensor/Axis.scala b/core/src/main/scala/dimwit/tensor/Axis.scala index be80572b..7ea19d8a 100644 --- a/core/src/main/scala/dimwit/tensor/Axis.scala +++ b/core/src/main/scala/dimwit/tensor/Axis.scala @@ -1,6 +1,7 @@ package dimwit.tensor import dimwit.|*| +import dimwit.tensor.DType.Int32 import scala.compiletime.{constValue, erasedValue, summonInline} @@ -14,7 +15,7 @@ final class Axis[L: Label]: def at(index: Int): AxisAtIndex[L] = AxisAtIndex(this, index) def at(range: Range): AxisAtRange[L] = AxisAtRange(this, range) def at(indices: Seq[Int]): AxisAtIndices[L] = AxisAtIndices(this, indices) - def at(index: Tensor0[Int]): AxisAtTensorIndex[L] = AxisAtTensorIndex(this, index) + def at(index: Tensor0[Int32]): AxisAtTensorIndex[L] = AxisAtTensorIndex(this, index) def at[I <: NonEmptyTuple](indices: I): AxisAtTupleIndices[L, I] = AxisAtTupleIndices(this, indices) def as[U](newAxis: Axis[U]): (Axis[L], Axis[U]) = (this, newAxis) @@ -45,7 +46,7 @@ case class AxisAtRange[L](axis: Axis[L], range: Range) extends AxisSelector[L] case class AxisAtIndices[L](axis: Axis[L], indices: Seq[Int]) extends AxisSelector[L] /* Represent an axis selection by a tensor containing indices. This allows for dynamic indexing based on the contents of the tensor. */ -case class AxisAtTensorIndex[L](axis: Axis[L], index: Tensor0[Int]) extends AxisSelector[L] +case class AxisAtTensorIndex[L](axis: Axis[L], index: Tensor0[Int32]) extends AxisSelector[L] /* Represent an axis selection by a tuple containing indices. */ case class AxisAtTupleIndices[L, I <: NonEmptyTuple](axis: Axis[L], indices: I) extends AxisSelector[L] diff --git a/core/src/main/scala/dimwit/tensor/DType.scala b/core/src/main/scala/dimwit/tensor/DType.scala index 4e1ccbae..42e5b86d 100644 --- a/core/src/main/scala/dimwit/tensor/DType.scala +++ b/core/src/main/scala/dimwit/tensor/DType.scala @@ -2,8 +2,61 @@ package dimwit.tensor import dimwit.jax.JaxDType import java.nio.ByteBuffer import java.nio.ByteOrder +import dimwit.tensor.TensorOps.{IsFloating, IsInteger, IsBoolean} + +object DType: + + type UInt8 = UInt8.type + given uint8IsFloating: IsInteger[UInt8] with + def dtype: DType = DType.UInt8 + + type UInt16 = UInt16.type + given uint16IsInteger: IsInteger[UInt16] with + def dtype: DType = DType.UInt16 + + type UInt32 = UInt32.type + given uint32IsInteger: IsInteger[UInt32] with + def dtype: DType = DType.UInt32 + + type Int8 = Int8.type + given int8IsInteger: IsInteger[Int8] with + def dtype: DType = DType.Int8 + + type Int16 = Int16.type + given int16IsInteger: IsInteger[Int16] with + def dtype: DType = DType.Int16 + + type Int32 = Int32.type + given int32IsInteger: IsInteger[Int32] with + def dtype: DType = DType.Int32 + + type Int64 = Int64.type + given int64IsInteger: IsInteger[Int64] with + def dtype: DType = DType.Int64 + + type Float16 = Float16.type + given float16IsFloating: IsFloating[Float16] with + def dtype: DType = DType.Float16 + + type BFloat16 = BFloat16.type + given bfloat16IsFloating: IsFloating[BFloat16] with + def dtype: DType = DType.BFloat16 + + type Float32 = Float32.type + given float32IsFloating: IsFloating[Float32] with + def dtype: DType = DType.Float32 + + type Float64 = Float64.type + given float64IsFloating: IsFloating[Float64] with + def dtype: DType = DType.Float64 + + type Bool = Bool.type + given boolIsBoolean: IsBoolean[Bool] with + def dtype: DType = DType.Bool enum DType(val name: String, val size: Int): + case BFloat16 extends DType("bfloat16", 2) + case Float16 extends DType("float16", 2) case Float32 extends DType("float32", 4) case Float64 extends DType("float64", 8) case Int32 extends DType("int32", 4) @@ -28,6 +81,12 @@ enum DType(val name: String, val size: Int): // write values into buffer according to this DType (this, values) match + case (Float16, arr: Array[Float]) => + val sb = buffer.asShortBuffer() + var i = 0 + while i < arr.length do + sb.put(floatToFloat16(arr(i))) + i += 1 // --- Float32 Target --- case (Float32, arr: Array[Float]) => buffer.asFloatBuffer().put(arr) @@ -74,3 +133,31 @@ enum DType(val name: String, val size: Int): throw new IllegalArgumentException(s"Conversion from ${values.getClass.getSimpleName} to DType $name is not supported or implemented.") buffer.array() + + private def floatToFloat16(f: Float): Short = + // TODO replace with java.lang.Float.floatToFloat16 when we can require Java 20+ + val bits = java.lang.Float.floatToIntBits(f) + val sign = (bits >>> 16) & 0x8000 + var valBits = bits & 0x7fffffff + + if valBits >= 0x47800000 then + // NaN or Infinity + if (valBits & 0x7f800000) == 0x7f800000 then + if (valBits & 0x007fffff) != 0 then + return (sign | 0x7c00 | (valBits & 0x007fffff) >>> 13).toShort // NaN + return (sign | 0x7c00).toShort // Infinity + return (sign | 0x7bff).toShort // Overflow + + if valBits >= 0x38800000 then + // Normalized number + return (sign | valBits - 0x38000000 >>> 13).toShort + + if valBits < 0x33000000 then + // Underflow to zero + return sign.toShort + + // Denormalized number + valBits = (valBits & 0x007fffff) | 0x00800000 + val shift = 113 - (bits >>> 23 & 0xff) + valBits = if shift < 24 then valBits >>> shift else 0 + (sign | valBits).toShort diff --git a/core/src/main/scala/dimwit/tensor/Tensor.scala b/core/src/main/scala/dimwit/tensor/Tensor.scala index e81504ed..9cbeae48 100644 --- a/core/src/main/scala/dimwit/tensor/Tensor.scala +++ b/core/src/main/scala/dimwit/tensor/Tensor.scala @@ -5,7 +5,7 @@ import scala.compiletime.{erasedValue, summonFrom} import dimwit.jax.Jax import dimwit.jax.JaxDType import dimwit.jax.Jax.PyDynamic -import dimwit.tensor.{Label, Labels, ExecutionType, VType} +import dimwit.tensor.{Label, Labels, VType} import me.shadaj.scalapy.py import me.shadaj.scalapy.py.SeqConverters import dimwit.random.Random @@ -17,6 +17,8 @@ import dimwit.Prime import ShapeTypeHelpers.AxisIndex import dimwit.hardware.Device import me.shadaj.scalapy.readwrite.Writer.stringWriter.given +import dimwit.tensor.TensorOps.{IsBoolean, IsInteger, IsFloating} +import DType.* class Tensor[T <: Tuple: Labels, V] private[dimwit] ( private[dimwit] val jaxValue: Jax.PyDynamic @@ -60,25 +62,70 @@ object Tensor: type IndicesOf[T <: Tuple] = Tuple.Map[T, [_] =>> Int] - case class Factory[T <: Tuple: Labels](val shape: Shape[T]): - - def fill[A: ExecutionType: Writer, V](value: A)(using ev: WriterEvidence.Aux[A, V]): Tensor[T, V] = - Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = ExecutionType[A].dtype.jaxType)) - - def fromArray[A: ExecutionType, V](values: Array[A])(using t2a: ArrayWriter.Aux[A, V]): Tensor[T, V] = - t2a.fromArray[T](shape)(values) + case class DefaultsFactory[T <: Tuple: Labels](shape: Shape[T]): + + def fill(value: Float): Tensor[T, Float32] = Tensor(shape, VType[Float32]).fill(value) + def fill(value: Double): Tensor[T, Float64] = Tensor(shape, VType[Float64]).fill(value) + def fromArray(values: Array[Float]): Tensor[T, Float32] = Tensor(shape, VType[Float32]).fromArray(values) + def fromArray(values: Array[Double]): Tensor[T, Float64] = Tensor(shape, VType[Float64]).fromArray(values) + + def fill(value: Byte): Tensor[T, Int8] = Tensor(shape, VType[Int8]).fill(value) + def fill(value: Short): Tensor[T, Int16] = Tensor(shape, VType[Int16]).fill(value) + def fill(value: Int): Tensor[T, Int32] = Tensor(shape, VType[Int32]).fill(value) + def fill(value: Long): Tensor[T, Int64] = Tensor(shape, VType[Int64]).fill(value) + def fromArray(values: Array[Byte]): Tensor[T, Int8] = Tensor(shape, VType[Int8]).fromArray(values) + def fromArray(values: Array[Short]): Tensor[T, Int16] = Tensor(shape, VType[Int16]).fromArray(values) + def fromArray(values: Array[Int]): Tensor[T, Int32] = Tensor(shape, VType[Int32]).fromArray(values) + def fromArray(values: Array[Long]): Tensor[T, Int64] = Tensor(shape, VType[Int64]).fromArray(values) + + def fill(value: Boolean): Tensor[T, Bool] = Tensor(shape, VType[Bool]).fill(value) + def fromArray(values: Array[Boolean]): Tensor[T, Bool] = Tensor(shape, VType[Bool]).fromArray(values) + + case class TypedFactory[T <: Tuple: Labels, V](shape: Shape[T], vtype: VType[V]): + + // --- Boolean --- + def fill(value: Boolean)(using IsBoolean[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + def fromArray(values: Array[Boolean])(using IsBoolean[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) + + // --- Integer --- + def fill(value: Byte)(using IsInteger[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + def fill(value: Short)(using IsInteger[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value.toInt, dtype = vtype.dtype.jaxType)) + def fill(value: Int)(using IsInteger[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + def fill(value: Long)(using IsInteger[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + def fromArray(values: Array[Byte])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) + def fromArray(values: Array[Short])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) + def fromArray(values: Array[Int])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) + def fromArray(values: Array[Long])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) + + // --- Floating --- + def fill(value: Float)(using IsFloating[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + def fill(value: Double)(using IsFloating[V]): Tensor[T, V] = Tensor(Jax.jnp.full(shape.dimensions.toPythonProxy, value, dtype = vtype.dtype.jaxType)) + def fromArray(values: Array[Float])(using IsFloating[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) + def fromArray(values: Array[Double])(using IsFloating[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](shape, values) case class LikeFactory[T <: Tuple: Labels, V](val other: Tensor[T, V]): - def fill[A: Writer](value: A)(using ev: WriterEvidence.Aux[A, V]): Tensor[T, V] = - Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fill(value: Boolean): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fromArray(values: Array[Boolean])(using IsBoolean[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) + + def fill(value: Byte): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fill(value: Short): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value.toInt, dtype = other.dtype.jaxType)) + def fill(value: Int): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fill(value: Long): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fromArray(values: Array[Byte])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) + def fromArray(values: Array[Short])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) + def fromArray(values: Array[Int])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) + def fromArray(values: Array[Long])(using IsInteger[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) - def fromArray[A](values: Array[A])(using t2a: ArrayWriter.Aux[A, V]): Tensor[T, V] = - given ExecutionType[A] = ExecutionTypeFor[A](other.dtype) // fix the underlying dtype to match the other tensor's dtype - summon[ArrayWriter[A]].fromArray[T](other.shape)(values) + def fill(value: Float): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fill(value: Double): Tensor[T, V] = Tensor(Jax.jnp.full(other.shape.dimensions.toPythonProxy, value, dtype = other.dtype.jaxType)) + def fromArray(values: Array[Float])(using IsFloating[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) + def fromArray(values: Array[Double])(using IsFloating[V]): Tensor[T, V] = ArrayWriter.fromArray[T, V](other.shape, values) private[dimwit] def apply[T <: Tuple: Labels, V](jaxValue: Jax.PyDynamic): Tensor[T, V] = new Tensor(jaxValue) - def apply[T <: Tuple: Labels](shape: Shape[T]): Tensor.Factory[T] = Tensor.Factory(shape) + + def apply[T <: Tuple: Labels](shape: Shape[T]): DefaultsFactory[T] = DefaultsFactory(shape) + def apply[T <: Tuple: Labels, V](shape: Shape[T], vtype: VType[V]): TypedFactory[T, V] = TypedFactory(shape, vtype) def like[T <: Tuple: Labels, V](template: Tensor[T, V]): Tensor.LikeFactory[T, V] = Tensor.LikeFactory(template) type Tensor0[V] = Tensor[EmptyTuple, V] @@ -89,44 +136,174 @@ type Tensor4[L1, L2, L3, L4, V] = Tensor[(L1, L2, L3, L4), V] object Tensor0: - given float2FloatTensor: Conversion[Float, Tensor0[Float]] = (x: Float) => Tensor0(x) - given int2IntTensor: Conversion[Int, Tensor0[Int]] = (x: Int) => Tensor0(x) - given int2FloatTensor: Conversion[Int, Tensor0[Float]] = (x: Int) => Tensor0(x.toFloat) - given boolean2BooleanTensor: Conversion[Boolean, Tensor0[Boolean]] = (x: Boolean) => Tensor0(x) + given boolean2BooleanTensor[V: IsBoolean]: Conversion[Boolean, Tensor0[V]] with + def apply(value: Boolean): Tensor0[V] = Tensor0(VType[V])(value) + + given byte2IntegerTensor[V: IsInteger]: Conversion[Byte, Tensor0[V]] with + def apply(value: Byte): Tensor0[V] = Tensor0(VType[V])(value) + + given short2IntegerTensor[V: IsInteger]: Conversion[Short, Tensor0[V]] with + def apply(value: Short): Tensor0[V] = Tensor0(VType[V])(value) + + given int2IntegerTensor[V: IsInteger]: Conversion[Int, Tensor0[V]] with + def apply(value: Int): Tensor0[V] = Tensor0(VType[V])(value) + + given int2FloatingTensor[V: IsFloating]: Conversion[Int, Tensor0[V]] with + def apply(value: Int): Tensor0[V] = Tensor0(VType[V])(value.toFloat) + + given long2IntegerTensor[V: IsInteger]: Conversion[Long, Tensor0[V]] with + def apply(value: Long): Tensor0[V] = Tensor0(VType[V])(value) + + given float2FloatingTensor[V: IsFloating]: Conversion[Float, Tensor0[V]] with + def apply(value: Float): Tensor0[V] = Tensor0(VType[V])(value) + + given double2FloatingTensor[V: IsFloating]: Conversion[Double, Tensor0[V]] with + def apply(value: Double): Tensor0[V] = Tensor0(VType[V])(value) + + object DefaultsFactory: + + def apply(value: Boolean): Tensor0[Bool] = Tensor0(VType[Bool])(value) + + def apply(value: Byte): Tensor0[Int8] = Tensor0(VType[Int8])(value) + def apply(value: Short): Tensor0[Int16] = Tensor0(VType[Int16])(value) + def apply(value: Int): Tensor0[Int32] = Tensor0(VType[Int32])(value) + def apply(value: Long): Tensor0[Int64] = Tensor0(VType[Int64])(value) + + def apply(value: Float): Tensor0[Float32] = Tensor0(VType[Float32])(value) + def apply(value: Double): Tensor0[Float64] = Tensor0(VType[Float64])(value) + + case class TypedFactory[V](vtype: VType[V]): + + // --- Boolean --- + def apply(value: Boolean)(using IsBoolean[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) + + // --- Integer --- + def apply(value: Byte)(using IsInteger[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) + def apply(value: Short)(using IsInteger[V]): Tensor0[V] = Tensor(Jax.jnp.array(value.toInt, dtype = vtype.dtype.jaxType)) + def apply(value: Int)(using IsInteger[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) + def apply(value: Long)(using IsInteger[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) + + // --- Floating --- + def apply(value: Float)(using IsFloating[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) + def apply(value: Double)(using IsFloating[V]): Tensor0[V] = Tensor(Jax.jnp.array(value, dtype = vtype.dtype.jaxType)) + + export DefaultsFactory.* + def apply[V](vtype: VType[V]): TypedFactory[V] = TypedFactory(vtype) - def apply[V: ExecutionType: Writer](value: V): Tensor0[V] = Tensor(Jax.jnp.full(Shape0.dimensions.toPythonProxy, value, dtype = ExecutionType[V].dtype.jaxType)) def like[V: Writer](template: Tensor0[V])(value: V): Tensor0[V] = Tensor(Jax.jnp.full(Shape0.dimensions.toPythonProxy, value, dtype = template.dtype.jaxType)) + def likeDType[V, T <: Tuple](template: Tensor[T, V])(value: Float): Tensor0[V] = Tensor(Jax.jnp.full(Shape0.dimensions.toPythonProxy, value, dtype = template.dtype.jaxType)) def apply[V](jaxValue: Jax.PyDynamic): Tensor0[V] = Tensor(jaxValue) object Tensor1: - case class Factory[L: Label](val axis: Axis[L]): - private def createShape(l: Int): Shape1[L] = Shape1(AxisExtent(axis, l)) - def fromArray[A: ExecutionType, V](values: Array[A])(using t2a: ArrayWriter.Aux[A, V]): Tensor[Tuple1[L], V] = Tensor(createShape(values.length)).fromArray(values) + case class DefaultsFactory[L: Label](axis: Axis[L]): + + // --- Boolean --- + def fromArray(values: Array[Boolean]): Tensor1[L, Bool] = Tensor1(axis, VType[Bool]).fromArray(values) + + // --- Integer --- + def fromArray(values: Array[Byte]): Tensor1[L, Int8] = Tensor1(axis, VType[Int8]).fromArray(values) + def fromArray(values: Array[Short]): Tensor1[L, Int16] = Tensor1(axis, VType[Int16]).fromArray(values) + def fromArray(values: Array[Int]): Tensor1[L, Int32] = Tensor1(axis, VType[Int32]).fromArray(values) + def fromArray(values: Array[Long]): Tensor1[L, Int64] = Tensor1(axis, VType[Int64]).fromArray(values) + + // --- Floating --- + def fromArray(values: Array[Float]): Tensor1[L, Float32] = Tensor1(axis, VType[Float32]).fromArray(values) + def fromArray(values: Array[Double]): Tensor1[L, Float64] = Tensor1(axis, VType[Float64]).fromArray(values) - def apply[L: Label](axis: Axis[L]): Tensor1.Factory[L] = Tensor1.Factory(axis) - def fromPy[L: Label, V](axis: Axis[L], vtype: VType[V])(jaxValue: Jax.PyDynamic): Tensor1[L, V] = new Tensor(jaxValue) + case class TypedFactory[L: Label, V](axis: Axis[L], vtype: VType[V]): + + // --- Boolean --- + def fromArray(values: Array[Boolean])(using IsBoolean[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) + + // --- Integer --- + def fromArray(values: Array[Byte])(using IsInteger[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) + def fromArray(values: Array[Short])(using IsInteger[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) + def fromArray(values: Array[Int])(using IsInteger[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) + def fromArray(values: Array[Long])(using IsInteger[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) + + // --- Floating --- + def fromArray(values: Array[Float])(using IsFloating[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) + def fromArray(values: Array[Double])(using IsFloating[V]): Tensor1[L, V] = ArrayWriter.fromArray[Tuple1[L], V](Shape1(axis -> values.length), values) + + def apply[L: Label](axis: Axis[L]): DefaultsFactory[L] = DefaultsFactory(axis) + def apply[L: Label, V](axis: Axis[L], vtype: VType[V]): TypedFactory[L, V] = TypedFactory(axis, vtype) object Tensor2: - case class Factory[L1: Label, L2: Label](val axis1: Axis[L1], val axis2: Axis[L2]): - private def createShape[V](values: Array[Array[V]]): Shape2[L1, L2] = Shape2(AxisExtent(axis1, values.length), AxisExtent(axis2, values.head.length)) - def fromArray[A: ClassTag: ExecutionType, V](values: Array[Array[A]])(using t2a: ArrayWriter.Aux[A, V]): Tensor[(L1, L2), V] = Tensor(createShape(values)).fromArray(values.flatten) + type Array2D[V] = Array[Array[V]] + + case class DefaultsFactory[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2]): + + def fromArray(values: Array2D[Boolean]): Tensor2[L1, L2, Bool] = Tensor2(axis1, axis2, VType[Bool]).fromArray(values) + + def fromArray(values: Array2D[Byte]): Tensor2[L1, L2, Int8] = Tensor2(axis1, axis2, VType[Int8]).fromArray(values) + def fromArray(values: Array2D[Short]): Tensor2[L1, L2, Int16] = Tensor2(axis1, axis2, VType[Int16]).fromArray(values) + def fromArray(values: Array2D[Int]): Tensor2[L1, L2, Int32] = Tensor2(axis1, axis2, VType[Int32]).fromArray(values) + def fromArray(values: Array2D[Long]): Tensor2[L1, L2, Int64] = Tensor2(axis1, axis2, VType[Int64]).fromArray(values) + + def fromArray(values: Array2D[Float]): Tensor2[L1, L2, Float32] = Tensor2(axis1, axis2, VType[Float32]).fromArray(values) + def fromArray(values: Array2D[Double]): Tensor2[L1, L2, Float64] = Tensor2(axis1, axis2, VType[Float64]).fromArray(values) + + case class TypedFactory[L1: Label, L2: Label, V](axis1: Axis[L1], axis2: Axis[L2], vtype: VType[V]): + + private def createShape[V](values: Array2D[V]): Shape2[L1, L2] = Shape2(AxisExtent(axis1, values.length), AxisExtent(axis2, values.head.length)) + + // --- Boolean --- + def fromArray(values: Array2D[Boolean])(using IsBoolean[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) + + // --- Integer --- + def fromArray(values: Array2D[Byte])(using IsInteger[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) + def fromArray(values: Array2D[Short])(using IsInteger[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) + def fromArray(values: Array2D[Int])(using IsInteger[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) + def fromArray(values: Array2D[Long])(using IsInteger[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) - def apply[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2]): Tensor2.Factory[L1, L2] = Tensor2.Factory(axis1, axis2) + // --- Floating --- + def fromArray(values: Array2D[Float])(using IsFloating[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) + def fromArray(values: Array2D[Double])(using IsFloating[V]): Tensor2[L1, L2, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten) - private def eyeImpl[L: Label, V](dim: AxisExtent[L], dtype: DType): Tensor2[L, Prime[L], V] = Tensor(Jax.jnp.eye(dim.size, dtype = dtype.jaxType)) - def eye[L: Label](dim: AxisExtent[L])(using et: ExecutionType[Float]): Tensor2[L, Prime[L], Float] = eyeImpl(dim, et.dtype) - def eye[L: Label, V](dim: AxisExtent[L], vtype: VType[V]): Tensor2[L, Prime[L], V] = eyeImpl(dim, vtype.dtype) + def apply[L1: Label, L2: Label](axis1: Axis[L1], axis2: Axis[L2]): DefaultsFactory[L1, L2] = DefaultsFactory(axis1, axis2) + def apply[L1: Label, L2: Label, V](axis1: Axis[L1], axis2: Axis[L2], vtype: VType[V]): TypedFactory[L1, L2, V] = TypedFactory(axis1, axis2, vtype) + + private def eyeImpl[L: Label, V](dim: AxisExtent[L], vtype: VType[V]): Tensor2[L, Prime[L], V] = Tensor(Jax.jnp.eye(dim.size, dtype = vtype.dtype.jaxType)) + def eye[L: Label](dim: AxisExtent[L]): Tensor2[L, Prime[L], Float32] = eyeImpl(dim, VType[Float32]) + def eye[L: Label, V](dim: AxisExtent[L], vtype: VType[V]): Tensor2[L, Prime[L], V] = eyeImpl(dim, vtype) def diag[L: Label, V](diag: Tensor1[L, V]): Tensor2[L, Prime[L], V] = Tensor(Jax.jnp.diag(diag.jaxValue)) object Tensor3: - case class Factory[L1: Label, L2: Label, L3: Label](val axis1: Axis[L1], val axis2: Axis[L2], val axis3: Axis[L3]): - private def createShape[V](values: Array[Array[Array[V]]]): Shape3[L1, L2, L3] = - Shape3(AxisExtent(axis1, values.length), AxisExtent(axis2, values.head.length), AxisExtent(axis3, values.head.head.length)) - def fromArray[A: ExecutionType: ClassTag, V](values: Array[Array[Array[A]]])(using t2a: ArrayWriter.Aux[A, V]): Tensor3[L1, L2, L3, V] = - Tensor(createShape(values)).fromArray(values.flatten.flatten) + type Array3D[V] = Array[Array[Array[V]]] + + case class DefaultsFactory[L1: Label, L2: Label, L3: Label](axis1: Axis[L1], axis2: Axis[L2], axis3: Axis[L3]): + + def fromArray(values: Array3D[Boolean]): Tensor3[L1, L2, L3, Bool] = Tensor3(axis1, axis2, axis3, VType[Bool]).fromArray(values) + + def fromArray(values: Array3D[Byte]): Tensor3[L1, L2, L3, Int8] = Tensor3(axis1, axis2, axis3, VType[Int8]).fromArray(values) + def fromArray(values: Array3D[Short]): Tensor3[L1, L2, L3, Int16] = Tensor3(axis1, axis2, axis3, VType[Int16]).fromArray(values) + def fromArray(values: Array3D[Int]): Tensor3[L1, L2, L3, Int32] = Tensor3(axis1, axis2, axis3, VType[Int32]).fromArray(values) + def fromArray(values: Array3D[Long]): Tensor3[L1, L2, L3, Int64] = Tensor3(axis1, axis2, axis3, VType[Int64]).fromArray(values) + + def fromArray(values: Array3D[Float]): Tensor3[L1, L2, L3, Float32] = Tensor3(axis1, axis2, axis3, VType[Float32]).fromArray(values) + def fromArray(values: Array3D[Double]): Tensor3[L1, L2, L3, Float64] = Tensor3(axis1, axis2, axis3, VType[Float64]).fromArray(values) + + case class TypedFactory[L1: Label, L2: Label, L3: Label, V](axis1: Axis[L1], axis2: Axis[L2], axis3: Axis[L3], vtype: VType[V]): + + private def createShape[V](values: Array3D[V]): Shape3[L1, L2, L3] = Shape3(AxisExtent(axis1, values.length), AxisExtent(axis2, values.head.length), AxisExtent(axis3, values.head.head.length)) + + // --- Boolean --- + def fromArray(values: Array3D[Boolean])(using IsBoolean[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) + + // --- Integer --- + def fromArray(values: Array3D[Byte])(using IsInteger[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) + def fromArray(values: Array3D[Short])(using IsInteger[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) + def fromArray(values: Array3D[Int])(using IsInteger[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) + def fromArray(values: Array3D[Long])(using IsInteger[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) + + // --- Floating --- + def fromArray(values: Array3D[Float])(using IsFloating[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) + def fromArray(values: Array3D[Double])(using IsFloating[V]): Tensor3[L1, L2, L3, V] = Tensor(createShape(values), VType[V]).fromArray(values.flatten.flatten) + + def apply[L1: Label, L2: Label, L3: Label](axis1: Axis[L1], axis2: Axis[L2], axis3: Axis[L3]): DefaultsFactory[L1, L2, L3] = DefaultsFactory(axis1, axis2, axis3) - def apply[L1: Label, L2: Label, L3: Label](axis1: Axis[L1], axis2: Axis[L2], axis3: Axis[L3]): Tensor3.Factory[L1, L2, L3] = Tensor3.Factory(axis1, axis2, axis3) + def apply[L1: Label, L2: Label, L3: Label, V](axis1: Axis[L1], axis2: Axis[L2], axis3: Axis[L3], vtype: VType[V]): TypedFactory[L1, L2, L3, V] = TypedFactory(axis1, axis2, axis3, vtype) diff --git a/core/src/main/scala/dimwit/tensor/TensorOps.scala b/core/src/main/scala/dimwit/tensor/TensorOps.scala index d12f32ec..5d09abe1 100644 --- a/core/src/main/scala/dimwit/tensor/TensorOps.scala +++ b/core/src/main/scala/dimwit/tensor/TensorOps.scala @@ -25,6 +25,8 @@ import dimwit.tensor.ShapeTypeHelpers.AxisIndex import dimwit.tensor.ShapeTypeHelpers.AxisIndices import dimwit.tensor.ShapeTypeHelpers.AxesMerger import dimwit.OnError +import dimwit.DType.* +import dimwit.DType.given import Tuple.:* import Tuple.++ @@ -36,6 +38,9 @@ object TensorOps: import TensorOpsUtil.* + sealed trait HasDType[V]: + def dtype: DType + @implicitNotFound("Operation only valid for Numeric (Int or Float) tensors.") sealed trait IsNumber[V] @@ -49,19 +54,25 @@ object TensorOps: given [V](using ev2: IsInteger[V]): IsNumber[V] = ev2 @implicitNotFound("Operation only valid for Floating tensors.") - trait IsFloating[V] extends IsNumber[V] + trait IsFloating[V] extends IsNumber[V], HasDType[V]: + def dtype: DType + object IsFloating: - given IsFloating[Float] with {} + def apply[V](using ev: IsFloating[V]): IsFloating[V] = ev - @implicitNotFound("Operation only valid for Integer tensors.") - trait IsInteger[V] extends IsNumber[V] object IsInteger: - given IsInteger[Int] with {} + def apply[V](using ev: IsInteger[V]): IsInteger[V] = ev - @implicitNotFound("Operation only valid for Boolean tensors.") - sealed trait IsBoolean[V] object IsBoolean: - given IsBoolean[Boolean] with {} + def apply[V](using ev: IsBoolean[V]): IsBoolean[V] = ev + + @implicitNotFound("Operation only valid for Integer tensors.") + trait IsInteger[V] extends IsNumber[V], HasDType[V]: + def dtype: DType + + @implicitNotFound("Operation only valid for Boolean tensors.") + trait IsBoolean[V] extends HasDType[V]: + def dtype: DType // ----------------------------------------------------------- // 1. Elementwise Operations (The Field) @@ -79,19 +90,22 @@ object TensorOps: extension [T <: Tuple: Labels, V](t: Tensor[T, V]) // --- Comparison --- - def <(other: Tensor[T, V]): Tensor[T, Boolean] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue)) - def <=(other: Tensor[T, V]): Tensor[T, Boolean] = Tensor(Jax.jnp.less_equal(t.jaxValue, other.jaxValue)) - def >(other: Tensor[T, V]): Tensor[T, Boolean] = Tensor(Jax.jnp.greater(t.jaxValue, other.jaxValue)) - def >=(other: Tensor[T, V]): Tensor[T, Boolean] = Tensor(Jax.jnp.greater_equal(t.jaxValue, other.jaxValue)) - def ===(other: Tensor[T, V]): Tensor0[Boolean] = Tensor0(Jax.jnp.array_equal(t.jaxValue, other.jaxValue)) + def <(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less(t.jaxValue, other.jaxValue)) + def <=(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.less_equal(t.jaxValue, other.jaxValue)) + def >(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.greater(t.jaxValue, other.jaxValue)) + def >=(other: Tensor[T, V]): Tensor[T, Bool] = Tensor(Jax.jnp.greater_equal(t.jaxValue, other.jaxValue)) + def ===(other: Tensor[T, V]): Tensor0[Bool] = Tensor0(Jax.jnp.array_equal(t.jaxValue, other.jaxValue)) - def elementEquals(other: Tensor[T, V]): Tensor[T, Boolean] = + def elementEquals(other: Tensor[T, V]): Tensor[T, Bool] = require(t.shape.dimensions == other.shape.dimensions, s"Shape mismatch: ${t.shape.dimensions} vs ${other.shape.dimensions}") Tensor(jaxValue = Jax.jnp.equal(t.jaxValue, other.jaxValue)) - def asBoolean: Tensor[T, Boolean] = t.asType(VType[Boolean]) - def asInt: Tensor[T, Int] = t.asType(VType[Int]) - def asFloat: Tensor[T, Float] = t.asType(VType[Float]) + def asBool: Tensor[T, Bool] = t.asType(VType[Bool]) + def asBoolean[NewV: IsBoolean](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) + def asInt32: Tensor[T, Int32] = t.asType(VType[Int32]) + def asInt[NewV: IsInteger](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) + def asFloat32: Tensor[T, Float32] = t.asType(VType[Float32]) + def asFloat[NewV: IsFloating](vtype: VType[NewV]): Tensor[T, NewV] = t.asType(vtype) // --------------------------------------------------------- // IsNumber operations (IsFloat or IsInt) @@ -147,8 +161,8 @@ object TensorOps: def cos: Tensor[T, V] = Tensor(Jax.jnp.cos(t.jaxValue)) def tanh: Tensor[T, V] = Tensor(Jax.jnp.tanh(t.jaxValue)) - def approxEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor0[Boolean] = approxElementEquals(other, tolerance).all - def approxElementEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor[T, Boolean] = + def approxEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor0[Bool] = approxElementEquals(other, tolerance).all + def approxElementEquals(other: Tensor[T, V], tolerance: Float = 1e-6f): Tensor[T, Bool] = Tensor( Jax.jnp.allclose( t.jaxValue, @@ -164,10 +178,10 @@ object TensorOps: extension [T <: Tuple: Labels, V: IsBoolean](t: Tensor[T, V]) - def all: Tensor0[Boolean] = Tensor0(Jax.jnp.all(t.jaxValue)) - def any: Tensor0[Boolean] = Tensor0(Jax.jnp.any(t.jaxValue)) + def all: Tensor0[V] = Tensor0(Jax.jnp.all(t.jaxValue)) + def any: Tensor0[V] = Tensor0(Jax.jnp.any(t.jaxValue)) - def unary_! : Tensor[T, Boolean] = Tensor(Jax.jnp.logical_not(t.jaxValue)) + def unary_! : Tensor[T, V] = Tensor(Jax.jnp.logical_not(t.jaxValue)) end Elementwise @@ -199,19 +213,19 @@ object TensorOps: def min[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, V] = Tensor(Jax.jnp.min(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Argmax --- - def argmax: Tensor0[Int] = Tensor0(Jax.jnp.argmax(t.jaxValue)) - def argmax[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, Int] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.index)) - def argmax[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, Int] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.indices.toPythonProxy)) + def argmax: Tensor0[Int32] = Tensor0(Jax.jnp.argmax(t.jaxValue)) + def argmax[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.index)) + def argmax[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, Int32] = Tensor(Jax.jnp.argmax(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Argmin --- - def argmin: Tensor0[Int] = Tensor0(Jax.jnp.argmin(t.jaxValue)) - def argmin[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, Int] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.index)) - def argmin[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, Int] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.indices.toPythonProxy)) + def argmin: Tensor0[Int32] = Tensor0(Jax.jnp.argmin(t.jaxValue)) + def argmin[L: Label, R <: Tuple](axis: Axis[L])(using ev: AxisRemover[T, L, R], l: Labels[R]): Tensor[R, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.index)) + def argmin[Inputs <: Tuple, R <: Tuple](axes: Inputs)(using ev: AxesRemover[T, UnwrapAxes[Inputs], R], l: Labels[R]): Tensor[R, Int32] = Tensor(Jax.jnp.argmin(t.jaxValue, axis = ev.indices.toPythonProxy)) // --- Argsort --- - def argsort: Tensor[T, Int] = Tensor(Jax.jnp.argsort(t.jaxValue)) - def argsort[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): Tensor[T, Int] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.index)) - def argsort[Inputs <: Tuple](axes: Inputs)(using ev: AxisIndices[T, UnwrapAxes[Inputs]]): Tensor[T, Int] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.indices.toPythonProxy)) + def argsort: Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue)) + def argsort[L: Label](axis: Axis[L])(using ev: AxisIndex[T, L]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.index)) + def argsort[Inputs <: Tuple](axes: Inputs)(using ev: AxisIndices[T, UnwrapAxes[Inputs]]): Tensor[T, Int32] = Tensor(Jax.jnp.argsort(t.jaxValue, axis = ev.indices.toPythonProxy)) // --------------------------------------------------------- // IsFloat operations (IsFloat or IsInt) @@ -307,7 +321,7 @@ object TensorOps: extension [S1: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: InChannel *: EmptyTuple, V]) def conv1d[OutChannel: Label]( - kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, Float], + kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride1[S1] | Int = 1, padding: Padding = Padding.SAME ): Tensor[S1 *: OutChannel *: EmptyTuple, V] = @@ -333,7 +347,7 @@ object TensorOps: extension [S1: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: OutChannel *: EmptyTuple, V]) def transposeConv1d[InChannel: Label]( - kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, Float], + kernel: Tensor[S1 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride1[S1] | Int = 1, padding: Padding = Padding.SAME ): Tensor[S1 *: InChannel *: EmptyTuple, V] = @@ -365,7 +379,7 @@ object TensorOps: extension [S1: Label, S2: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V]) def conv2d[OutChannel: Label]( - kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, Float], + kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride2[S1, S2] | Int = 1, padding: Padding = Padding.SAME ): Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V] = @@ -391,7 +405,7 @@ object TensorOps: extension [S1: Label, S2: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V]) def transposeConv2d[InChannel: Label]( - kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, Float], + kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride2[S1, S2] | Int = 1, padding: Padding = Padding.SAME ): Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V] = @@ -426,7 +440,7 @@ object TensorOps: extension [S1: Label, S2: Label, S3: Label, InChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V]) def conv3d[OutChannel: Label]( - kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, Float], + kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride3[S1, S2, S3] | Int = 1, padding: Padding = Padding.SAME ): Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V] = @@ -454,7 +468,7 @@ object TensorOps: extension [S1: Label, S2: Label, S3: Label, OutChannel: Label, V: IsFloating](input: Tensor[S1 *: S2 *: S3 *: OutChannel *: EmptyTuple, V]) def transposeConv3d[InChannel: Label]( - kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, Float], + kernel: Tensor[S1 *: S2 *: S3 *: InChannel *: OutChannel *: EmptyTuple, V], stride: Stride3[S1, S2, S3] | Int = 1, padding: Padding = Padding.SAME ): Tensor[S1 *: S2 *: S3 *: InChannel *: EmptyTuple, V] = @@ -564,7 +578,7 @@ object TensorOps: case A *: tail => A *: B *: tail case h *: tail => h *: InsertAfter[tail, A, B] - type SliceIndex = Int | List[Int] | Range | Tensor0[Int] + type SliceIndex = Int | List[Int] | Range | Tensor0[Int32] type ExtractLabel[X] = X match case AxisAtIndex[l] => l case AxisAtRange[l] => l @@ -614,8 +628,8 @@ object TensorOps: given consTensor0Int[L, Tail <: Tuple, TailOut <: Tuple](using tailExt: SliceLabelExtractor[Tail, TailOut] - ): SliceLabelExtractor[(Axis[L], Tensor0[Int]) *: Tail, L *: TailOut] = - new SliceLabelExtractor[(Axis[L], Tensor0[Int]) *: Tail, L *: TailOut] {} + ): SliceLabelExtractor[(Axis[L], Tensor0[Int32]) *: Tail, L *: TailOut] = + new SliceLabelExtractor[(Axis[L], Tensor0[Int32]) *: Tail, L *: TailOut] {} given consSeq[L, SeqT <: Seq[Int], Tail <: Tuple, TailOut <: Tuple](using tailExt: SliceLabelExtractor[Tail, TailOut] @@ -637,7 +651,7 @@ object TensorOps: object TensorWhere: def where[T <: Tuple: Labels, V]( - condition: Tensor[T, Boolean], + condition: Tensor[T, Bool], x: Tensor[T, V], y: Tensor[T, V] ): Tensor[T, V] = @@ -873,7 +887,7 @@ object TensorOps: indicesBuffer(dimIndex) = PySlice(range.head, range.last + 1, range.step) case idx: Int => indicesBuffer(dimIndex) = py.Any.from(idx) - case tensorId: Tensor0[Int] @unchecked => + case tensorId: Tensor0[Int32] @unchecked => indicesBuffer(dimIndex) = tensorId.jaxValue } @@ -1039,7 +1053,7 @@ object TensorOps: def take[L1, L2: Label, R <: Tuple]( axis: Axis[L1] )( - indices: Tensor1[L2, Int] + indices: Tensor1[L2, Int32] )(using ev: AxisRemover[T, L1, R], labels: Labels[R] @@ -1058,6 +1072,18 @@ object TensorOps: val result = tensor.jaxValue.at.bracketAccess(pyIndices).set(value.jaxValue) Tensor(result) + // Convenience overload for Float + def set[Inputs <: Tuple, LabelsToRemove <: Tuple]( + inputs: Inputs + )(using + sliceExtractor: SliceLabelExtractor[Inputs, LabelsToRemove], + ev: AxesConditionalRemover[T, LabelsToRemove, ExtractLabels[Inputs], EmptyTuple], + labels: Labels[T] + )(value: Float): Tensor[T, V] = + val pyIndices = tensor.calcPyIndices(inputs, ev.indices) + val result = tensor.jaxValue.at.bracketAccess(pyIndices).set(value) + Tensor(result) + // Convenience overload for AxisAtIndex def set[L, LabelsToRemove <: Tuple, R <: Tuple]( selector: AxisAtIndex[L] @@ -1411,40 +1437,80 @@ object TensorOps: object Tensor0Ops: - extension [V: Reader](scalar: Tensor0[V]) - - def item: V = - require( - !scalar.isTracer, - """ - | Cannot convert a JAX Tracer to a scalar value. Tensor0 is part of a JAX computation graph (e.g., inside vmap or a jitted function). - | Common mistakes leading to this error: - | - calling .slice(t0.item) rather than .slice(t0); breaking the computation graph unintentionally. - |""".stripMargin - ) - scalar.jaxValue.item().as[V] + private inline def checkTracer[V, R](scalar: Tensor0[V]): Unit = + require( + !scalar.isTracer, + """ + | Cannot convert a JAX Tracer to a scalar value. Tensor0 is part of a JAX computation graph (e.g., inside vmap or a jitted function). + | Common mistakes leading to this error: + | - calling .slice(t0.item) rather than .slice(t0); breaking the computation graph unintentionally. + |""".stripMargin + ) + + extension (scalar: Tensor0[Bool]) + def item: Boolean = + checkTracer(scalar) + scalar.jaxValue.item().as[Boolean] + + extension (scalar: Tensor0[Int8]) + def item: Byte = + checkTracer(scalar) + scalar.jaxValue.item().as[Byte] + + extension (scalar: Tensor0[Int16]) + def item: Short = + checkTracer(scalar) + scalar.jaxValue.item().as[Int].toShort + + extension (scalar: Tensor0[Int32]) + def item: Int = + checkTracer(scalar) + scalar.jaxValue.item().as[Int] + + extension (scalar: Tensor0[Int64]) + def item: Long = + checkTracer(scalar) + scalar.jaxValue.item().as[Long] + + extension (scalar: Tensor0[Float32]) + def item: Float = + checkTracer(scalar) + scalar.jaxValue.item().as[Float] + + extension (scalar: Tensor0[Float64]) + def item: Double = + checkTracer(scalar) + scalar.jaxValue.item().as[Double] object ValueOps: import Elementwise.+! - extension [V: IsNumber: Writer](scalar: V) + extension [V: IsNumber](t: Tensor0[V]) + + def +(t2: Tensor0[V]): Tensor0[V] = TensorOps.add(t, t2) + def -(t2: Tensor0[V]): Tensor0[V] = TensorOps.subtract(t, t2) + def *(t2: Tensor0[V]): Tensor0[V] = TensorOps.multiply(t, t2) + + extension [V: IsFloating](t: Tensor0[V]) + + def /(scalar: Tensor0[V]): Tensor0[V] = TensorOps.divide(t, scalar) + + extension (scalar: Float) + + def +[V: IsNumber](t: Tensor0[V]): Tensor0[V] = add(Tensor0.likeDType(t)(scalar), t) + def +![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(add) + + def -[V: IsNumber](t: Tensor0[V]): Tensor0[V] = subtract(Tensor0.likeDType(t)(scalar), t) + def -![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(subtract) - def +![T <: Tuple: Labels](t: Tensor[T, V]): Tensor[T, V] = - given ExecutionType[V] = ExecutionTypeFor[V](t.dtype) - Tensor0(scalar).broadcastTo(t.shape) + t - def -![T <: Tuple: Labels](t: Tensor[T, V]): Tensor[T, V] = - given ExecutionType[V] = ExecutionTypeFor[V](t.dtype) - Tensor0(scalar).broadcastTo(t.shape) - t - def *![T <: Tuple: Labels](t: Tensor[T, V]): Tensor[T, V] = - given ExecutionType[V] = ExecutionTypeFor[V](t.dtype) - Tensor0(scalar).broadcastTo(t.shape) * t + def *[V: IsNumber](t: Tensor0[V]): Tensor0[V] = multiply(Tensor0.likeDType(t)(scalar), t) + def *![T <: Tuple: Labels, V: IsNumber](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(multiply) - extension [V: IsFloating: Writer](scalar: V) + extension (scalar: Float) - def /![T <: Tuple: Labels](t: Tensor[T, V]): Tensor[T, V] = - given ExecutionType[V] = ExecutionTypeFor[V](t.dtype) - Tensor0(scalar).broadcastTo(t.shape) / t + def /[V: IsFloating](t: Tensor0[V]): Tensor0[V] = divide(Tensor0.likeDType(t)(scalar), t) + def /![T <: Tuple: Labels, V: IsFloating](t: Tensor[T, V])(using bc: Broadcast[EmptyTuple, T, V]): Tensor[bc.Out, V] = bc.applyTo(Tensor0.likeDType(t)(scalar), t)(divide) object Tensor1Ops: @@ -1454,7 +1520,7 @@ object TensorOps: // TODO generalize to TensorN (like slice) def dynamicSlice( - dynamicStart: Tensor0[Int], + dynamicStart: Tensor0[Int32], staticSize: Int )(using label: Label[L] diff --git a/core/src/main/scala/dimwit/tensor/VType.scala b/core/src/main/scala/dimwit/tensor/VType.scala new file mode 100644 index 00000000..78782b70 --- /dev/null +++ b/core/src/main/scala/dimwit/tensor/VType.scala @@ -0,0 +1,16 @@ +package dimwit.tensor + +import dimwit.stats.Prob +import dimwit.stats.LogProb +import scala.compiletime.ops.double +import java.nio.ByteBuffer +import dimwit.tensor.TensorOps.HasDType + +object VType: + def apply[V](tensor: Tensor[?, V]): VType[V] = VTypeImpl[V](tensor.dtype) + def apply[A: HasDType]: VType[A] = VTypeImpl[A](summon[HasDType[A]].dtype) + +sealed trait VType[A]: + def dtype: DType + +private case class VTypeImpl[A](override val dtype: DType) extends VType[A] diff --git a/core/src/main/scala/dimwit/tensor/Value.scala b/core/src/main/scala/dimwit/tensor/Value.scala deleted file mode 100644 index ab2e5b67..00000000 --- a/core/src/main/scala/dimwit/tensor/Value.scala +++ /dev/null @@ -1,45 +0,0 @@ -package dimwit.tensor - -import dimwit.stats.Prob -import dimwit.stats.LogProb -import scala.compiletime.ops.double -import java.nio.ByteBuffer - -trait ExecutionType[V]: - def dtype: DType - -object ExecutionType: - - def apply[V](using executionType: ExecutionType[V]): ExecutionType[V] = executionType - - given floatValue: ExecutionType[Float] with - def dtype: DType = DType.Float32 - - given intValue: ExecutionType[Int] with - def dtype: DType = DType.Int32 - - given booleanValue: ExecutionType[Boolean] with - def dtype: DType = DType.Bool - - given byteValue: ExecutionType[Byte] with - def dtype: DType = DType.Int8 - - given doubleValue: ExecutionType[Double] with - def dtype: DType = DType.Float64 - - given prob: ExecutionType[Prob] with - def dtype: DType = summon[ExecutionType[Float]].dtype - - given logProb: ExecutionType[LogProb] with - def dtype: DType = summon[ExecutionType[Float]].dtype - -object VType: - def apply[V](tensor: Tensor[?, V]): VType[V] = new OfImpl[V](tensor.dtype) - def apply[A: ExecutionType]: VType[A] = new OfImpl[A](summon[ExecutionType[A]].dtype) - -sealed trait VType[A]: - def dtype: DType - -class OfImpl[A](val dtype: DType) extends VType[A] - -case class ExecutionTypeFor[V](dtype: DType) extends ExecutionType[V] diff --git a/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala b/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala index 5f585353..af60d772 100644 --- a/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/AutodiffSuite.scala @@ -1,6 +1,5 @@ package dimwit.autodiff -import dimwit.* import dimwit.* import dimwit.Conversions.given import dimwit.autodiff.Autodiff.Gradient @@ -10,10 +9,10 @@ class AutodiffSuite extends DimwitTest: describe("grad"): describe("single parameter function"): it("d¹, d², d³ of x²"): - def f(x: Tensor0[Float]) = x * x + def f(x: Tensor0[Float32]) = x * x val df = Autodiff.grad(f) - val ddf = Autodiff.grad((x: Tensor0[Float]) => df(x).value) - val dddf = Autodiff.grad((x: Tensor0[Float]) => ddf(x).value) + val ddf = Autodiff.grad((x: Tensor0[Float32]) => df(x).value) + val dddf = Autodiff.grad((x: Tensor0[Float32]) => ddf(x).value) val x = Tensor0(3.0f) df(x) shouldEqual Tensor0(6.0f) @@ -21,14 +20,14 @@ class AutodiffSuite extends DimwitTest: dddf(x) shouldEqual Tensor0(0.0f) it("d¹ sum(x²)"): - def f(x: Tensor1[A, Float]) = (x * x).sum + def f(x: Tensor1[A, Float32]) = (x * x).sum val df = Autodiff.grad(f) val x = Tensor1(Axis[A]).fromArray(Array(1.0f, 5.0f)) df(x) shouldEqual Tensor1(Axis[A]).fromArray(Array(2.0f, 10.0f)) it("d¹ function using vmap"): - def f(x: Tensor2[A, B, Float]) = x.vmap(Axis[A])(_.sum).sum + def f(x: Tensor2[A, B, Float32]) = x.vmap(Axis[A])(_.sum).sum val df = Autodiff.grad(f) val x = Tensor(Shape(Axis[A] -> 2, Axis[B] -> 2)).fill(1f) @@ -36,7 +35,7 @@ class AutodiffSuite extends DimwitTest: describe("two parameter function"): it("d¹/dx and d¹/dy of (x + 2y)²"): - def f(x: Tensor1[A, Float], y: Tensor1[A, Float]) = ((x + (y *! 2.0f)).pow(Tensor0(2.0f))).sum + def f(x: Tensor1[A, Float32], y: Tensor1[A, Float32]) = ((x + (y *! 2.0f)).pow(Tensor0(2.0f))).sum val df = Autodiff.grad(f) val x = Tensor1(Axis[A]).fromArray(Array(1.0f)) @@ -50,7 +49,7 @@ class AutodiffSuite extends DimwitTest: describe("two parameter function"): it("d¹/dx and d¹/dy of (x + 2y)²"): - def f(x: Tensor1[A, Float], y: Tensor1[A, Float]) = ((x + (y *! 2.0f)).pow(Tensor0(2.0f))).sum + def f(x: Tensor1[A, Float32], y: Tensor1[A, Float32]) = ((x + (y *! 2.0f)).pow(Tensor0(2.0f))).sum val df = Autodiff.grad(f) val x = Tensor1(Axis[A]).fromArray(Array(1.0f)) @@ -65,7 +64,7 @@ class AutodiffSuite extends DimwitTest: describe("jacobian"): describe("single parameter function"): it("Jacobian of f: R² -> R², f(x) = 2x"): - def f(x: Tensor1[A, Float]) = x *! 2.0f + def f(x: Tensor1[A, Float32]) = x *! 2.0f val jf = Autodiff.jacobian(f) val x = Tensor1(Axis[A]).fromArray(Array(1.0f, 1.0f)) @@ -82,7 +81,7 @@ class AutodiffSuite extends DimwitTest: engines.foreach: case (modeName, jacMode) => it(s"$modeName d¹ on f: R² -> R², f(x) = swap(x)"): - def f(x1: Tensor1[A, Float], x2: Tensor1[A, Float]): (Tensor1[A, Float], Tensor1[A, Float]) = (x2, x1) + def f(x1: Tensor1[A, Float32], x2: Tensor1[A, Float32]): (Tensor1[A, Float32], Tensor1[A, Float32]) = (x2, x1) val df = jacMode(f.tupled) val x1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 0.0f)) val x2 = Tensor1(Axis[A]).fromArray(Array(0.0f, 1.0f)) @@ -95,7 +94,7 @@ class AutodiffSuite extends DimwitTest: x2_dx2 should approxEqual(Tensor.like(x2_dx2).fill(0f)) it(s"$modeName d² on f: R² -> R, f(x1, x2) = sum(x1 * x2)"): - def f(x1: Tensor1[A, Float], x2: Tensor1[A, Float]): Tensor0[Float] = (x1 * x2).sum + def f(x1: Tensor1[A, Float32], x2: Tensor1[A, Float32]): Tensor0[Float32] = (x1 * x2).sum val df = jacMode(f.tupled) val ddf = jacMode(df) val x1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) @@ -110,8 +109,8 @@ class AutodiffSuite extends DimwitTest: describe("Complex application"): it("case class support"): - case class Params(w: Tensor1[A, Float], b: Tensor0[Float]) - def loss(data: Tensor1[A, Float])(params: Params): Tensor0[Float] = + case class Params(w: Tensor1[A, Float32], b: Tensor0[Float32]) + def loss(data: Tensor1[A, Float32])(params: Params): Tensor0[Float32] = ((data * params.w).sum + params.b).pow(Tensor0(2.0f)) val trainData = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) val dloss = Autodiff.grad(loss(trainData)) diff --git a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala index cd77f5be..e7b43414 100644 --- a/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/FloatTensorTreeSuite.scala @@ -10,10 +10,10 @@ class FloatTensorTreeSuite extends DimwitTest: describe("map"): it("1-level case class"): case class Params( - val w1: Tensor1[A, Float], - val b1: Tensor0[Float], - val w2: Tensor2[A, B, Float], - val b2: Tensor0[Float] + val w1: Tensor1[A, Float32], + val b1: Tensor0[Float32], + val w2: Tensor2[A, B, Float32], + val b2: Tensor0[Float32] ) val params = Params( Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), @@ -21,7 +21,7 @@ class FloatTensorTreeSuite extends DimwitTest: Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(0.1f, 0.2f), Array(0.3f, 0.4f), Array(0.5f, 0.6f))), Tensor0(0.25f) ) - val res = params.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float]) => x +! 0.5f) + val res = params.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float32]) => x +! 0.5f) res.w1 should approxEqual(params.w1 +! 0.5f) res.b1 should approxEqual(params.b1 + 0.5f) res.w2 should approxEqual(params.w2 +! 0.5f) @@ -29,8 +29,8 @@ class FloatTensorTreeSuite extends DimwitTest: it("2-level case class"): case class LayerParams( - val w: Tensor2[A, B, Float], - val b: Tensor0[Float] + val w: Tensor2[A, B, Float32], + val b: Tensor0[Float32] ) case class ModelParams( val layer1: LayerParams, @@ -45,7 +45,7 @@ class FloatTensorTreeSuite extends DimwitTest: Tensor0(0.75f) ) val params = ModelParams(layer1Params, layer2Params) - val res = params.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float]) => x +! 0.5f) + val res = params.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float32]) => x +! 0.5f) res.layer1.w should approxEqual(params.layer1.w +! 0.5f) res.layer1.b should approxEqual(params.layer1.b + 0.5f) @@ -54,20 +54,57 @@ class FloatTensorTreeSuite extends DimwitTest: it("case class with tuple"): case class LayerParams( - val weightBias: (Tensor2[A, B, Float], Tensor0[Float]) + val weightBias: (Tensor2[A, B, Float32], Tensor0[Float32]) ) val layerParams = LayerParams( Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(0.1f, 0.2f), Array(0.3f, 0.4f), Array(0.5f, 0.6f))), Tensor0(0.25f) ) - val res = layerParams.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float]) => x +! 0.5f) + val res = layerParams.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float32]) => x +! 0.5f) res.weightBias._1 should approxEqual(layerParams.weightBias._1 +! 0.5f) res.weightBias._2 should approxEqual(layerParams.weightBias._2 + 0.5f) + it("example for Float16"): + case class LayerParams( + val weightBias: (Tensor2[A, B, Float16], Tensor0[Float16]) + ) + val layerParams = LayerParams( + Tensor2(Axis[A], Axis[B], VType[Float16]).fromArray(Array(Array(0.1f, 0.2f), Array(0.3f, 0.4f), Array(0.5f, 0.6f))), + Tensor0(VType[Float16])(0.25f) + ) + val res = layerParams.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float16]) => x +! 0.5f) + + res.weightBias._1.asFloat32 should approxEqual((layerParams.weightBias._1 +! 0.5f).asFloat32) + res.weightBias._2.asFloat32 should approxEqual((layerParams.weightBias._2 + 0.5f).asFloat32) + + it("example for V"): + case class LayerParams[V]( + val weightBias: (Tensor2[A, B, V], Tensor0[V]) + ) + val layerParams = LayerParams( + Tensor2(Axis[A], Axis[B], VType[Float16]).fromArray(Array(Array(0.1f, 0.2f), Array(0.3f, 0.4f), Array(0.5f, 0.6f))), + Tensor0(VType[Float16])(0.25f) + ) + val res = layerParams.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float16]) => x +! 0.5f) + + res.weightBias._1.asFloat32 should approxEqual((layerParams.weightBias._1 +! 0.5f).asFloat32) + res.weightBias._2.asFloat32 should approxEqual((layerParams.weightBias._2 + 0.5f).asFloat32) + + it("example for Float32 to Float16"): + case class LayerParams[V]( + val weightBias: (Tensor2[A, B, V], Tensor0[V]) + ) + val layerParams = LayerParams( + Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(0.1f, 0.2f), Array(0.3f, 0.4f), Array(0.5f, 0.6f))), + Tensor0(0.25f) + ) + val layerParamsFloat16: LayerParams[Float16] = layerParams.asFloats(VType[Float16]) + layerParamsFloat16.weightBias._1.dtype.name shouldBe "float16" + it("case class with list"): case class Params( - val layerWeights: List[Tensor2[A, B, Float]] + val layerWeights: List[Tensor2[A, B, Float32]] ) val layerParams = Params( List( @@ -75,7 +112,7 @@ class FloatTensorTreeSuite extends DimwitTest: Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.1f, 1.2f), Array(1.3f, 1.4f), Array(1.5f, 1.6f))) ) ) - val res = layerParams.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float]) => x +! 0.5f) + val res = layerParams.map([T <: Tuple] => (labels: Labels[T]) ?=> (x: Tensor[T, Float32]) => x +! 0.5f) res.layerWeights(0) should approxEqual(layerParams.layerWeights(0) +! 0.5f) res.layerWeights(1) should approxEqual(layerParams.layerWeights(1) +! 0.5f) @@ -83,8 +120,8 @@ class FloatTensorTreeSuite extends DimwitTest: describe("zipmap"): it("1-level case class"): case class Params( - val w1: Tensor1[A, Float], - val b1: Tensor0[Float] + val w1: Tensor1[A, Float32], + val b1: Tensor0[Float32] ) val params1 = Params( Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), @@ -94,15 +131,15 @@ class FloatTensorTreeSuite extends DimwitTest: Tensor1(Axis[A]).fromArray(Array(0.4f, 0.5f, 0.6f)), Tensor0(1.5f) ) - def addTensors[T <: Tuple: Labels](t1: Tensor[T, Float], t2: Tensor[T, Float]): Tensor[T, Float] = t1 + t2 - val res = params1.zipMap(params2, [T <: Tuple] => (labels: Labels[T]) ?=> (x1: Tensor[T, Float], x2: Tensor[T, Float]) => addTensors[T](x1, x2)) + def addTensors[T <: Tuple: Labels](t1: Tensor[T, Float32], t2: Tensor[T, Float32]): Tensor[T, Float32] = t1 + t2 + val res = params1.zipMap(params2, [T <: Tuple] => (labels: Labels[T]) ?=> (x1: Tensor[T, Float32], x2: Tensor[T, Float32]) => addTensors[T](x1, x2)) res.w1 should approxEqual(params1.w1 + params2.w1) res.b1 should approxEqual(params1.b1 + params2.b1) describe("Extension methods"): case class Params( - w: Tensor1[A, Float], - b: Tensor0[Float] + w: Tensor1[A, Float32], + b: Tensor0[Float32] ) val params = Params( diff --git a/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala index 497ccd47..b8c67259 100644 --- a/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/PyTreeSuite.scala @@ -2,7 +2,6 @@ package dimwit.autodiff import dimwit.* import dimwit.jax.Jax -import dimwit.Conversions.given import me.shadaj.scalapy.py class ToPyTreeSuite extends DimwitTest: @@ -10,8 +9,8 @@ class ToPyTreeSuite extends DimwitTest: it("1-level case class"): case class Params( - val w: Tensor1[A, Float], - val b: Tensor0[Float] + val w: Tensor1[A, Float32], + val b: Tensor0[Float32] ) val params = Params( Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), @@ -26,8 +25,8 @@ class ToPyTreeSuite extends DimwitTest: it("2-level case class"): case class LayerParams( - val w: Tensor2[A, B, Float], - val b: Tensor0[Float] + val w: Tensor2[A, B, Float32], + val b: Tensor0[Float32] ) case class ModelParams( val layer1: LayerParams, @@ -59,7 +58,7 @@ class ToPyTreeSuite extends DimwitTest: Tensor0(0.5f) ) - val tc = TensorTree[(Tensor1[A, Float], Tensor0[Float])] + val tc = TensorTree[(Tensor1[A, Float32], Tensor0[Float32])] val reconstructed = tc.fromPyTree(tc.toPyTree(myTuple)) reconstructed._1 should approxEqual(myTuple._1) @@ -67,7 +66,7 @@ class ToPyTreeSuite extends DimwitTest: it("case class with list"): case class Params( - val layerWeights: List[Tensor2[A, B, Float]] + val layerWeights: List[Tensor2[A, B, Float32]] ) val params = Params( List( diff --git a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala index dd1ef92c..134f850c 100644 --- a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala @@ -1,7 +1,5 @@ package dimwit.autodiff -import dimwit.* -import dimwit.Conversions.given import dimwit.* class TensorTreeSuite extends DimwitTest: @@ -9,9 +7,9 @@ class TensorTreeSuite extends DimwitTest: describe("map"): it("1-level case class"): case class Data( - val numbers: Tensor1[A, Float], - val counts: Tensor1[A, Int], - val flags: Tensor1[A, Boolean] + val numbers: Tensor1[A, Float32], + val counts: Tensor1[A, Int32], + val flags: Tensor1[A, Bool] ) val params = Data( Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), @@ -27,8 +25,8 @@ class TensorTreeSuite extends DimwitTest: describe("zipmap"): it("1-level case class"): case class Params( - val w1: Tensor1[A, Float], - val b1: Tensor0[Int] + val w1: Tensor1[A, Float32], + val b1: Tensor0[Int32] ) val params1 = Params( Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), diff --git a/core/src/test/scala/dimwit/jax/JitSuite.scala b/core/src/test/scala/dimwit/jax/JitSuite.scala index c57228dd..c244077b 100644 --- a/core/src/test/scala/dimwit/jax/JitSuite.scala +++ b/core/src/test/scala/dimwit/jax/JitSuite.scala @@ -7,7 +7,7 @@ import me.shadaj.scalapy.py class JitSuite extends DimwitTest: it("jit compilation works correctly"): - def f(t: Tensor1[A, Float]): Tensor1[A, Float] = + def f(t: Tensor1[A, Float32]): Tensor1[A, Float32] = t * ((t +! 1f) /! 2f) val jitF = jit(f) @@ -19,7 +19,7 @@ class JitSuite extends DimwitTest: res should approxEqual(jittedRes) it("jitDonating compilation works correctly"): - def f(t: Tensor1[A, Float]): Tensor1[A, Float] = + def f(t: Tensor1[A, Float32]): Tensor1[A, Float32] = t * ((t +! 1f) /! 2f) val (jitDonate, jitF, jitReclaim) = jitDonating(f) @@ -32,7 +32,7 @@ class JitSuite extends DimwitTest: res should approxEqual(jittedRes) it("jitDonatingUnsafe compilation works correctly"): - def f(t: Tensor1[A, Float]): Tensor1[A, Float] = + def f(t: Tensor1[A, Float32]): Tensor1[A, Float32] = t * ((t +! 1f) /! 2f) val jitF = jitDonatingUnsafe(f) @@ -55,7 +55,7 @@ class JitSuite extends DimwitTest: val tensor = Tensor(Shape1(Axis[A] -> 5)).fill(1f) - def complexFn(t: Tensor1[A, Float]): Tensor1[A, Float] = + def complexFn(t: Tensor1[A, Float32]): Tensor1[A, Float32] = (0 until 50).foldLeft(t) { (acc, _) => acc * ((acc +! 1f) /! 2f) } val jitComplexFn = jit(complexFn) @@ -93,19 +93,19 @@ class JitSuite extends DimwitTest: // Prepare functions to test // One Param - def fi0r1(r1: Tensor1[A, Float]): Tensor1[A, Float] = r1 +! 1f + def fi0r1(r1: Tensor1[A, Float32]): Tensor1[A, Float32] = r1 +! 1f // Two Params - def fi1r1(p1: Tensor2[A, B, Float], r1: Tensor1[A, Float]): Tensor1[A, Float] = r1 + p1.sum(Axis[B]) - def fi0r2(r1: Tensor1[A, Float], r2: Tensor1[A, Float]): (Tensor1[A, Float], Tensor1[A, Float]) = (r1 +! 1f, r2 *! 2f) + def fi1r1(p1: Tensor2[A, B, Float32], r1: Tensor1[A, Float32]): Tensor1[A, Float32] = r1 + p1.sum(Axis[B]) + def fi0r2(r1: Tensor1[A, Float32], r2: Tensor1[A, Float32]): (Tensor1[A, Float32], Tensor1[A, Float32]) = (r1 +! 1f, r2 *! 2f) // Three Params - def fi2r1(p1: Tensor2[A, B, Float], p2: Tensor2[A, C, Float], r1: Tensor1[A, Float]): Tensor1[A, Float] = r1 + p1.sum(Axis[B]) + p2.sum(Axis[C]) - def fi1r2(p1: Tensor2[A, B, Float], r1: Tensor1[A, Float], r2: Tensor1[A, Float]): (Tensor1[A, Float], Tensor1[A, Float]) = (r1 + p1.sum(Axis[B]), r2 *! 2f) - def fi0r3(r1: Tensor1[A, Float], r2: Tensor1[A, Float], r3: Tensor1[A, Float]): (Tensor1[A, Float], Tensor1[A, Float], Tensor1[A, Float]) = (r1 +! 1f, r2 *! 2f, r3 -! 3f) + def fi2r1(p1: Tensor2[A, B, Float32], p2: Tensor2[A, C, Float32], r1: Tensor1[A, Float32]): Tensor1[A, Float32] = r1 + p1.sum(Axis[B]) + p2.sum(Axis[C]) + def fi1r2(p1: Tensor2[A, B, Float32], r1: Tensor1[A, Float32], r2: Tensor1[A, Float32]): (Tensor1[A, Float32], Tensor1[A, Float32]) = (r1 + p1.sum(Axis[B]), r2 *! 2f) + def fi0r3(r1: Tensor1[A, Float32], r2: Tensor1[A, Float32], r3: Tensor1[A, Float32]): (Tensor1[A, Float32], Tensor1[A, Float32], Tensor1[A, Float32]) = (r1 +! 1f, r2 *! 2f, r3 -! 3f) // Four Params - def fi3r1(p1: Tensor2[A, B, Float], p2: Tensor2[A, C, Float], p3: Tensor2[A, D, Float], r1: Tensor1[A, Float]): Tensor1[A, Float] = r1 + p1.sum(Axis[B]) + p2.sum(Axis[C]) + p3.sum(Axis[D]) - def fi2r2(p1: Tensor2[A, B, Float], p2: Tensor2[A, C, Float], r1: Tensor1[A, Float], r2: Tensor1[A, Float]): (Tensor1[A, Float], Tensor1[A, Float]) = (r1 + p1.sum(Axis[B]), r2 * p2.sum(Axis[C])) - def fi1r3(p1: Tensor2[A, B, Float], r1: Tensor1[A, Float], r2: Tensor1[A, Float], r3: Tensor1[A, Float]): (Tensor1[A, Float], Tensor1[A, Float], Tensor1[A, Float]) = (r1 + p1.sum(Axis[B]), r2 *! 2f, r3 -! 3f) - def fi0r4(r1: Tensor1[A, Float], r2: Tensor1[A, Float], r3: Tensor1[A, Float], r4: Tensor1[A, Float]): (Tensor1[A, Float], Tensor1[A, Float], Tensor1[A, Float], Tensor1[A, Float]) = (r1 +! 1f, r2 *! 2f, r3 -! 3f, r4 /! 4f) + def fi3r1(p1: Tensor2[A, B, Float32], p2: Tensor2[A, C, Float32], p3: Tensor2[A, D, Float32], r1: Tensor1[A, Float32]): Tensor1[A, Float32] = r1 + p1.sum(Axis[B]) + p2.sum(Axis[C]) + p3.sum(Axis[D]) + def fi2r2(p1: Tensor2[A, B, Float32], p2: Tensor2[A, C, Float32], r1: Tensor1[A, Float32], r2: Tensor1[A, Float32]): (Tensor1[A, Float32], Tensor1[A, Float32]) = (r1 + p1.sum(Axis[B]), r2 * p2.sum(Axis[C])) + def fi1r3(p1: Tensor2[A, B, Float32], r1: Tensor1[A, Float32], r2: Tensor1[A, Float32], r3: Tensor1[A, Float32]): (Tensor1[A, Float32], Tensor1[A, Float32], Tensor1[A, Float32]) = (r1 + p1.sum(Axis[B]), r2 *! 2f, r3 -! 3f) + def fi0r4(r1: Tensor1[A, Float32], r2: Tensor1[A, Float32], r3: Tensor1[A, Float32], r4: Tensor1[A, Float32]): (Tensor1[A, Float32], Tensor1[A, Float32], Tensor1[A, Float32], Tensor1[A, Float32]) = (r1 +! 1f, r2 *! 2f, r3 -! 3f, r4 /! 4f) // Prepare test data (def so tests are independent as donating can destroy internal data) diff --git a/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala b/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala index 114aba68..4b4b110a 100644 --- a/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala +++ b/core/src/test/scala/dimwit/memory/DimWitMemorySuite.scala @@ -13,7 +13,7 @@ class DimWitMemorySuite extends DimwitTest: val exampleT = Tensor(Shape(Axis[A] -> 1000, Axis[B] -> 1000)).fill(0f) - def complexF(in: Tensor2[A, B, Float]): Tensor2[A, B, Float] = + def complexF(in: Tensor2[A, B, Float32]): Tensor2[A, B, Float32] = var x = in for i <- 0 until 10 do val a = in +! 5 @@ -35,7 +35,7 @@ class DimWitMemorySuite extends DimwitTest: lazy val oomBarrier = oomAt * 2 - def testF(in: Tensor2[A, B, Float]): Tensor2[A, B, Float] = + def testF(in: Tensor2[A, B, Float32]): Tensor2[A, B, Float32] = var t = in for _ <- 0 until oomBarrier do t = complexF(t) @@ -47,7 +47,7 @@ class DimWitMemorySuite extends DimwitTest: exception.getMessage should include("Out of memory") it("GC should fix (not guaranteed)"): - def testFWithGC(in: Tensor2[A, B, Float]): Tensor2[A, B, Float] = + def testFWithGC(in: Tensor2[A, B, Float32]): Tensor2[A, B, Float32] = var t = in for _ <- 0 until oomBarrier do dimwit.gc() // trigger GC (suggestion) @@ -57,7 +57,7 @@ class DimWitMemorySuite extends DimwitTest: testFWithGC(exampleT) it("eager should fix"): - def testFWithEager(in: Tensor2[A, B, Float]): Tensor2[A, B, Float] = + def testFWithEager(in: Tensor2[A, B, Float32]): Tensor2[A, B, Float32] = var t = in val complexFEager = dimwit.eagerCleanup(complexF) for _ <- 0 until oomBarrier do @@ -67,7 +67,7 @@ class DimWitMemorySuite extends DimwitTest: testFWithEager(exampleT) it("jit should fix"): - def testFWithJit(in: Tensor2[A, B, Float]): Tensor2[A, B, Float] = + def testFWithJit(in: Tensor2[A, B, Float32]): Tensor2[A, B, Float32] = var t = in val complexFJit = dimwit.jit(complexF) for _ <- 0 until oomBarrier do diff --git a/core/src/test/scala/dimwit/package.scala b/core/src/test/scala/dimwit/package.scala index 192772e3..4bc580bb 100644 --- a/core/src/test/scala/dimwit/package.scala +++ b/core/src/test/scala/dimwit/package.scala @@ -3,7 +3,6 @@ package dimwit /** Global test utility definitions */ import dimwit.* -import dimwit.Conversions.given import org.scalacheck.Prop.* import org.scalacheck.{Arbitrary, Gen} import me.shadaj.scalapy.py @@ -22,14 +21,12 @@ trait C derives Label trait D derives Label trait E derives Label -def approxEqual[T <: Tuple: Labels, V](right: Tensor[T, V], tolerance: Float = 1e-6f)(using ev: MustBeFloat[V]): Matcher[Tensor[T, V]] = - new Matcher[Tensor[T, V]]: - def apply(left: Tensor[T, V]): MatchResult = - val leftF = left.asInstanceOf[Tensor[T, Float]] - val rightF = right.asInstanceOf[Tensor[T, Float]] +def approxEqual[T <: Tuple: Labels](right: Tensor[T, Float32], tolerance: Float = 1e-6f): Matcher[Tensor[T, Float32]] = + new Matcher[Tensor[T, Float32]]: + def apply(left: Tensor[T, Float32]): MatchResult = - val areEqual = (leftF `approxEquals` (rightF, tolerance)).item - lazy val diffMsg = if areEqual then "" else s"Max diff: ${(leftF - rightF).abs.max}" + val areEqual = (left `approxEquals` (right, tolerance)).item + lazy val diffMsg = if areEqual then "" else s"Max diff: ${(left - right).abs.max}" MatchResult( areEqual, @@ -37,13 +34,6 @@ def approxEqual[T <: Tuple: Labels, V](right: Tensor[T, V], tolerance: Float = 1 s"Tensors matched, but they shouldn't have." ) -trait MustBeFloat[V] -object MustBeFloat: - given MustBeFloat[Float] with {} - - transparent inline given [V]: MustBeFloat[V] = - error("approxEqual can only be used with Float tensors. For Int tensors, use 'equal(...)'.") - private lazy val _dimwitTestInit: Unit = dimwit.initialize() trait DimwitTest extends AnyFunSpec with Matchers: diff --git a/core/src/test/scala/dimwit/python/PyWrapSuite.scala b/core/src/test/scala/dimwit/python/PyWrapSuite.scala index 8fc40929..432584a8 100644 --- a/core/src/test/scala/dimwit/python/PyWrapSuite.scala +++ b/core/src/test/scala/dimwit/python/PyWrapSuite.scala @@ -16,20 +16,20 @@ class PyWrapSuite extends DimwitTest: describe("1-input function"): it("wraps an identity Python function"): - val f = PyBridge.liftPyFn[Tensor1[A, Float], Tensor1[A, Float]](identity1d) + val f = PyBridge.liftPyFn[Tensor1[A, Float32], Tensor1[A, Float32]](identity1d) val input = Tensor1(Axis[A]).fromArray(Array(1f, 2f, 3f)) f(input) should approxEqual(input) it("wraps a Python function that doubles output"): - val f = PyBridge.liftPyFn[Tensor1[A, Float], Tensor1[A, Float]](double1d) + val f = PyBridge.liftPyFn[Tensor1[A, Float32], Tensor1[A, Float32]](double1d) val input = Tensor1(Axis[A]).fromArray(Array(1f, 2f, 3f)) f(input) should approxEqual(Tensor1(Axis[A]).fromArray(Array(2f, 4f, 6f))) describe("2-input function"): it("wraps a Python function that adds two tensors"): - val f = PyBridge.liftPyFn[(Tensor1[A, Float], Tensor1[A, Float]), Tensor1[A, Float]](addTupled) + val f = PyBridge.liftPyFn[(Tensor1[A, Float32], Tensor1[A, Float32]), Tensor1[A, Float32]](addTupled) val t1 = Tensor1(Axis[A]).fromArray(Array(1f, 2f, 3f)) val t2 = Tensor1(Axis[A]).fromArray(Array(4f, 5f, 6f)) @@ -37,47 +37,47 @@ class PyWrapSuite extends DimwitTest: describe("scalar function"): it("wraps a Python function that squares a scalar"): - val f = PyBridge.liftPyFn[Tensor0[Float], Tensor0[Float]](squareScalar) + val f = PyBridge.liftPyFn[Tensor0[Float32], Tensor0[Float32]](squareScalar) val x = Tensor0(5.0f) f(x) shouldEqual Tensor0(25.0f) describe("PyBridge.toJax"): it("applies jax.jit to a Scala function"): - def f(t: Tensor1[A, Float]): Tensor1[A, Float] = t *! 3f + def f(t: Tensor1[A, Float32]): Tensor1[A, Float32] = t *! 3f - val jitted = PyBridge.toJax[Tensor1[A, Float], Tensor1[A, Float]](Jax.jax.jit)(f) + val jitted = PyBridge.toJax[Tensor1[A, Float32], Tensor1[A, Float32]](Jax.jax.jit)(f) val input = Tensor1(Axis[A]).fromArray(Array(1f, 2f, 3f)) jitted(input) should approxEqual(f(input)) describe("PyBridge.toPyFn"): it("wraps a Scala function as a Python-callable that doubles a tensor"): - def double(t: Tensor1[A, Float]): Tensor1[A, Float] = t *! 2f + def double(t: Tensor1[A, Float32]): Tensor1[A, Float32] = t *! 2f - val pyFn = PyBridge.toPyFn[Tensor1[A, Float], Tensor1[A, Float]](double) + val pyFn = PyBridge.toPyFn[Tensor1[A, Float32], Tensor1[A, Float32]](double) val input = Tensor1(Axis[A]).fromArray(Array(1f, 2f, 3f)) - val result = PyBridge.liftPyFn[Tensor1[A, Float], Tensor1[A, Float]](pyFn)(input) + val result = PyBridge.liftPyFn[Tensor1[A, Float32], Tensor1[A, Float32]](pyFn)(input) result should approxEqual(Tensor1(Axis[A]).fromArray(Array(2f, 4f, 6f))) it("wraps a Scala function that returns a scalar"): - def sumAll(t: Tensor1[A, Float]): Tensor0[Float] = t.sum + def sumAll(t: Tensor1[A, Float32]): Tensor0[Float32] = t.sum - val pyFn = PyBridge.toPyFn[Tensor1[A, Float], Tensor0[Float]](sumAll) + val pyFn = PyBridge.toPyFn[Tensor1[A, Float32], Tensor0[Float32]](sumAll) val input = Tensor1(Axis[A]).fromArray(Array(1f, 2f, 3f)) - val result = PyBridge.liftPyFn[Tensor1[A, Float], Tensor0[Float]](pyFn)(input) + val result = PyBridge.liftPyFn[Tensor1[A, Float32], Tensor0[Float32]](pyFn)(input) result shouldEqual Tensor0(6.0f) it("produces a py.Dynamic usable from Python code"): - def triple(t: Tensor1[A, Float]): Tensor1[A, Float] = t *! 3f + def triple(t: Tensor1[A, Float32]): Tensor1[A, Float32] = t *! 3f - val pyFn = PyBridge.toPyFn[Tensor1[A, Float], Tensor1[A, Float]](triple) + val pyFn = PyBridge.toPyFn[Tensor1[A, Float32], Tensor1[A, Float32]](triple) // Call through a Python lambda that just forwards to pyFn val caller = py.eval("lambda fn, x: fn(x)") val input = Tensor1(Axis[A]).fromArray(Array(2f, 3f, 4f)) - val pyResult = caller(pyFn, dimwit.autodiff.TensorTree[Tensor1[A, Float]].toPyTree(input)) - val result = dimwit.autodiff.TensorTree[Tensor1[A, Float]].fromPyTree(pyResult) + val pyResult = caller(pyFn, dimwit.autodiff.TensorTree[Tensor1[A, Float32]].toPyTree(input)) + val result = dimwit.autodiff.TensorTree[Tensor1[A, Float32]].fromPyTree(pyResult) result should approxEqual(Tensor1(Axis[A]).fromArray(Array(6f, 9f, 12f))) diff --git a/core/src/test/scala/dimwit/random/RandomSuite.scala b/core/src/test/scala/dimwit/random/RandomSuite.scala index 13b4e3c3..813bda2a 100644 --- a/core/src/test/scala/dimwit/random/RandomSuite.scala +++ b/core/src/test/scala/dimwit/random/RandomSuite.scala @@ -1,7 +1,6 @@ package dimwit.random import dimwit.* -import dimwit.Conversions.given import dimwit.jax.Jax import me.shadaj.scalapy.py diff --git a/core/src/test/scala/dimwit/stats/DistributionSuite.scala b/core/src/test/scala/dimwit/stats/DistributionSuite.scala index 9959b2c5..66095632 100644 --- a/core/src/test/scala/dimwit/stats/DistributionSuite.scala +++ b/core/src/test/scala/dimwit/stats/DistributionSuite.scala @@ -24,7 +24,7 @@ class DistributionSuite extends DimwitTest: val dist = Normal(loc, scale) val scalaLogProbs = dist.elementWiseLogProb(x) - val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float])( + val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float32])( jstats.norm.logpdf(x.jaxValue, loc = loc.jaxValue, scale = scale.jaxValue) ) scalaLogProbs.asFloat should approxEqual(jaxLogProbs) @@ -48,7 +48,7 @@ class DistributionSuite extends DimwitTest: val dist = Uniform(low, high) val scalaLogProbs = dist.elementWiseLogProb(x) - val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float])( + val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float32])( jstats.uniform.logpdf(x.jaxValue, loc = low.jaxValue, scale = (high - low).jaxValue) ) scalaLogProbs.asFloat should approxEqual(jaxLogProbs) @@ -71,7 +71,7 @@ class DistributionSuite extends DimwitTest: val dist = Bernoulli(Prob(probs)) val scalaLogProbs = dist.elementWiseLogProb(x) - val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float])( + val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float32])( jstats.bernoulli.logpmf(x.jaxValue, p = probs.jaxValue) ) scalaLogProbs.asFloat should approxEqual(jaxLogProbs) @@ -82,7 +82,7 @@ class DistributionSuite extends DimwitTest: ) val key = Random.Key(42) val samples = key.splitvmap(Axis[Samples] -> 1000)(k => bernoulli.sample(k)) - val sampleMeans = samples.asFloat.mean(Axis[Samples]) + val sampleMeans = samples.asFloat32.mean(Axis[Samples]) val expectedMeans = bernoulli.probs.asFloat sampleMeans should approxEqual(expectedMeans, 0.1f) @@ -94,7 +94,7 @@ class DistributionSuite extends DimwitTest: val dist = Binomial(n, Prob(probs)) val scalaLogProbs = dist.elementWiseLogProb(x) - val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float])( + val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float32])( jstats.binom.logpmf(x.jaxValue, n = n.jaxValue, p = probs.jaxValue) ) scalaLogProbs.asFloat should approxEqual(jaxLogProbs) @@ -107,7 +107,7 @@ class DistributionSuite extends DimwitTest: ) val key = Random.Key(42) val samples = key.splitvmap(Axis[Samples] -> 10000)(k => binomial.sample(k)) - val sampleMeans = samples.asFloat.mean(Axis[Samples]) + val sampleMeans = samples.asFloat32.mean(Axis[Samples]) val expectedMeans = binomial.probs.asFloat *! n.item.toFloat sampleMeans should approxEqual(expectedMeans, 0.5f) @@ -118,7 +118,7 @@ class DistributionSuite extends DimwitTest: val binomial = Binomial(n, Prob(probs)) val key = Random.Key(123) val samples = key.splitvmap(Axis[Samples] -> 5000)(k => binomial.sample(k)) - val sampleMeans = samples.asFloat.mean(Axis[Samples]) + val sampleMeans = samples.asFloat32.mean(Axis[Samples]) sampleMeans should approxEqual(probs, 0.1f) it("handles edge cases p=0 and p=1"): @@ -128,7 +128,7 @@ class DistributionSuite extends DimwitTest: val binomial = Binomial(n, Prob(probsEdge)) val key = Random.Key(456) val samples = key.splitvmap(Axis[Samples] -> 100)(k => binomial.sample(k)) - val sampleMeans = samples.asFloat.mean(Axis[Samples]) + val sampleMeans = samples.asFloat32.mean(Axis[Samples]) val expectedMeans = Tensor(Shape(Axis[A] -> 2)).fromArray(Array(0.0f, n.item.toFloat)) sampleMeans should approxEqual(expectedMeans, 0.1f) @@ -140,7 +140,7 @@ class DistributionSuite extends DimwitTest: val dist = Cauchy(loc, scale) val scalaLogProbs = dist.elementWiseLogProb(x) - val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float])( + val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float32])( jstats.cauchy.logpdf(x.jaxValue, loc = loc.jaxValue, scale = scale.jaxValue) ) scalaLogProbs.asFloat should approxEqual(jaxLogProbs) @@ -165,7 +165,7 @@ class DistributionSuite extends DimwitTest: val dist = HalfNormal(loc, scale) val scalaLogProbs = dist.elementWiseLogProb(x) // Compute expected manually: log(2) + norm.logpdf for x >= loc - val expectedLogProbs = liftPyTensor1(Axis[A], VType[Float])( + val expectedLogProbs = liftPyTensor1(Axis[A], VType[Float32])( jstats.norm.logpdf(x.jaxValue, loc = loc.jaxValue, scale = scale.jaxValue) ) +! math.log(2.0).toFloat scalaLogProbs.asFloat should approxEqual(expectedLogProbs) @@ -192,7 +192,7 @@ class DistributionSuite extends DimwitTest: val dist = StudentT(df, loc, scale) val scalaLogProbs = dist.elementWiseLogProb(x) - val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float])( + val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float32])( jstats.t.logpdf(x.jaxValue, df = df.jaxValue, loc = loc.jaxValue, scale = scale.jaxValue) ) scalaLogProbs.asFloat should approxEqual(jaxLogProbs) @@ -223,7 +223,7 @@ class DistributionSuite extends DimwitTest: val dist = MVNormal(mean, cov) val scalaLogProb = dist.logProb(x) - val jaxLogProb = liftPyTensor0(VType[Float])( + val jaxLogProb = liftPyTensor0(VType[Float32])( jstats.multivariate_normal.logpdf(x.jaxValue, mean = mean.jaxValue, cov = cov.jaxValue) ) scalaLogProb.asFloat should approxEqual(jaxLogProb) @@ -247,7 +247,7 @@ class DistributionSuite extends DimwitTest: val dist = Dirichlet(concentration) val scalaLogProb = dist.logProb(x) - val jaxLogProb = liftPyTensor0(VType[Float])( + val jaxLogProb = liftPyTensor0(VType[Float32])( jstats.dirichlet.logpdf(x.jaxValue, alpha = concentration.jaxValue) ) scalaLogProb.asFloat should approxEqual(jaxLogProb) @@ -272,7 +272,7 @@ class DistributionSuite extends DimwitTest: val dist = Multinomial[A](n, probs) val scalaLogProb = dist.logProb(x) - val jaxLogProb = liftPyTensor0(VType[Float])( + val jaxLogProb = liftPyTensor0(VType[Float32])( jstats.multinomial.logpmf(x.jaxValue, n = n.jaxValue, p = probs.jaxValue) ) scalaLogProb.asFloat should approxEqual(jaxLogProb) @@ -284,7 +284,7 @@ class DistributionSuite extends DimwitTest: val multinomial = Multinomial[A](n, probs) val key = Random.Key(42) val samples = key.splitvmap(Axis[Samples] -> 10000)(k => multinomial.sample(k)) - val sampleMean = samples.asFloat.mean(Axis[Samples]) + val sampleMean = samples.asFloat32.mean(Axis[Samples]) // Expected mean counts are n * probs val expectedMean = multinomial.probs.asFloat *! n.item.toFloat sampleMean should approxEqual(expectedMean, 2.0f) @@ -305,11 +305,11 @@ class DistributionSuite extends DimwitTest: val key = Random.Key(42) val numSamples = 10000 val samples = key.splitvmap(Axis[Samples] -> numSamples)(k => categorical.sample(k)) - val counts = liftPyTensor1(Axis[A], VType[Float])( + val counts = liftPyTensor1(Axis[A], VType[Float32])( Jax.jnp.bincount(samples.jaxValue, minlength = 4).astype(Jax.jnp.float32) ) val frequencies = counts *! (1.0f / numSamples.toFloat) - frequencies should approxEqual(probs.asFloat, 0.02f) + frequencies should approxEqual(probs.asFloat32, 0.02f) describe("Beta"): it("logProbs matches JAX"): @@ -319,7 +319,7 @@ class DistributionSuite extends DimwitTest: val dist = Beta(alpha, beta) val scalaLogProbs = dist.elementWiseLogProb(x) - val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float])( + val jaxLogProbs = liftPyTensor1(Axis[A], VType[Float32])( jstats.beta.logpdf(x.jaxValue, a = alpha.jaxValue, b = beta.jaxValue) ) scalaLogProbs.asFloat should approxEqual(jaxLogProbs) diff --git a/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala b/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala index fdeb1812..398518f3 100644 --- a/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorCovarianceSuite.scala @@ -1,7 +1,6 @@ package dimwit.tensor import dimwit.* -import dimwit.Conversions.given import scala.collection.View.Empty class TensorCovarianceSuite extends DimwitTest: @@ -10,9 +9,9 @@ class TensorCovarianceSuite extends DimwitTest: trait Parent derives Label trait Child1 extends Parent derives Label trait Child2 extends Parent derives Label - def genericFunction[T <: Parent: Label](t: Tensor1[T, Float]): Tensor1[T, Float] = t + t - val child1: Tensor1[Child1, Float] = Tensor(Shape1(Axis[Child1] -> 4)).fill(1f) - val child2: Tensor1[Child2, Float] = Tensor(Shape1(Axis[Child2] -> 4)).fill(1f) + def genericFunction[T <: Parent: Label](t: Tensor1[T, Float32]): Tensor1[T, Float32] = t + t + val child1: Tensor1[Child1, Float32] = Tensor(Shape1(Axis[Child1] -> 4)).fill(1f) + val child2: Tensor1[Child2, Float32] = Tensor(Shape1(Axis[Child2] -> 4)).fill(1f) "genericFunction(child1)" should compile "genericFunction(child2)" should compile @@ -22,8 +21,8 @@ class TensorCovarianceSuite extends DimwitTest: trait Classes derives Label object MLContext: - opaque type Logit = Float - opaque type Prob = Float + opaque type Logit = Float32 + opaque type Prob = Float32 def createLogits[L: Label](s: Shape1[L]): Tensor1[L, Logit] = Tensor(s).fill(0f) def createProbs[L: Label](s: Shape1[L]): Tensor1[L, Prob] = Tensor(s).fill(0f) diff --git a/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala b/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala index 48957fec..b0341b97 100644 --- a/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorCreationSuite.scala @@ -48,12 +48,21 @@ class TensorCreationSuite extends DimwitTest: floatTensorFromDouble.dtype shouldBe DType.Float64 describe("Overwrite default setings"): + it("Change float default dtype from Float32 to Float64"): + // Check fill + withJaxX64Support: // Enable float64 support in JAX + val t64 = Tensor(Shape3(Axis[A] -> 2, Axis[B] -> 3, Axis[C] -> 4), VType[Float64]).fill(3.14f) + t64.dtype shouldBe DType.Float64 + // Check fromArray + withJaxX64Support: // Enable float64 support in JAX + val t64 = Tensor(Shape2(Axis[A] -> 2, Axis[B] -> 2), VType[Float64]).fromArray(Array(1.0f, 2.0f, 3.0f, 4.0f)) + t64.dtype shouldBe DType.Float64 + it("Change double default dtype from Float64 to Float32"): - given ExecutionType[Double] = ExecutionTypeFor[Double](DType.Float32) // Check fill - val floatTensorFromDouble = Tensor(Shape3(Axis[A] -> 2, Axis[B] -> 3, Axis[C] -> 4)).fill(3.14) + val floatTensorFromDouble = Tensor(Shape3(Axis[A] -> 2, Axis[B] -> 3, Axis[C] -> 4), VType[Float32]).fill(3.14) floatTensorFromDouble.dtype shouldBe DType.Float32 // Check fromArray withJaxX64Support: // Enable float64 support in JAX - val floatTensorFromDouble2 = Tensor(Shape2(Axis[A] -> 2, Axis[B] -> 2)).fromArray(Array(1.0, 2.0, 3.0, 4.0)) + val floatTensorFromDouble2 = Tensor(Shape2(Axis[A] -> 2, Axis[B] -> 2), VType[Float32]).fromArray(Array(1.0, 2.0, 3.0, 4.0)) floatTensorFromDouble2.dtype shouldBe DType.Float32 diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala index 55fa6447..7d38c18b 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsBinarySuite.scala @@ -1,7 +1,7 @@ package dimwit.tensor import dimwit.* -import dimwit.Conversions.given + class TensorOpsBinarySuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( @@ -32,8 +32,8 @@ class TensorOpsBinarySuite extends DimwitTest: (t2 / t2_2) shouldEqual Tensor.like(t2).fromArray(Array(5.0f, 5.0f, 6.0f, 5.0f)) it("Comparisons (<, <=, >, >=)"): - (t2 < t2_2).asBoolean shouldEqual Tensor(t2.shape).fromArray(Array(false, false, false, false)) - (t2 > t2_2).asBoolean shouldEqual Tensor(t2.shape).fromArray(Array(true, true, true, true)) + (t2 < t2_2).asBool shouldEqual Tensor(t2.shape).fromArray(Array(false, false, false, false)) + (t2 > t2_2).asBool shouldEqual Tensor(t2.shape).fromArray(Array(true, true, true, true)) it("elementEquals"): (t2 `elementEquals` t2) shouldEqual Tensor(t2.shape).fromArray(Array(true, true, true, true)) @@ -50,5 +50,5 @@ class TensorOpsBinarySuite extends DimwitTest: (i2 * i2_2) shouldEqual Tensor.like(i2).fromArray(Array(20, 80, 150, 320)) it("Comparisons (<, <=, >, >=)"): - (i2 < i2_2).asBoolean shouldEqual Tensor(i2.shape).fromArray(Array(false, false, false, false)) - (i2 >= i2_2).asBoolean shouldEqual Tensor(i2.shape).fromArray(Array(true, true, true, true)) + (i2 < i2_2).asBool shouldEqual Tensor(i2.shape).fromArray(Array(false, false, false, false)) + (i2 >= i2_2).asBool shouldEqual Tensor(i2.shape).fromArray(Array(true, true, true, true)) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala index 2e5b8280..6b5f5962 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsBroadcastSuite.scala @@ -2,6 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given + class TensorOpsBroadcastSuite extends DimwitTest: val tA = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala index 635e327e..89f33e57 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsContractionSuite.scala @@ -1,7 +1,7 @@ package dimwit.tensor import dimwit.* -import dimwit.Conversions.given + class TensorOpsContractionSuite extends DimwitTest: val v1 = Tensor1(Axis[A]).fromArray( diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala index cb827809..26a25f58 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsConvolutionSuite.scala @@ -1,7 +1,6 @@ package dimwit.tensor import dimwit.* -import dimwit.Conversions.given import dimwit.tensor.TensorOps.Convolution.Padding import dimwit.stats.Normal diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala index eec0551b..ad634b68 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsElementwiseSuite.scala @@ -2,6 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given + class TensorOpsElementwiseSuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( @@ -90,17 +91,17 @@ class TensorOpsElementwiseSuite extends DimwitTest: describe("Casting Ops (Tensor2)"): it("boolean casting"): - b2.asBoolean shouldEqual b2 - b2.asInt shouldEqual Tensor(b2.shape).fromArray(Array(1, 0, 0, 1)) - b2.asFloat should approxEqual(Tensor(b2.shape).fromArray(Array(1.0f, 0.0f, 0.0f, 1.0f))) + b2.asBool shouldEqual b2 + b2.asInt32 shouldEqual Tensor(b2.shape).fromArray(Array(1, 0, 0, 1)) + b2.asFloat32 should approxEqual(Tensor(b2.shape).fromArray(Array(1.0f, 0.0f, 0.0f, 1.0f))) it("int casting"): - i2.asBoolean shouldEqual Tensor(i2.shape).fromArray(Array(true, false, true, true)) - i2.asInt shouldEqual i2 - i2.asFloat should approxEqual(Tensor(i2.shape).fromArray(Array(-1.0f, 0.0f, 1.0f, 2.0f))) + i2.asBool shouldEqual Tensor(i2.shape).fromArray(Array(true, false, true, true)) + i2.asInt32 shouldEqual i2 + i2.asFloat32 should approxEqual(Tensor(i2.shape).fromArray(Array(-1.0f, 0.0f, 1.0f, 2.0f))) it("float casting"): val f2 = Tensor.like(t2).fromArray(Array(-1.1f, 0.0f, 0.9f, 2.5f)) - f2.asBoolean shouldEqual Tensor(f2.shape).fromArray(Array(true, false, true, true)) - f2.asInt shouldEqual Tensor(f2.shape).fromArray(Array(-1, 0, 0, 2)) - f2.asFloat shouldEqual f2 + f2.asBool shouldEqual Tensor(f2.shape).fromArray(Array(true, false, true, true)) + f2.asInt32 shouldEqual Tensor(f2.shape).fromArray(Array(-1, 0, 0, 2)) + f2.asFloat32 shouldEqual f2 diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala index 52a28b76..73e370f3 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsFunctionalSuite.scala @@ -2,6 +2,7 @@ package dimwit.tensor import dimwit.* import dimwit.Conversions.given + class TensorOpsFunctionalSuite extends DimwitTest: val t2 = Tensor2(Axis[A], Axis[B]).fromArray( @@ -26,18 +27,18 @@ class TensorOpsFunctionalSuite extends DimwitTest: res shouldEqual Tensor1(Axis[B]).fromArray(Array(4.0f, 6.0f)) it("nested vmap"): - val res = t2.vmap(Axis[A])(_.vmap(Axis[B])(_ => 0.0f)) + val res = t2.vmap(Axis[A])(_.vmap(Axis[B])(_ => Tensor0(0.0f))) res shouldEqual Tensor.like(t2).fill(0.0f) describe("zipvmap (Parallel Mapping)"): - def l2[L: Label](v1: Tensor1[L, Float], v2: Tensor1[L, Float]): Tensor0[Float] = (v1 - v2).pow(2.0f).sum.sqrt + def l2[L: Label](v1: Tensor1[L, Float32], v2: Tensor1[L, Float32]): Tensor0[Float32] = (v1 - v2).pow(2.0f).sum.sqrt it("zipvmap f should get correct runtime shape."): val t1 = Tensor(Shape(Axis[A] -> 2, Axis[B] -> 3)).fill(0f) val t2 = Tensor(Shape(Axis[A] -> 2, Axis[B] -> 3)).fill(0f) val shapesCorrect = zipvmap(Axis[A])(t1, t2): (v1, v2) => - v1.shape == Shape(Axis[B] -> 3) && v2.shape == Shape(Axis[B] -> 3) + Tensor0(v1.shape == Shape(Axis[B] -> 3) && v2.shape == Shape(Axis[B] -> 3)) shapesCorrect.all.item shouldBe true it("zipvmap2 adds two tensors"): @@ -60,7 +61,7 @@ class TensorOpsFunctionalSuite extends DimwitTest: describe("vapply (Axis-wise application)"): - def l2[L: Label](v1: Tensor1[L, Float], v2: Tensor1[L, Float]): Tensor0[Float] = (v1 - v2).pow(2.0f).sum.sqrt + def l2[L: Label](v1: Tensor1[L, Float32], v2: Tensor1[L, Float32]): Tensor0[Float32] = (v1 - v2).pow(2.0f).sum.sqrt it("vapply(identity) is identity"): t2.vapply(Axis[A])(identity) shouldEqual t2 diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala index ff5c31a5..9d4cdbd3 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsReductionSuite.scala @@ -1,7 +1,7 @@ package dimwit.tensor import dimwit.* -import dimwit.Conversions.given + class TensorOpsReductionSuite extends DimwitTest: val t2 = Tensor2( @@ -179,10 +179,9 @@ class TensorOpsReductionSuite extends DimwitTest: val allFalse = Tensor.like(b2).fill(false) allFalse.any shouldEqual Tensor0(false) - describe("Approximate Equality") { + describe("Approximate Equality"): it("approxEquals"): val t2Near = t2 *! Tensor0(1.0000001f) t2.approxEquals(t2Near).item shouldBe true val t2Far = t2 *! Tensor0(1.1f) t2.approxEquals(t2Far).item shouldBe false - } diff --git a/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala b/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala index daaee223..fbf0c4c7 100644 --- a/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorOpsStructureSuite.scala @@ -1,7 +1,6 @@ package dimwit.tensor import dimwit.* -import dimwit.Conversions.given import dimwit.tensor.Labels.concat import scala.compiletime.testing.typeCheckErrors @@ -178,7 +177,7 @@ class TensorOpsStructureSuite extends DimwitTest: ab should approxEqual(ab2) it("generic shape in function"): - def f[T <: Tuple: Labels](t: Tensor[T, Float]): Tensor[T, Float] = + def f[T <: Tuple: Labels](t: Tensor[T, Float32]): Tensor[T, Float32] = val unflattened = t.flatten unflattened.unflatten(t.shape) diff --git a/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala b/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala index 6134e31c..0c85dd8f 100644 --- a/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala +++ b/core/src/test/scala/dimwit/tensor/TensorWithValueClassSuite.scala @@ -7,15 +7,15 @@ class TensorWithValueClassSuite extends DimwitTest: it("Value class support for more specific types in tensors"): object ValueClassScope: - opaque type V1 = Float - opaque type V2 = Float + opaque type V1 = Float32 + opaque type V2 = Float32 object V1: - def apply[T <: Tuple](t: Tensor[T, Float]): Tensor[T, V1] = t // lift - given IsFloating[V1] with {} // make all IsFloating ops available + def apply[T <: Tuple](t: Tensor[T, Float32]): Tensor[T, V1] = t // lift + given IsFloating[V1] = summon[IsFloating[Float32]] // make all IsFloating ops available object V2: - def apply[T <: Tuple](t: Tensor[T, Float]): Tensor[T, V2] = t // lift - given IsFloating[V2] with {} // make all IsFloating ops available + def apply[T <: Tuple](t: Tensor[T, Float32]): Tensor[T, V2] = t // lift + given IsFloating[V2] = summon[IsFloating[Float32]] // make all IsFloating ops available import ValueClassScope.* val t = Tensor(Shape(Axis[A] -> 1, Axis[B] -> 2)).fill(0f) diff --git a/docs/quickstart.md b/docs/quickstart.md index 878a93d1..d22a7ba1 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -19,19 +19,19 @@ trait Batch derives Label trait Feature derives Label // parameters are explicitly defined and usually bundled in a case class -case class Params(w: Tensor1[Feature, Float], b: Tensor0[Float]) derives TensorTree +case class Params(w: Tensor1[Feature, Float32], b: Tensor0[Float32]) derives TensorTree // the model as a function of data and parametesrs -def model(x: Tensor2[Batch, Feature, Float], y: Tensor1[Batch, Float])(params: Params): Tensor1[Batch, Float] = +def model(x: Tensor2[Batch, Feature, Float32], y: Tensor1[Batch, Float32])(params: Params): Tensor1[Batch, Float32] = x.dot(Axis[Feature])(params.w) +! params.b // the loss function as a function of data and parameters -def loss(x: Tensor2[Batch, Feature, Float], y: Tensor1[Batch, Float])(params: Params): Tensor0[Float] = +def loss(x: Tensor2[Batch, Feature, Float32], y: Tensor1[Batch, Float32])(params: Params): Tensor0[Float32] = val pred = model(x, y)(params) (pred - y).pow(Tensor0(2.0f)).mean // the training loop, which produces an iterator of parameters -def fit(x: Tensor2[Batch, Feature, Float], y: Tensor1[Batch, Float]): Iterator[Params] = +def fit(x: Tensor2[Batch, Feature, Float32], y: Tensor1[Batch, Float32]): Iterator[Params] = // initialize parameters val p0 = Params( @@ -146,11 +146,11 @@ Let's inspect the tensor more closely: ```scala tensor -// res2: Tensor[Tuple2[Batch, Feature], Float] = [[1. 2.] +// res2: Tensor[Tuple2[Batch, Feature], Float32] = [[1. 2.] // [3. 4.] // [5. 6.]] ``` -We see that the full type of the tensor is `Tensor[(Batch, Feature), Float]`, which indicates that the tensor has two axes, `Batch` and `Feature`, and that the data type of the tensor is `Float`. As this type is rather bulky to write, we can also use the convenient type aliases `Tensor2`, `Tensor1` and `Tensor0` for tensors of rank 2, rank 1 and rank 0 respectively. In this case, the type of the tensor could also be written as `Tensor2[Batch, Feature, Float]`. +We see that the full type of the tensor is `Tensor[(Batch, Feature), Float32]`, which indicates that the tensor has two axes, `Batch` and `Feature`, and that the data type of the tensor is `Float`. As this type is rather bulky to write, we can also use the convenient type aliases `Tensor2`, `Tensor1` and `Tensor0` for tensors of rank 2, rank 1 and rank 0 respectively. In this case, the type of the tensor could also be written as `Tensor2[Batch, Feature, Float32]`. In addition to the type aliases, we also have convenient factory methods for tensors of rank 0 to 2. These allow us to create tensors without having to explicitly create the shape. @@ -183,8 +183,8 @@ val matrix = Tensor2(Axis[Feature], Axis[Batch]).fromArray( ##### A note on type annotations for tensors -In DimWit, the type of a tensor is represented at the type level as `Tensor[ShapeTuple, VType]`, where `ShapeTuple` is a tuple of the labels of the axes and `DataType` is the type of the data. Hence a tensor with shape `Shape(Axis[Batch] -> 3, Axis[Feature] -> 2)` and data type `Float` has the type `Tensor[(Batch, Feature), Float]`. To make type annotations more convenient, we have the type aliases `Tensor0`, `Tensor1` and `Tensor2`, etc. to -refer to Tensors of a specific rank. For example, a `Tensor[(Batch, Feature), Float]` can be referred to as `Tensor2[Batch, Feature, Float]`, a Tensor `Tensor[Tuple1[Batch], Float]` can be referred to as `Tensor1[Batch, Float]` and a `Tensor[EmptyTuple, Float]` can be referred to as `Tensor0[Float]`. +In DimWit, the type of a tensor is represented at the type level as `Tensor[ShapeTuple, VType]`, where `ShapeTuple` is a tuple of the labels of the axes and `DataType` is the type of the data. Hence a tensor with shape `Shape(Axis[Batch] -> 3, Axis[Feature] -> 2)` and data type `Float` has the type `Tensor[(Batch, Feature), Float32]`. To make type annotations more convenient, we have the type aliases `Tensor0`, `Tensor1` and `Tensor2`, etc. to +refer to Tensors of a specific rank. For example, a `Tensor[(Batch, Feature), Float32]` can be referred to as `Tensor2[Batch, Feature, Float32]`, a Tensor `Tensor[Tuple1[Batch], Float32]` can be referred to as `Tensor1[Batch, Float32]` and a `Tensor[EmptyTuple, Float32]` can be referred to as `Tensor0[Float32]`. ### Arithmetic Operations on Tensors and broadcasting @@ -230,8 +230,12 @@ tensor1 + tensor3 // ^^^^^^^ // error: // Conflicting definitions: -// val tensor1: dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B), Float] in class MdocApp1 at line 63 and -// val tensor1: dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B), Float] in class MdocApp1 at line 67 +// val tensor1: +// dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B), +// dimwit.tensor.DType.Float32] in class MdocApp1 at line 63 and +// val tensor1: +// dimwit.tensor.Tensor[(MdocApp1.this.A, MdocApp1.this.B), +// dimwit.tensor.DType.Float32] in class MdocApp1 at line 67 // // val tensor1 = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2)).fill(1.0f) // ^ @@ -275,10 +279,10 @@ axis to sum over using the labels of the axes, which ensures that we are summing The resulting tensor has the correct shape, which is inferred from the labels of the axes: ```scala -val sumOverB : Tensor1[A, Float] = tensor1.sum(Axis[B]) -// sumOverB: Tensor[Tuple1[A], Float] = [2. 2. 2.] -val sumOverA : Tensor1[B, Float] = tensor1.sum(Axis[A]) -// sumOverA: Tensor[Tuple1[B], Float] = [3. 3.] +val sumOverB : Tensor1[A, Float32] = tensor1.sum(Axis[B]) +// sumOverB: Tensor[Tuple1[A], Float32] = [2. 2. 2.] +val sumOverA : Tensor1[B, Float32] = tensor1.sum(Axis[A]) +// sumOverA: Tensor[Tuple1[B], Float32] = [3. 3.] ``` ### Transforming the shape of tensors @@ -296,12 +300,12 @@ val tensor = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2, Axis[C] -> 4)).fill(1.0f) #### Flattening and unflattening axes The first operation we will discuss is `flatten` which flattens part of the tensor into a single axis. Invoked without arguments, `flatten` will flatten all axes into a single axis, resulting in a Tensor1. ```scala -val flattened : Tensor1[A |*| B |*| C, Float] = tensor.flatten +val flattened : Tensor1[A |*| B |*| C, Float32] = tensor.flatten ``` Note that the resulting axis has a label that is a combination of the labels of the original axes. When flattening, we can also specify which axes to flatten, and the resulting axis will have a label that is a combination of the labels of the flattened axes. For example, we can flatten only the last two axes as follows: ```scala - val partiallyFlattened: Tensor2[A, B |*| C, Float] = tensor.flatten((Axis[B], Axis[C])) + val partiallyFlattened: Tensor2[A, B |*| C, Float32] = tensor.flatten((Axis[B], Axis[C])) ``` The counterpart of flatten is `unflatten`, which takes an axis that was previously flattened and restores the original axes. Since in the process of flattening we lost the information about the original shape, @@ -325,7 +329,7 @@ Given two tensors with the same shape except for one axis, we can concatenate th val part1 = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2)).fill(1.0f) val part2 = Tensor(Shape(Axis[A] -> 7, Axis[B] -> 2)).fill(1.0f) - val concatenated: Tensor[(A, B), Float] = concatenate(Seq(part1, part2), Axis[A]) + val concatenated: Tensor[(A, B), Float32] = concatenate(Seq(part1, part2), Axis[A]) ``` The concatenated tensor can be split back into the original tensors using the split method @@ -342,13 +346,13 @@ it returns a single tensor that is a slice of the original tensor. For example, at index 1 along axis A as follows: ```scala - val sliced1 : Tensor1[B, Float]= concatenated.slice(Axis[A].at(1)) + val sliced1 : Tensor1[B, Float32]= concatenated.slice(Axis[A].at(1)) ``` As for split, we can also provide multiple a tuple (or sequence) of indices, which will then select all slices in the tuple. ```scala - val slicedMultiple : Tensor2[A, B, Float] = concatenated.slice(Axis[A].at((0, 2))) + val slicedMultiple : Tensor2[A, B, Float32] = concatenated.slice(Axis[A].at((0, 2))) ``` #### Squeezing, Expanding and transposing axes @@ -358,17 +362,17 @@ method. ```scala val squeezableTensor = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 1, Axis[C] -> 4)).fill(1.0f) -val squeezedTensor : Tensor[(A, C), Float] = squeezableTensor.squeeze(Axis[B]) +val squeezedTensor : Tensor[(A, C), Float32] = squeezableTensor.squeeze(Axis[B]) ``` Similarly, we can add a new axis with extent 1 using the method `appendAxis`: ```scala -val appendedTensor : Tensor[(A, C, B), Float] = squeezedTensor.appendAxis(Axis[B]) +val appendedTensor : Tensor[(A, C, B), Float32] = squeezedTensor.appendAxis(Axis[B]) ``` Finally, we can permute the axes of a tensor using the `transpose` method. The order of the axes is the order of the labels in the tuple that we pass as an argument to the method. For example, we can reorder the above tensor to have the order of axes `A, B, C` as follows: ```scala -val restoredTensor : Tensor[(A, B, C), Float] = appendedTensor.transpose((Axis[A], Axis[B], Axis[C])) +val restoredTensor : Tensor[(A, B, C), Float32] = appendedTensor.transpose((Axis[A], Axis[B], Axis[C])) ``` ### Mapping over axes @@ -387,13 +391,13 @@ val tensor = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2, Axis[C] -> 4)).fill(1.0f) The simplest method is `vapply`. `vapply` applies a function from a `Tensor1` to a `Tensor1` to each slice of the tensor along the specified axis. For example, we can apply the function that multiplies each element by 2 to each slice along axis A as follows: ```scala -val doubled : Tensor3[A, B, C, Float] = tensor.vapply(Axis[A])((slice : Tensor1[A, Float]) => slice *! Tensor0(2.0f)) +val doubled : Tensor3[A, B, C, Float32] = tensor.vapply(Axis[A])((slice : Tensor1[A, Float32]) => slice *! Tensor0(2.0f)) ``` Similar to `vapply`is `vreduce`. `vreduce` applies a function that reduces a `Tensor1` to a `Tensor0` to each slice of the tensor along the specified axis. It effectively reduces the specified axis to a scalar. ```scala -val summedA : Tensor2[B, C, Float] = tensor.vreduce(Axis[A])((slice : Tensor1[A, Float]) => slice.sum) +val summedA : Tensor2[B, C, Float32] = tensor.vreduce(Axis[A])((slice : Tensor1[A, Float32]) => slice.sum) ``` A more general method is `vmap`, which applies a function to the slice of the tensor along the specified axis. The function can return a tensor of any shape, not just a `Tensor1`. The resulting tensor will have the same shape as the original tensor, except that the specified axis will be replaced by the shape of the output of the function. The following example takes a Tensor2 as input and computes the @@ -401,7 +405,7 @@ mean of each slice along axis C ```scala - val res : Tensor[(A, B), Float] = tensor.vmap(Axis[A])((slice : Tensor2[B, C, Float]) => slice.mean(Axis[C])) + val res : Tensor[(A, B), Float32] = tensor.vmap(Axis[A])((slice : Tensor2[B, C, Float32]) => slice.mean(Axis[C])) ``` `zipmap` is a variant of `vmap` that applies a function to the slices of multiple tensors along the specified axis. The function takes as input a tuple of slices, one from each tensor, and, as `vmap` returns a tensor of any shape. The resulting tensor will have the same shape as the original tensors, except that the specified axis will be replaced by the shape of the output of the function. For example, we can use `zipmap` to add two tensors along axis A as follows: @@ -409,7 +413,7 @@ mean of each slice along axis C val t1 = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2, Axis[C] -> 4)).fill(1.0f) val t2 = Tensor(Shape(Axis[A] -> 3, Axis[C] -> 3)).fill(2.0f) -val sumAlongA : Tensor1[A, Float] = zipvmap(Axis[A])(t1, t2)((s1: Tensor2[B, C, Float], s2: Tensor1[C, Float]) => s1.sum + s2.sum) +val sumAlongA : Tensor1[A, Float32] = zipvmap(Axis[A])(t1, t2)((s1: Tensor2[B, C, Float32], s2: Tensor1[C, Float32]) => s1.sum + s2.sum) ``` @@ -426,20 +430,20 @@ import Autodiff.grad Let's take a simple quadratic function as an example. ```scala -def f(x: Tensor1[A, Float]): Tensor0[Float] = x.dot(Axis[A])(x) +def f(x: Tensor1[A, Float32]): Tensor0[Float32] = x.dot(Axis[A])(x) ``` To compute the gradient of this function with respect to its input, we can use the `grad` method as follows: ```scala val x = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) -val gradient : Tensor1[A, Float] => Grad[Tensor1[A, Float]] = grad(f) +val gradient : Tensor1[A, Float32] => Grad[Tensor1[A, Float32]] = grad(f) ``` Note that the result of `grad` is a function that takes as input a tensor and returns the gradient of the function with respect to that tensor. The gradient is a normal tensor, but wrapped in a `Grad` object, to make sure that we don't accidentally use it in a computation without realizing that it is a gradient. To get the actual tensor from the `Grad` object, we can use the `value` method as follows: ```scala -val gradValue : Tensor1[A, Float] = gradient(x).value +val gradValue : Tensor1[A, Float32] = gradient(x).value ``` #### Tensor trees and gradients of multiple parameters @@ -450,13 +454,13 @@ For larger models the most convenient representation of the parameters is usuall ```scala -case class Params(w: Tensor1[Feature, Float], b: Tensor0[Float]) derives TensorTree +case class Params(w: Tensor1[Feature, Float32], b: Tensor0[Float32]) derives TensorTree ``` To compute the gradient of a function that takes as input a tensor tree, we can use the same `grad` method, as long as the input type of the function is a tensor tree. The resulting gradient will then be a tensor tree of the same shape as the input tensor tree, as illustrated in the following example: ```scala -def f(params: Params): Tensor0[Float] = params.w.dot(Axis[Feature])(params.w) + params.b.pow(Tensor0(2.0f)) +def f(params: Params): Tensor0[Float32] = params.w.dot(Axis[Feature])(params.w) + params.b.pow(Tensor0(2.0f)) val params = Params( w = Tensor1(Axis[Feature]).fromArray(Array(1.0f, 2.0f)), @@ -499,7 +503,7 @@ val keys = key.split(5) Often we need to create a sequence of random numbers. In this case, we can use the `splitvmap` method, which splits a key into a tensor of new keys and applies a function to each of the new keys. For example, we can create a vector of random numbers drawn from the normal distribution as follows: ```scala -val sampleVec: Tensor1[A, Float] = key.splitvmap(Axis[A] -> 3)((k: Key) => normalDist.sample(k)) +val sampleVec: Tensor1[A, Float32] = key.splitvmap(Axis[A] -> 3)((k: Key) => normalDist.sample(k)) ``` A more flexible, but less performant way to create a tensor of keys and to use it together with the `vmap` or `zipvmap` method. For example, we can create a tensor of keys and use it to create a tensor of random numbers as follows: diff --git a/examples/src/main/scala/basic/Autoencoder.scala b/examples/src/main/scala/basic/Autoencoder.scala index 4cd5cd8b..152f33a4 100644 --- a/examples/src/main/scala/basic/Autoencoder.scala +++ b/examples/src/main/scala/basic/Autoencoder.scala @@ -40,7 +40,7 @@ class Encoder(p: Encoder.EncoderParams): val layer2 = LinearLayer(p.layer2) val latentLayer = LinearLayer(p.latentLayer) - def apply(v: Tensor1[Pixel, Float]): Tensor1[Latent, Float] = + def apply(v: Tensor1[Pixel, Float32]): Tensor1[Latent, Float32] = val h1 = relu(layer1(v)) val h2 = relu(layer2(h1)) latentLayer(h2) @@ -58,7 +58,7 @@ class Decoder(p: Decoder.DecoderParams): val layer2 = LinearLayer(p.layer2) val outputLayer = LinearLayer(p.outputLayer) - def apply(v: Tensor1[Latent, Float]): Tensor1[ReconstructedPixel, Float] = + def apply(v: Tensor1[Latent, Float32]): Tensor1[ReconstructedPixel, Float32] = val h1 = relu(layer1(v)) val h2 = relu(layer2(h1)) sigmoid(outputLayer(h2)) @@ -75,12 +75,12 @@ case class Autoencoder(params: Autoencoder.Params): val encoder = Encoder(params.encoderParams) val decoder = Decoder(params.decoderParams) - def apply(v: Tensor1[Pixel, Float]): (Tensor1[ReconstructedPixel, Float], Tensor1[Latent, Float]) = + def apply(v: Tensor1[Pixel, Float32]): (Tensor1[ReconstructedPixel, Float32], Tensor1[Latent, Float32]) = val latent = encoder(v) val reconstructed = decoder(latent) (reconstructed, latent) - def loss(original: Tensor1[Pixel, Float]): Tensor0[Float] = + def loss(original: Tensor1[Pixel, Float32]): Tensor0[Float32] = val (reconstructed, _) = apply(original) val eps = 1e-5f val reconstructionLoss = -((original * (reconstructed +! eps).log) + ((Tensor0(1f) -! original) * (1f -! reconstructed +! eps).log)).sum @@ -155,13 +155,13 @@ object AutoencoderExample: // TODO linear layer et al. should support custom initializers // or xavier initialization val initialParams = Autoencoder.Params(encoderParams, decoderParams) - val scaledInitialParams = initialParams.map([T <: Tuple] => (n: Labels[T]) ?=> (t: Tensor[T, Float]) => t *! Tensor0(0.1f)) + val scaledInitialParams = initialParams.map([T <: Tuple] => (n: Labels[T]) ?=> (t: Tensor[T, Float32]) => t *! Tensor0(0.1f)) /* * Training loop * */ - def loss[S <: Sample: Label](trainData: Tensor3[S, Height, Width, Float])(params: Autoencoder.Params): Tensor0[Float] = + def loss[S <: Sample: Label](trainData: Tensor3[S, Height, Width, Float32])(params: Autoencoder.Params): Tensor0[Float32] = val ae = Autoencoder(params) trainData .vmap(Axis[S])(sample => ae.loss(sample.flatten)) @@ -171,7 +171,7 @@ object AutoencoderExample: val optimizer = GradientDescent(learningRate = Tensor0(learningRate)) - def gradientStep(batch: Tensor3[TrainSample, Height, Width, Float], params: Autoencoder.Params): Autoencoder.Params = + def gradientStep(batch: Tensor3[TrainSample, Height, Width, Float32], params: Autoencoder.Params): Autoencoder.Params = val grads = grad(loss(batch))(params) val (newParams, _) = optimizer.update(grads, params, ()) newParams diff --git a/examples/src/main/scala/basic/LogisticRegression.scala b/examples/src/main/scala/basic/LogisticRegression.scala index 075e6f27..85070310 100644 --- a/examples/src/main/scala/basic/LogisticRegression.scala +++ b/examples/src/main/scala/basic/LogisticRegression.scala @@ -1,8 +1,8 @@ package examples.basic import dimwit.* -import dimwit.autodiff.* import dimwit.Conversions.given +import dimwit.autodiff.* import nn.* import nn.ActivationFunctions.{sigmoid, relu} import dimwit.random.Random @@ -17,33 +17,33 @@ object LogisticRegression: // Define a binary logistic regression model case class BinaryLogisticRegression( params: BinaryLogisticRegression.Params - ) extends Function[Tensor1[Feature, Float], Tensor0[Boolean]]: + ) extends Function[Tensor1[Feature, Float32], Tensor0[Bool]]: - def logits(input: Tensor1[Feature, Float]): Tensor0[Float] = + def logits(input: Tensor1[Feature, Float32]): Tensor0[Float32] = params.weights.dot(Axis[Feature])(input) + params.bias - def probits(input: Tensor1[Feature, Float]): Tensor0[Float] = + def probits(input: Tensor1[Feature, Float32]): Tensor0[Float32] = sigmoid(logits(input)) - def apply(input: Tensor1[Feature, Float]): Tensor0[Boolean] = + def apply(input: Tensor1[Feature, Float32]): Tensor0[Bool] = logits(input) >= Tensor0(0f) // Parameters are, by convention, defined in the companion object object BinaryLogisticRegression: case class Params( - weights: Tensor1[Feature, Float], - bias: Tensor0[Float] + weights: Tensor1[Feature, Float32], + bias: Tensor0[Float32] ) derives TensorTree // The loss is a simple binary cross-entropy loss - def loss(data: Tensor2[Sample, Feature, Float], labels: Tensor1[Sample, Boolean])(params: BinaryLogisticRegression.Params) - : Tensor0[Float] = + def loss(data: Tensor2[Sample, Feature, Float32], labels: Tensor1[Sample, Bool])(params: BinaryLogisticRegression.Params) + : Tensor0[Float32] = // Create the model with the given parameters val model = BinaryLogisticRegression(params) // Compute the logistic loss for the model over the dataset - val losses = zipvmap(Axis[Sample])(data, labels.asFloat): + val losses = zipvmap(Axis[Sample])(data, labels.asFloat32): case (sample, label) => val logits = model.logits(sample) relu(logits) - logits * label + ((-logits.abs).exp + 1f).log @@ -97,7 +97,7 @@ object LogisticRegression: val trainLabels = labelsInitial.take(Axis[Sample])(trainPerm) val valLabels = labelsInitial.take(Axis[Sample])(testPerm) - def calcMeanAndStd(t: Tensor2[Sample, Feature, Float]): (Tensor1[Feature, Float], Tensor1[Feature, Float]) = + def calcMeanAndStd(t: Tensor2[Sample, Feature, Float32]): (Tensor1[Feature, Float32], Tensor1[Feature, Float32]) = val mean = t.vmap(Axis[Feature])(_.mean) val std = zipvmap(Axis[Feature])(t, mean): case (x, m) => @@ -105,8 +105,8 @@ object LogisticRegression: (x -! m).pow(2f).mean.sqrt + epsilon (mean, std) - def standardizeData(mean: Tensor1[Feature, Float], std: Tensor1[Feature, Float])(data: Tensor2[Sample, Feature, Float]) - : Tensor2[Sample, Feature, Float] = + def standardizeData(mean: Tensor1[Feature, Float32], std: Tensor1[Feature, Float32])(data: Tensor2[Sample, Feature, Float32]) + : Tensor2[Sample, Feature, Float32] = (data -! mean) /! std // Standardize the training and validation data @@ -124,7 +124,7 @@ object LogisticRegression: val trainLoss = jit(BinaryLogisticRegression.loss(trainingData, trainLabels)) val valLoss = jit(BinaryLogisticRegression.loss(valData, valLabels)) val learningRate = 5e-1f - val gd = GradientDescent(learningRate) + val gd = GradientDescent(Tensor0(learningRate)) // Training loop val numiterations = 1000 @@ -138,8 +138,8 @@ object LogisticRegression: println( List( "epoch: " + index, - "trainAcc: " + (1f - (trainPreds.asFloat - trainLabels.asFloat).abs.mean), - "valAcc: " + (1f - (valPreds.asFloat - valLabels.asFloat).abs.mean) + "trainAcc: " + (1f - (trainPreds.asFloat32 - trainLabels.asFloat32).abs.mean), + "valAcc: " + (1f - (valPreds.asFloat32 - valLabels.asFloat32).abs.mean) ).mkString(", ") ) .map((params, _) => params) diff --git a/examples/src/main/scala/basic/MLClassifierMNist.scala b/examples/src/main/scala/basic/MLClassifierMNist.scala index 40d8a2d3..8fe12d54 100644 --- a/examples/src/main/scala/basic/MLClassifierMNist.scala +++ b/examples/src/main/scala/basic/MLClassifierMNist.scala @@ -2,7 +2,6 @@ package examples.basic import dimwit.* import dimwit.autodiff.* -import dimwit.Conversions.given import nn.* import nn.ActivationFunctions.{relu, sigmoid} import dimwit.random.Random @@ -11,9 +10,9 @@ import examples.timed import examples.dataset.MNISTLoader def binaryCrossEntropy[L: Label]( - logits: Tensor1[L, Float], - label: Tensor0[Int] -): Tensor0[Float] = + logits: Tensor1[L, Float32], + label: Tensor0[Int32] +): Tensor0[Float32] = val maxLogit = logits.max val stableExp = (logits -! maxLogit).exp val logSumExp = stableExp.sum.log + maxLogit @@ -47,18 +46,18 @@ object MLPClassifierMNist: layer2 = LinearLayer.Params(key2)(layer2Dim, outputDim) ) - case class MLP(params: MLP.Params) extends Function[Tensor2[Height, Width, Float], Tensor0[Int]]: + case class MLP(params: MLP.Params) extends Function[Tensor2[Height, Width, Float32], Tensor0[Int32]]: private val layer1 = LinearLayer(params.layer1) private val layer2 = LinearLayer(params.layer2) def logits( - image: Tensor2[Height, Width, Float] - ): Tensor1[Output, Float] = + image: Tensor2[Height, Width, Float32] + ): Tensor1[Output, Float32] = val hidden = relu(layer1(image.flatten)) layer2(hidden) - override def apply(image: Tensor2[Height, Width, Float]): Tensor0[Int] = logits(image).argmax(Axis[Output]) + override def apply(image: Tensor2[Height, Width, Float32]): Tensor0[Int32] = logits(image).argmax(Axis[Output]) def main(args: Array[String]): Unit = @@ -74,9 +73,9 @@ object MLPClassifierMNist: val (trainX, trainY) = MNISTLoader.createTrainingDataset(maxSamples = Some(numSamples)).get val (testX, testY) = MNISTLoader.createTestDataset(maxSamples = Some(numTestSamples)).get - def batchLoss(batchImages: Tensor[(TrainSample, Height, Width), Float], batchLabels: Tensor1[TrainSample, Int])( + def batchLoss(batchImages: Tensor[(TrainSample, Height, Width), Float32], batchLabels: Tensor1[TrainSample, Int32])( params: MLP.Params - ): Tensor0[Float] = + ): Tensor0[Float32] = val model = MLP(params) val losses = zipvmap(Axis[TrainSample])(batchImages, batchLabels): case (image, label) => @@ -90,11 +89,11 @@ object MLPClassifierMNist: )(initKey) def accuracy[S: Label]( - predictions: Tensor1[S, Int], - targets: Tensor1[S, Int] - ): Tensor0[Float] = + predictions: Tensor1[S, Int32], + targets: Tensor1[S, Int32] + ): Tensor0[Float32] = val matches = zipvmap(Axis[S])(predictions, targets)(_ === _) - matches.asFloat.mean + matches.asFloat32.mean // val optimizer = GradientDescent(learningRate = Tensor0(1e-4f)) // type OptState = Unit @@ -106,8 +105,8 @@ object MLPClassifierMNist: type OptState = AdamState[MLP.Params] def gradientStep( - imageBatch: Tensor[(TrainSample, Height, Width), Float], - labelBatch: Tensor1[TrainSample, Int], + imageBatch: Tensor[(TrainSample, Height, Width), Float32], + labelBatch: Tensor1[TrainSample, Int32], params: MLP.Params, state: OptState ): (MLP.Params, OptState) = @@ -117,8 +116,8 @@ object MLPClassifierMNist: val (jitDonate, jitStep, jitReclaim) = jitDonating(gradientStep) def miniBatchGradientDescent( - imageBatches: Seq[Tensor[(TrainSample, Height, Width), Float]], - labelBatches: Seq[Tensor1[TrainSample, Int]] + imageBatches: Seq[Tensor[(TrainSample, Height, Width), Float32]], + labelBatches: Seq[Tensor1[TrainSample, Int32]] )( params: MLP.Params, initialState: OptState @@ -127,12 +126,12 @@ object MLPClassifierMNist: .zip(labelBatches) .foldLeft(jitDonate(params, initialState)): case ((currentParams, state), (imageBatch, labelBatch)) => - jitStep(imageBatch, labelBatch, currentParams, state) + jitStep(imageBatch, labelBatch.asInt32, currentParams, state) jitReclaim(res) val trainMiniBatchGradientDescent = miniBatchGradientDescent( trainX.chunk(Axis[TrainSample], numSamples / batchSize), - trainY.chunk(Axis[TrainSample], numSamples / batchSize) + trainY.asInt32.chunk(Axis[TrainSample], numSamples / batchSize) ) val trainTrajectory = Iterator.iterate((initParams, optimizer.init(initParams))): (currentParams, state) => timed("Training"): @@ -140,19 +139,19 @@ object MLPClassifierMNist: trainMiniBatchGradientDescent(currentParams, state) def evaluate[S <: Sample: Label]( params: MLP.Params, - dataX: Tensor3[S, Height, Width, Float], - dataY: Tensor1[S, Int] - ): Tensor0[Float] = + dataX: Tensor3[S, Height, Width, Float32], + dataY: Tensor1[S, Int32] + ): Tensor0[Float32] = val model = MLP(params) val predictions = dataX.vmap(Axis[S])(model) - accuracy(predictions, dataY) + accuracy(predictions, dataY.asInt32) val jitEvaluate = evaluate val (finalParams, finalState) = trainTrajectory.zipWithIndex .tapEach: case ((params, state), epoch) => timed("Evaluation"): - val testAccuracy = evaluate(params, testX, testY) - val trainAccuracy = evaluate(params, trainX, trainY) + val testAccuracy = evaluate(params, testX, testY.asInt32) + val trainAccuracy = evaluate(params, trainX, trainY.asInt32) println( List( s"Epoch $epoch", diff --git a/examples/src/main/scala/basic/MLClassifierMNistCNN.scala b/examples/src/main/scala/basic/MLClassifierMNistCNN.scala index bf1e9cde..605d4fee 100644 --- a/examples/src/main/scala/basic/MLClassifierMNistCNN.scala +++ b/examples/src/main/scala/basic/MLClassifierMNistCNN.scala @@ -12,9 +12,9 @@ import examples.basic.MLPClassifierMNist.MLP // Logits-based Cross Entropy (same as yours) def binaryCrossEntropy[L: Label]( - logits: Tensor1[L, Float], - label: Tensor0[Int] -): Tensor0[Float] = + logits: Tensor1[L, Float32], + label: Tensor0[Int32] +): Tensor0[Float32] = val maxLogit = logits.max val logSumExp = ((logits -! maxLogit).exp.sum + 1e-7f).log + maxLogit val targetLogit = logits.slice(Axis[L].at(label)) @@ -32,8 +32,8 @@ object MNistCNN: object CNN: case class Params( - conv1: Conv2DLayer.Params[Height, Width, Channel, Hidden], - conv2: Conv2DLayer.Params[Height, Width, Hidden, PixelEmbedding], + conv1: Conv2DLayer.Params[Height, Width, Channel, Hidden, Float32], + conv2: Conv2DLayer.Params[Height, Width, Hidden, PixelEmbedding, Float32], output: LinearLayer.Params[ImageEmbedding, Output] ) @@ -56,18 +56,18 @@ object MNistCNN: output = LinearLayer.Params(keys(2))(embeddingDim, outputDim) ) - case class CNN(params: CNN.Params) extends Function[Tensor2[Height, Width, Float], Tensor0[Int]]: + case class CNN(params: CNN.Params) extends Function[Tensor2[Height, Width, Float32], Tensor0[Int32]]: private val conv1 = Conv2DLayer(params.conv1, stride = 2, padding = Padding.SAME) private val conv2 = Conv2DLayer(params.conv2, stride = 2, padding = Padding.SAME) private val output = LinearLayer(params.output) - def logits(image: Tensor2[Height, Width, Float]): Tensor1[Output, Float] = + def logits(image: Tensor2[Height, Width, Float32]): Tensor1[Output, Float32] = val input = image.appendAxis(Axis[Channel]) val hidden = relu(conv1(input)) val features = relu(conv2(hidden)) output(features.flatten) - override def apply(image: Tensor2[Height, Width, Float]): Tensor0[Int] = + override def apply(image: Tensor2[Height, Width, Float32]): Tensor0[Int32] = logits(image).argmax(Axis[Output]) def main(args: Array[String]): Unit = @@ -86,9 +86,9 @@ object MNistCNN: val initParams = CNN.Params(trainKey)(16, 32) val scaledInitialParams = initParams **! Tensor0(0.1f) - def batchLoss(batchImages: Tensor[(TrainSample, Height, Width), Float], batchLabels: Tensor1[TrainSample, Int])( + def batchLoss(batchImages: Tensor[(TrainSample, Height, Width), Float32], batchLabels: Tensor1[TrainSample, Int32])( params: CNN.Params - ): Tensor0[Float] = + ): Tensor0[Float32] = val model = CNN(params) val batchLosses = zipvmap(Axis[TrainSample])(batchImages, batchLabels): case (img, lbl) => @@ -98,8 +98,8 @@ object MNistCNN: val optimizer = GradientDescent(learningRate = Tensor0(learningRate)) def gradientStep( - imageBatch: Tensor[(TrainSample, Height, Width), Float], - labelBatch: Tensor1[TrainSample, Int], + imageBatch: Tensor[(TrainSample, Height, Width), Float32], + labelBatch: Tensor1[TrainSample, Int32], params: CNN.Params ): CNN.Params = val grads = Autodiff.grad(batchLoss(imageBatch, labelBatch))(params) @@ -115,19 +115,19 @@ object MNistCNN: val lblBatches = trainY.chunk(Axis[TrainSample], numSamples / batchSize) val newParams = imgBatches.zip(lblBatches).foldLeft(jitDonate(params)): case (params, (imgB, lblB)) => - jitStep(imgB, lblB, params) + jitStep(imgB, lblB.asInt32, params) jitReclaim(newParams) // Evaluation - def evaluate[S <: Sample: Label](params: CNN.Params, dataX: Tensor[(S, Height, Width), Float], dataY: Tensor1[S, Int]): Tensor0[Float] = + def evaluate[S <: Sample: Label](params: CNN.Params, dataX: Tensor[(S, Height, Width), Float32], dataY: Tensor1[S, Int32]): Tensor0[Float32] = val model = CNN(params) val predictions = dataX.vmap(Axis[S])(model) val matches = zipvmap(Axis[S])(predictions, dataY)(_ === _) - matches.asFloat.mean + matches.asFloat32.mean trainTrajectory.drop(1).zipWithIndex.foreach: case (params, epoch) => if epoch % 1 == 0 then dimwit.gc() - val acc = evaluate(params, testX, testY) + val acc = evaluate(params, testX, testY.asInt32) println(f"Epoch $epoch | Test Accuracy: ${acc.item * 100}%.2f%%") diff --git a/examples/src/main/scala/basic/Playground.scala b/examples/src/main/scala/basic/Playground.scala index 0ff57434..db0ba479 100644 --- a/examples/src/main/scala/basic/Playground.scala +++ b/examples/src/main/scala/basic/Playground.scala @@ -9,7 +9,7 @@ object Playground extends App: trait A derives Label trait B derives Label - def f(x: Tensor1[A, Float]): Tensor0[Float] = + def f(x: Tensor1[A, Float32]): Tensor0[Float32] = x.sum grad(f) diff --git a/examples/src/main/scala/complex/GPT2.scala b/examples/src/main/scala/complex/GPT2.scala index af567da2..6caec9fd 100644 --- a/examples/src/main/scala/complex/GPT2.scala +++ b/examples/src/main/scala/complex/GPT2.scala @@ -15,17 +15,17 @@ trait EmbeddingMixed derives Label // 3072 trait Batch derives Label case class LayerNormalizationParams( - weight: Tensor1[Embedding, Float], - bias: Tensor1[Embedding, Float] + weight: Tensor1[Embedding, Float32], + bias: Tensor1[Embedding, Float32] ) case class LinearLayerParams[In, Out]( - weight: Tensor2[In, Out, Float], - bias: Tensor1[Out, Float] + weight: Tensor2[In, Out, Float32], + bias: Tensor1[Out, Float32] ) case class ProjectionLayerParams[In, Out]( - weight: Tensor2[In, Out, Float] + weight: Tensor2[In, Out, Float32] ) trait Head derives Label @@ -33,7 +33,7 @@ trait HeadKey derives Label trait HeadQuery derives Label trait HeadValue derives Label -case class HeadsParams[Kind](val weights: Tensor3[Head, Embedding, Kind, Float], val bias: Tensor2[Head, Kind, Float]) +case class HeadsParams[Kind](val weights: Tensor3[Head, Embedding, Kind, Float32], val bias: Tensor2[Head, Kind, Float32]) case class MultiHeadAttentionParams( wq: HeadsParams[HeadQuery], @@ -55,8 +55,8 @@ case class TransformerLayerParams( ) case class GPT2Params( - vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float], - positionalEmbeddings: Tensor2[Context, Embedding, Float], + vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], + positionalEmbeddings: Tensor2[Context, Embedding, Float32], layers: List[TransformerLayerParams], outputNormalization: LayerNormalizationParams, output: ProjectionLayerParams[Embedding, Vocab] @@ -64,8 +64,8 @@ case class GPT2Params( object GPT2Params: def apply( - vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float], - positionalEmbeddings: Tensor2[Context, Embedding, Float], + vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], + positionalEmbeddings: Tensor2[Context, Embedding, Float32], layers: List[TransformerLayerParams], outputNormalization: LayerNormalizationParams ): GPT2Params = @@ -74,48 +74,48 @@ object GPT2Params: ) GPT2Params(vocabularyEmbeddings, positionalEmbeddings, layers, outputNormalization, outputParams) -case class GPT2(params: GPT2Params) extends (Tensor2[Batch, Context, Int] => Tensor2[Batch, Context, Int]): +case class GPT2(params: GPT2Params) extends (Tensor2[Batch, Context, Int32] => Tensor2[Batch, Context, Int32]): - private case class LinearLayer[In: Label, Out: Label](params: LinearLayerParams[In, Out]) extends (Tensor1[In, Float] => Tensor1[Out, Float]): - override def apply(x: Tensor1[In, Float]): Tensor1[Out, Float] = + private case class LinearLayer[In: Label, Out: Label](params: LinearLayerParams[In, Out]) extends (Tensor1[In, Float32] => Tensor1[Out, Float32]): + override def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = x.dot(Axis[In])(params.weight) + params.bias - private case class EmbeddingMixer(params: EmbeddingMixerParams) extends (Tensor2[Context, Embedding, Float] => Tensor2[Context, Embedding, Float]): + private case class EmbeddingMixer(params: EmbeddingMixerParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): private val hiddenLayer = LinearLayer(params.c_fc) private val outputLayer = LinearLayer(params.c_proj) // TODO add dropout - def apply(in: Tensor2[Context, Embedding, Float]): Tensor2[Context, Embedding, Float] = + def apply(in: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = in.vmap(Axis[Context])(x => val hidden = gelu(hiddenLayer(x)) outputLayer(hidden) ) - private case class ProjectionLayer[In: Label, Out: Label](params: ProjectionLayerParams[In, Out]) extends (Tensor1[In, Float] => Tensor1[Out, Float]): - def apply(x: Tensor1[In, Float]): Tensor1[Out, Float] = + private case class ProjectionLayer[In: Label, Out: Label](params: ProjectionLayerParams[In, Out]) extends (Tensor1[In, Float32] => Tensor1[Out, Float32]): + def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = x.dot(Axis[In])(params.weight) - private case class MultiHeadAttention(params: MultiHeadAttentionParams) extends (Tensor2[Context, Embedding, Float] => Tensor2[Context, Embedding, Float]): + private case class MultiHeadAttention(params: MultiHeadAttentionParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): private val projection = LinearLayer(params.proj) - def apply(x: Tensor2[Context, Embedding, Float]): Tensor2[Context, Embedding, Float] = + def apply(x: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = val heads = zipvmap(Axis[Head])(params.wq.weights, params.wq.bias, params.wk.weights, params.wk.bias, params.wv.weights, params.wv.bias): attention.tupled(_)(x) heads.vmap(Axis[Context])(heads => projection(heads.flatten)) private def attention( - wq: Tensor2[Embedding, HeadQuery, Float], - wqBias: Tensor1[HeadQuery, Float], - wk: Tensor2[Embedding, HeadKey, Float], - wkBias: Tensor1[HeadKey, Float], - wv: Tensor2[Embedding, HeadValue, Float], - wvBias: Tensor1[HeadValue, Float] - )(x: Tensor2[Context, Embedding, Float]): Tensor2[Context, HeadValue, Float] = + wq: Tensor2[Embedding, HeadQuery, Float32], + wqBias: Tensor1[HeadQuery, Float32], + wk: Tensor2[Embedding, HeadKey, Float32], + wkBias: Tensor1[HeadKey, Float32], + wv: Tensor2[Embedding, HeadValue, Float32], + wvBias: Tensor1[HeadValue, Float32] + )(x: Tensor2[Context, Embedding, Float32]): Tensor2[Context, HeadValue, Float32] = trait AttnWeights derives Label - def causalMasking(attnScores: Tensor2[Context, Prime[Context], Float]): Tensor2[Context, Prime[Context], Float] = + def causalMasking(attnScores: Tensor2[Context, Prime[Context], Float32]): Tensor2[Context, Prime[Context], Float32] = val ctxLength = attnScores.shape(Axis[Context]) val causalMask = tril(Tensor(Shape((Axis[Context] -> ctxLength, Axis[Prime[Context]] -> ctxLength))).fill(true)) where(causalMask, attnScores, Tensor.like(attnScores).fill(Float.NegativeInfinity)) @@ -130,63 +130,63 @@ case class GPT2(params: GPT2Params) extends (Tensor2[Batch, Context, Int] => Ten val res = attnWeights.dot(Axis[AttnWeights ~ Context])(values) res - private case class LayerNorm(params: LayerNormalizationParams) extends (Tensor1[Embedding, Float] => Tensor1[Embedding, Float]): + private case class LayerNorm(params: LayerNormalizationParams) extends (Tensor1[Embedding, Float32] => Tensor1[Embedding, Float32]): - private def standardize(x: Tensor1[Embedding, Float]): Tensor1[Embedding, Float] = + private def standardize(x: Tensor1[Embedding, Float32]): Tensor1[Embedding, Float32] = val x0 = x -! x.mean val variance = x0.pow(2).mean val epsilon = 1e-6f x0 /! (variance + epsilon).sqrt - def apply(x: Tensor1[Embedding, Float]): Tensor1[Embedding, Float] = + def apply(x: Tensor1[Embedding, Float32]): Tensor1[Embedding, Float32] = standardize(x) * params.weight + params.bias - private case class TransformerLayer(params: TransformerLayerParams) extends (Tensor2[Context, Embedding, Float] => Tensor2[Context, Embedding, Float]): + private case class TransformerLayer(params: TransformerLayerParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): private val embeddingMixer = EmbeddingMixer(params.embeddingMixer) private val multiHeadAttention = MultiHeadAttention(params.attn) private val preNormalization = LayerNorm(params.ln1) private val postNormalization = LayerNorm(params.ln2) - def apply(t: Tensor2[Context, Embedding, Float]): Tensor2[Context, Embedding, Float] = + def apply(t: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = var x = t x = x + multiHeadAttention(x.vmap(Axis[Context])(preNormalization)) x = x + embeddingMixer(x.vmap(Axis[Context])(postNormalization)) x - private case class TransformerBlock(layers: List[TransformerLayer]) extends (Tensor2[Context, Embedding, Float] => Tensor2[Context, Embedding, Float]): - override def apply(t: Tensor2[Context, Embedding, Float]): Tensor2[Context, Embedding, Float] = + private case class TransformerBlock(layers: List[TransformerLayer]) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): + override def apply(t: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = layers.foldLeft(t): case (t, layer) => layer(t) - case class Embedder(vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float], positionalEmbeddings: Tensor2[Context, Embedding, Float]): + case class Embedder(vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], positionalEmbeddings: Tensor2[Context, Embedding, Float32]): - def apply(tokens: Tensor1[Context, Int]): Tensor2[Context, Embedding, Float] = + def apply(tokens: Tensor1[Context, Int32]): Tensor2[Context, Embedding, Float32] = val embeddings = vocabularyEmbeddings.take(Axis[Vocab])(tokens) embeddings + positionalEmbeddings - case class OutputLayer(normalization: LayerNormalizationParams, projectionParams: ProjectionLayerParams[Embedding, Vocab]) extends (Tensor1[Embedding, Float] => Tensor1[Vocab, Float]): + case class OutputLayer(normalization: LayerNormalizationParams, projectionParams: ProjectionLayerParams[Embedding, Vocab]) extends (Tensor1[Embedding, Float32] => Tensor1[Vocab, Float32]): private val normalizationLayer = LayerNorm(normalization) private val projection = ProjectionLayer(projectionParams) - override def apply(x: Tensor1[Embedding, Float]): Tensor1[Vocab, Float] = + override def apply(x: Tensor1[Embedding, Float32]): Tensor1[Vocab, Float32] = projection(normalizationLayer(x)) private val embedder = Embedder(params.vocabularyEmbeddings, params.positionalEmbeddings) private val transformerBlock = TransformerBlock(params.layers.map(TransformerLayer(_))) private val outputLayer = OutputLayer(params.outputNormalization, params.output) - def logits(inputTokens: Tensor2[Batch, Context, Int]): Tensor3[Batch, Context, Vocab, Float] = + def logits(inputTokens: Tensor2[Batch, Context, Int32]): Tensor3[Batch, Context, Vocab, Float32] = inputTokens.vmap(Axis[Batch]): case tokens => val startEmbeddings = embedder(tokens) val endEmbeddings = transformerBlock(startEmbeddings) endEmbeddings.vmap(Axis[Context])(x => outputLayer(x)) - def probits(inputTokens: Tensor2[Batch, Context, Int]): Tensor3[Batch, Context, Vocab, Float] = + def probits(inputTokens: Tensor2[Batch, Context, Int32]): Tensor3[Batch, Context, Vocab, Float32] = val x = logits(inputTokens) val res = x.vapply(Axis[Vocab])(softmax) return res - def apply(inputTokens: Tensor2[Batch, Context, Int]): Tensor2[Batch, Context, Int] = + def apply(inputTokens: Tensor2[Batch, Context, Int32]): Tensor2[Batch, Context, Int32] = val x = probits(inputTokens) val res = x.argmax(Axis[Vocab]) return res @@ -315,7 +315,7 @@ object GPT2Inference: val cAttn = loadLinear(cAttnName, Axis[Embedding], Axis[QKV]) val cProj = loadLinear(cProjName, Axis[Head |*| HeadValue], Axis[Embedding]) - def splitWeightToHeads[L](t: Tensor2[Embedding, Head |*| L, Float], numHeads: Int)(using label: Label[L]): Tensor3[Head, Embedding, L, Float] = + def splitWeightToHeads[L](t: Tensor2[Embedding, Head |*| L, Float32], numHeads: Int)(using label: Label[L]): Tensor3[Head, Embedding, L, Float32] = val tLength = t.shape(Axis[Head |*| L]) require(tLength % numHeads == 0, s"T length $tLength not divisible by numHeads $numHeads") t.rearrange( @@ -323,7 +323,7 @@ object GPT2Inference: Axis[Head] -> numHeads, Axis[L] -> (tLength / numHeads) ) - def splitBiasToHeads[L](t: Tensor1[Head |*| L, Float], numHeads: Int)(using label: Label[L]): Tensor2[Head, L, Float] = + def splitBiasToHeads[L](t: Tensor1[Head |*| L, Float32], numHeads: Int)(using label: Label[L]): Tensor2[Head, L, Float32] = val tLength = t.shape(Axis[Head |*| L]) require(tLength % numHeads == 0, s"T length $tLength not divisible by numHeads $numHeads") t.rearrange( @@ -351,12 +351,12 @@ object GPT2Inference: proj = cProj ) - def load1[L](name: String, axis: Axis[L])(using Label[L]): Tensor1[L, Float] = + def load1[L](name: String, axis: Axis[L])(using Label[L]): Tensor1[L, Float32] = val info = tensorMap(name) val jaxArray = SafeTensorsReader.loadTensor(filePath, info, dataStartPos) liftPyTensor(jaxArray) - def load2[L1, L2](name: String, axis1: Axis[L1], axis2: Axis[L2])(using Label[L1], Label[L2]): Tensor2[L1, L2, Float] = + def load2[L1, L2](name: String, axis1: Axis[L1], axis2: Axis[L2])(using Label[L1], Label[L2]): Tensor2[L1, L2, Float32] = val info = tensorMap(name) val jaxArray = SafeTensorsReader.loadTensor(filePath, info, dataStartPos) liftPyTensor(jaxArray) diff --git a/examples/src/main/scala/complex/GPT2Train.scala b/examples/src/main/scala/complex/GPT2Train.scala index dee88c54..1e16115d 100644 --- a/examples/src/main/scala/complex/GPT2Train.scala +++ b/examples/src/main/scala/complex/GPT2Train.scala @@ -72,17 +72,17 @@ trait EmbeddingMixed derives Label trait Batch derives Label case class LayerNormalizationParams( - weight: Tensor1[Embedding, Float], - bias: Tensor1[Embedding, Float] + weight: Tensor1[Embedding, Float32], + bias: Tensor1[Embedding, Float32] ) case class LinearLayerParams[In, Out]( - weight: Tensor2[In, Out, Float], - bias: Tensor1[Out, Float] + weight: Tensor2[In, Out, Float32], + bias: Tensor1[Out, Float32] ) case class ProjectionLayerParams[In, Out]( - weight: Tensor2[In, Out, Float] + weight: Tensor2[In, Out, Float32] ) trait Head derives Label @@ -90,7 +90,7 @@ trait HeadKey derives Label trait HeadQuery derives Label trait HeadValue derives Label -case class HeadsParams[Kind](val weights: Tensor3[Head, Embedding, Kind, Float], val bias: Tensor2[Head, Kind, Float]) +case class HeadsParams[Kind](val weights: Tensor3[Head, Embedding, Kind, Float32], val bias: Tensor2[Head, Kind, Float32]) case class MultiHeadAttentionParams( wq: HeadsParams[HeadQuery], @@ -112,8 +112,8 @@ case class TransformerLayerParams( ) derives TensorTree case class GPT2Params( - vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float], - positionalEmbeddings: Tensor2[Context, Embedding, Float], + vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], + positionalEmbeddings: Tensor2[Context, Embedding, Float32], layers: List[TransformerLayerParams], outputNormalization: LayerNormalizationParams ) derives TensorTree @@ -174,48 +174,48 @@ object GPT2Params: outputNormalization = initLayerNormalizationParams() ) -case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Context, Int]): +case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int32] => Tensor1[Context, Int32]): - private case class LinearLayer[In: Label, Out: Label](params: LinearLayerParams[In, Out]) extends (Tensor1[In, Float] => Tensor1[Out, Float]): - override def apply(x: Tensor1[In, Float]): Tensor1[Out, Float] = + private case class LinearLayer[In: Label, Out: Label](params: LinearLayerParams[In, Out]) extends (Tensor1[In, Float32] => Tensor1[Out, Float32]): + override def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = x.dot(Axis[In])(params.weight) + params.bias - private case class EmbeddingMixer(params: EmbeddingMixerParams) extends (Tensor2[Context, Embedding, Float] => Tensor2[Context, Embedding, Float]): + private case class EmbeddingMixer(params: EmbeddingMixerParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): private val hiddenLayer = LinearLayer(params.c_fc) private val outputLayer = LinearLayer(params.c_proj) // TODO add dropout - def apply(in: Tensor2[Context, Embedding, Float]): Tensor2[Context, Embedding, Float] = + def apply(in: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = in.vmap(Axis[Context])(x => val hidden = gelu(hiddenLayer(x)) outputLayer(hidden) ) - private case class ProjectionLayer[In: Label, Out: Label](params: ProjectionLayerParams[In, Out]) extends (Tensor1[In, Float] => Tensor1[Out, Float]): - def apply(x: Tensor1[In, Float]): Tensor1[Out, Float] = + private case class ProjectionLayer[In: Label, Out: Label](params: ProjectionLayerParams[In, Out]) extends (Tensor1[In, Float32] => Tensor1[Out, Float32]): + def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = x.dot(Axis[In])(params.weight) - private case class MultiHeadAttention(params: MultiHeadAttentionParams) extends (Tensor2[Context, Embedding, Float] => Tensor2[Context, Embedding, Float]): + private case class MultiHeadAttention(params: MultiHeadAttentionParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): private val projection = LinearLayer(params.proj) - def apply(x: Tensor2[Context, Embedding, Float]): Tensor2[Context, Embedding, Float] = + def apply(x: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = val heads = zipvmap(Axis[Head])(params.wq.weights, params.wq.bias, params.wk.weights, params.wk.bias, params.wv.weights, params.wv.bias): attention.tupled(_)(x) heads.vmap(Axis[Context])(heads => projection(heads.flatten)) private def attention( - wq: Tensor2[Embedding, HeadQuery, Float], - wqBias: Tensor1[HeadQuery, Float], - wk: Tensor2[Embedding, HeadKey, Float], - wkBias: Tensor1[HeadKey, Float], - wv: Tensor2[Embedding, HeadValue, Float], - wvBias: Tensor1[HeadValue, Float] - )(x: Tensor2[Context, Embedding, Float]): Tensor2[Context, HeadValue, Float] = + wq: Tensor2[Embedding, HeadQuery, Float32], + wqBias: Tensor1[HeadQuery, Float32], + wk: Tensor2[Embedding, HeadKey, Float32], + wkBias: Tensor1[HeadKey, Float32], + wv: Tensor2[Embedding, HeadValue, Float32], + wvBias: Tensor1[HeadValue, Float32] + )(x: Tensor2[Context, Embedding, Float32]): Tensor2[Context, HeadValue, Float32] = trait AttnWeights derives Label - def causalMasking(attnScores: Tensor2[Context, Prime[Context], Float]): Tensor2[Context, Prime[Context], Float] = + def causalMasking(attnScores: Tensor2[Context, Prime[Context], Float32]): Tensor2[Context, Prime[Context], Float32] = val ctxLength = attnScores.shape(Axis[Context]) val causalMask = tril(Tensor(Shape((Axis[Context] -> ctxLength, Axis[Prime[Context]] -> ctxLength))).fill(true)) where(causalMask, attnScores, Tensor.like(attnScores).fill(Float.NegativeInfinity)) @@ -230,44 +230,44 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co val res = attnWeights.dot(Axis[AttnWeights ~ Context])(values) res - private case class LayerNorm(params: LayerNormalizationParams) extends (Tensor1[Embedding, Float] => Tensor1[Embedding, Float]): + private case class LayerNorm(params: LayerNormalizationParams) extends (Tensor1[Embedding, Float32] => Tensor1[Embedding, Float32]): - private def standardize(x: Tensor1[Embedding, Float]): Tensor1[Embedding, Float] = + private def standardize(x: Tensor1[Embedding, Float32]): Tensor1[Embedding, Float32] = val x0 = x -! x.mean val variance = x0.pow(2).mean val epsilon = 1e-6f x0 /! (variance + epsilon).sqrt - def apply(x: Tensor1[Embedding, Float]): Tensor1[Embedding, Float] = + def apply(x: Tensor1[Embedding, Float32]): Tensor1[Embedding, Float32] = standardize(x) * params.weight + params.bias - private case class TransformerLayer(params: TransformerLayerParams) extends (Tensor2[Context, Embedding, Float] => Tensor2[Context, Embedding, Float]): + private case class TransformerLayer(params: TransformerLayerParams) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): private val embeddingMixer = EmbeddingMixer(params.embeddingMixer) private val multiHeadAttention = MultiHeadAttention(params.attn) private val preNormalization = LayerNorm(params.ln1) private val postNormalization = LayerNorm(params.ln2) - def apply(t: Tensor2[Context, Embedding, Float]): Tensor2[Context, Embedding, Float] = + def apply(t: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = var x = t x = x + multiHeadAttention(x.vmap(Axis[Context])(preNormalization)) x = x + embeddingMixer(x.vmap(Axis[Context])(postNormalization)) x - private case class TransformerBlock(layers: List[TransformerLayer]) extends (Tensor2[Context, Embedding, Float] => Tensor2[Context, Embedding, Float]): - override def apply(t: Tensor2[Context, Embedding, Float]): Tensor2[Context, Embedding, Float] = + private case class TransformerBlock(layers: List[TransformerLayer]) extends (Tensor2[Context, Embedding, Float32] => Tensor2[Context, Embedding, Float32]): + override def apply(t: Tensor2[Context, Embedding, Float32]): Tensor2[Context, Embedding, Float32] = layers.foldLeft(t): case (t, layer) => layer(t) - case class Embedder(vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float], positionalEmbeddings: Tensor2[Context, Embedding, Float]): + case class Embedder(vocabularyEmbeddings: Tensor2[Vocab, Embedding, Float32], positionalEmbeddings: Tensor2[Context, Embedding, Float32]): - def apply(tokens: Tensor1[Context, Int]): Tensor2[Context, Embedding, Float] = + def apply(tokens: Tensor1[Context, Int32]): Tensor2[Context, Embedding, Float32] = val embeddings = vocabularyEmbeddings.take(Axis[Vocab])(tokens) embeddings + positionalEmbeddings - case class OutputLayer(normalization: LayerNormalizationParams, projectionParams: ProjectionLayerParams[Embedding, Vocab]) extends (Tensor1[Embedding, Float] => Tensor1[Vocab, Float]): + case class OutputLayer(normalization: LayerNormalizationParams, projectionParams: ProjectionLayerParams[Embedding, Vocab]) extends (Tensor1[Embedding, Float32] => Tensor1[Vocab, Float32]): private val normalizationLayer = LayerNorm(normalization) private val projection = ProjectionLayer(projectionParams) - override def apply(x: Tensor1[Embedding, Float]): Tensor1[Vocab, Float] = + override def apply(x: Tensor1[Embedding, Float32]): Tensor1[Vocab, Float32] = projection(normalizationLayer(x)) private val embedder = Embedder(params.vocabularyEmbeddings, params.positionalEmbeddings) @@ -277,23 +277,25 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co ProjectionLayerParams(params.vocabularyEmbeddings.transpose) // Tying output weights with input embeddings ) - def logits(inputTokens: Tensor1[Context, Int]): Tensor2[Context, Vocab, Float] = + def logits(inputTokens: Tensor1[Context, Int32]): Tensor2[Context, Vocab, Float32] = val startEmbeddings = embedder(inputTokens) val endEmbeddings = transformerBlock(startEmbeddings) endEmbeddings.vmap(Axis[Context])(x => outputLayer(x)) - def probits(inputTokens: Tensor1[Context, Int]): Tensor2[Context, Vocab, Float] = + def probits(inputTokens: Tensor1[Context, Int32]): Tensor2[Context, Vocab, Float32] = val x = logits(inputTokens) val res = x.vapply(Axis[Vocab])(softmax) return res - def apply(inputTokens: Tensor1[Context, Int]): Tensor1[Context, Int] = + def apply(inputTokens: Tensor1[Context, Int32]): Tensor1[Context, Int32] = val x = probits(inputTokens) val res = x.argmax(Axis[Vocab]) return res @main def train(): Unit = + import Tensor0.given + dimwit.initialize() trait Data derives Label @@ -301,12 +303,12 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co PythonSetup.initialize lazy val np = py.module("numpy") case class Sample( - input: Tensor2[Batch, Context, Int], - labels: Tensor2[Batch, Context, Int] + input: Tensor2[Batch, Context, Int32], + labels: Tensor2[Batch, Context, Int32] ) def createDataset(key: Random.Key, pathToBinaryFile: String): Iterator[Sample] = - val data = Tensor1.fromPy(Axis[Data], VType[Int])(Jax.jnp.asarray(np.memmap(pathToBinaryFile, dtype = np.uint16, mode = "r"))) - def sliceContextBlockAt(idx: Tensor0[Int]): Tensor1[Context, Int] = + val data = dimwit.python.PyBridge.liftPyTensor1(Axis[Data], VType[Int32])(Jax.jnp.asarray(np.memmap(pathToBinaryFile, dtype = np.uint16, mode = "r"))) + def sliceContextBlockAt(idx: Tensor0[Int32]): Tensor1[Context, Int32] = data .dynamicSlice(idx, contextLength) .relabelTo(Axis[Context]) @@ -329,6 +331,7 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co val initParams = GPT2Params.init(Random.Key(42)) + import Tensor0.given val adam = Adam(learningRate = learningRate, b1 = beta1, b2 = beta2, epsilon = 1e-8f) val adamW = AdamW(adam, weightDecayFactor = 1e-1f) type AdamWState = adamW.State[GPT2Params] @@ -336,10 +339,10 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co case class TrainingState( params: GPT2Params, adamWState: AdamWState, - loss: Tensor0[Float] + loss: Tensor0[Float32] ) - def batchLoss(input: Tensor2[Batch, Context, Int], labels: Tensor2[Batch, Context, Int])(params: GPT2Params): Tensor0[Float] = + def batchLoss(input: Tensor2[Batch, Context, Int32], labels: Tensor2[Batch, Context, Int32])(params: GPT2Params): Tensor0[Float32] = val model = GPT2(params) val logits = input.vmap(Axis[Batch])(model.logits) val lossPerSample = zipvmap(Axis[Batch])(labels, logits): (labels, logits) => @@ -349,8 +352,8 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co lossPerSample.mean def gradientStep( - input: Tensor2[Batch, Context, Int], - labels: Tensor2[Batch, Context, Int], + input: Tensor2[Batch, Context, Int32], + labels: Tensor2[Batch, Context, Int32], state: TrainingState ): TrainingState = val lossBatch = batchLoss(input, labels) @@ -360,7 +363,7 @@ case class GPT2(params: GPT2Params) extends (Tensor1[Context, Int] => Tensor1[Co TrainingState(params = params, adamWState = adamWState, loss = loss) val jitStep = jitDonatingUnsafe(gradientStep) - def evaluate(input: Tensor2[Batch, Context, Int], labels: Tensor2[Batch, Context, Int], params: GPT2Params): Tensor0[Float] = + def evaluate(input: Tensor2[Batch, Context, Int32], labels: Tensor2[Batch, Context, Int32], params: GPT2Params): Tensor0[Float32] = batchLoss(input, labels)(params) // val evalF = jit(evaluate) val evalF = eagerCleanup(evaluate) diff --git a/examples/src/main/scala/complex/VariationalAutoencoder.scala b/examples/src/main/scala/complex/VariationalAutoencoder.scala index a5c53ca9..3f5995fd 100644 --- a/examples/src/main/scala/complex/VariationalAutoencoder.scala +++ b/examples/src/main/scala/complex/VariationalAutoencoder.scala @@ -3,9 +3,9 @@ package examples.complex.vae import examples.timed import dimwit.* +import dimwit.Conversions.given import dimwit.autodiff.* import dimwit.autodiff.FloatTree.* -import dimwit.Conversions.given import dimwit.stats.Normal import dimwit.random.Random import examples.dataset.MNISTLoader @@ -36,7 +36,7 @@ class Encoder(p: Encoder.Params): val meanLayer = LinearLayer(p.meanLayer) val logVarLayer = LinearLayer(p.logVarLayer) - def apply(v: Tensor1[Pixel, Float]): (Tensor1[Latent, Float], Tensor1[Latent, Float]) = + def apply(v: Tensor1[Pixel, Float32]): (Tensor1[Latent, Float32], Tensor1[Latent, Float32]) = val h1 = relu(layer1(v)) val h2 = relu(layer2(h1)) val mean = meanLayer(h2) @@ -57,7 +57,7 @@ class Decoder(p: Decoder.Params): val layer2 = LinearLayer(p.layer2) val outputLayer = LinearLayer(p.outputLayer) - def apply(v: Tensor1[Latent, Float]): Tensor1[ReconstructedPixel, Float] = + def apply(v: Tensor1[Latent, Float32]): Tensor1[ReconstructedPixel, Float32] = val h1 = relu(layer1(v)) val h2 = relu(layer2(h1)) sigmoid(outputLayer(h2)) @@ -69,7 +69,7 @@ object Decoder: outputLayer: LinearLayer.Params[DHidden2, ReconstructedPixel] ) -def reparametrize(mean: Tensor1[Latent, Float], logVar: Tensor1[Latent, Float], key: Random.Key): Tensor1[Latent, Float] = +def reparametrize(mean: Tensor1[Latent, Float32], logVar: Tensor1[Latent, Float32], key: Random.Key): Tensor1[Latent, Float32] = val std = (logVar *! 0.5f).exp Normal(mean, std).sample(key) @@ -78,13 +78,13 @@ case class VariationalAutoencoder(params: VariationalAutoencoder.Params): val encoder = Encoder(params.encoderParams) val decoder = Decoder(params.decoderParams) - def apply(pixels: Tensor1[Pixel, Float], key: Random.Key): (Tensor1[ReconstructedPixel, Float], Tensor1[Latent, Float], Tensor1[Latent, Float]) = + def apply(pixels: Tensor1[Pixel, Float32], key: Random.Key): (Tensor1[ReconstructedPixel, Float32], Tensor1[Latent, Float32], Tensor1[Latent, Float32]) = val (mean, logVar) = encoder(pixels) val latent = reparametrize(mean, logVar, key) val reconstructedPixels = decoder(latent) (reconstructedPixels, mean, logVar) - def loss(original: Tensor1[Pixel, Float], key: Random.Key): Tensor0[Float] = + def loss(original: Tensor1[Pixel, Float32], key: Random.Key): Tensor0[Float32] = val (reconstructedPixels, mean, logVar) = apply(original, key) val eps = 1e-5f val reconstructionLoss = -((original * (reconstructedPixels +! eps).log) + ((1f -! original) * (1f -! reconstructedPixels +! eps).log)).sum @@ -172,7 +172,7 @@ object VariationalAutoencoderExample: /* * Training */ - def batchLoss[S <: Sample: Label](key: Random.Key, trainData: Tensor3[S, Height, Width, Float])(params: Params): Tensor0[Float] = + def batchLoss[S <: Sample: Label](key: Random.Key, trainData: Tensor3[S, Height, Width, Float32])(params: Params): Tensor0[Float32] = val vae = VariationalAutoencoder(params) val batchSize = trainData.shape.extent(Axis[S]).size val keys = key.splitToTensor(Axis[S] -> batchSize) @@ -183,7 +183,7 @@ object VariationalAutoencoderExample: val batches = trainImages.chunk(Axis[TrainSample], numSamples / batchSize) val optimizer = GradientDescent(learningRate = Tensor0(learningRate)) - def trainBatch(trainKey: Random.Key, batch: Tensor3[TrainSample, Height, Width, Float], params: Params): Params = + def trainBatch(trainKey: Random.Key, batch: Tensor3[TrainSample, Height, Width, Float32], params: Params): Params = val grads = grad(batchLoss(trainKey, batch))(params) val (newParams, _) = optimizer.update(grads, params, ()) newParams @@ -200,7 +200,7 @@ object VariationalAutoencoderExample: val keysForEpochs = dataKey.split(numEpochs) - val initialParams = Params(encoderParams, decoderParams).map([T <: Tuple] => (n: Labels[T]) ?=> (t: Tensor[T, Float]) => t *! 0.1f) + val initialParams = Params(encoderParams, decoderParams).map([T <: Tuple] => (n: Labels[T]) ?=> (t: Tensor[T, Float32]) => t *! 0.1f) val trainedParams = (0 until numEpochs).foldLeft(initialParams): case (params, epoch) => @@ -214,7 +214,7 @@ object VariationalAutoencoderExample: /* * Evaluation */ - def plotImg[H, W](img2d: Tensor2[H, W, Float]): Unit = + def plotImg[H, W](img2d: Tensor2[H, W, Float32]): Unit = import me.shadaj.scalapy.py val plt = py.module("matplotlib.pyplot") plt.imshow(toPyTensor(img2d), cmap = "gray") diff --git a/examples/src/main/scala/dataset/MNISTLoader.scala b/examples/src/main/scala/dataset/MNISTLoader.scala index 1acf50d4..2fc0112e 100644 --- a/examples/src/main/scala/dataset/MNISTLoader.scala +++ b/examples/src/main/scala/dataset/MNISTLoader.scala @@ -18,7 +18,7 @@ object MNISTLoader: private val pythonLoader = py.eval("lambda b64, shape: __import__('jax').numpy.array(__import__('numpy').frombuffer(__import__('base64').b64decode(b64), dtype=__import__('numpy').uint8).reshape(shape).astype(__import__('numpy').int32))") - def loadImages[S <: Sample: Label](filename: String, maxImages: Option[Int] = None): Tensor3[S, Height, Width, Int] = + def loadImages[S <: Sample: Label](filename: String, maxImages: Option[Int] = None): Tensor3[S, Height, Width, UInt8] = val file = new RandomAccessFile(filename, "r") try val magic = file.readInt() @@ -37,16 +37,12 @@ object MNISTLoader: file.readFully(pixels) val shape = Shape(Axis[S] -> numImages, Axis[Height] -> rows, Axis[Width] -> cols) - - // MNIST pixels are unsigned bytes - // So we read them as Byte and interpret as UInt8 when creating the Tensor - given ExecutionType[Byte] = ExecutionTypeFor[Byte](DType.UInt8) - Tensor(shape).fromArray(pixels) + Tensor(shape, VType[UInt8]).fromArray(pixels) finally file.close() - def loadLabels[S <: Sample: Label](filename: String, maxLabels: Option[Int] = None): Tensor1[S, Int] = + def loadLabels[S <: Sample: Label](filename: String, maxLabels: Option[Int] = None): Tensor1[S, Int8] = val file = new RandomAccessFile(filename, "r") try val magic = file.readInt() @@ -66,20 +62,20 @@ object MNISTLoader: finally file.close() - private def createDataset[S <: Sample: Label](imagesFile: String, labelsFile: String, maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(S, Height, Width), Float], Tensor1[S, Int]]] = + private def createDataset[S <: Sample: Label](imagesFile: String, labelsFile: String, maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(S, Height, Width), Float32], Tensor1[S, Int8]]] = Try: val images = loadImages[S](imagesFile, maxSamples) val labels = loadLabels[S](labelsFile, maxSamples) require(images.shape(Axis[S]) == labels.shape(Axis[S]), s"Number of images and labels must match") - val imagesFloat = images.asFloat /! 255.0f + val imagesFloat = images.asFloat32 /! 255.0f (imagesFloat, labels) - def createTrainingDataset(dataDir: String = "data", maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(TrainSample, Height, Width), Float], Tensor1[TrainSample, Int]]] = + def createTrainingDataset(dataDir: String = "data", maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(TrainSample, Height, Width), Float32], Tensor1[TrainSample, Int8]]] = val imagesFile = s"$dataDir/train-images-idx3-ubyte" val labelsFile = s"$dataDir/train-labels-idx1-ubyte" createDataset[TrainSample](imagesFile, labelsFile, maxSamples) - def createTestDataset(dataDir: String = "data", maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(TestSample, Height, Width), Float], Tensor1[TestSample, Int]]] = + def createTestDataset(dataDir: String = "data", maxSamples: Option[Int] = None): Try[Tuple2[Tensor[(TestSample, Height, Width), Float32], Tensor1[TestSample, Int8]]] = val imagesFile = s"$dataDir/t10k-images-idx3-ubyte" val labelsFile = s"$dataDir/t10k-labels-idx1-ubyte" createDataset[TestSample](imagesFile, labelsFile, maxSamples) diff --git a/mdocs/AGENTS.md b/mdocs/AGENTS.md index df84146f..178764f6 100644 --- a/mdocs/AGENTS.md +++ b/mdocs/AGENTS.md @@ -137,18 +137,18 @@ val t2dNested = Tensor2(Axis[A], Axis[B]).fromArray( ```scala mdoc:silent // Scalar (0D) -val scalar: Tensor0[Float] = Tensor0(42.0f) +val scalar: Tensor0[Float32] = Tensor0(42.0f) // Vector (1D) -val vector: Tensor1[Feature, Float] = Tensor1(Axis[Feature]).fromArray(Array(1.0f, 2.0f, 3.0f)) +val vector: Tensor1[Feature, Float32] = Tensor1(Axis[Feature]).fromArray(Array(1.0f, 2.0f, 3.0f)) // Matrix (2D) -val matrix: Tensor2[Batch, Feature, Float] = Tensor2(Axis[Batch], Axis[Feature]).fromArray( +val matrix: Tensor2[Batch, Feature, Float32] = Tensor2(Axis[Batch], Axis[Feature]).fromArray( Array(Array(1.0f, 2.0f), Array(3.0f, 4.0f)) ) // 3D Tensor -val tensor3d: Tensor3[Batch, Feature, Hidden, Float] = +val tensor3d: Tensor3[Batch, Feature, Hidden, Float32] = Tensor(Shape3(Axis[Batch] -> 2, Axis[Feature] -> 3, Axis[Hidden] -> 4)).fill(0.0f) ``` @@ -219,20 +219,20 @@ val data = Tensor2(Axis[A], Axis[B]).fromArray( ) // Reduce to scalar -val totalSum: Tensor0[Float] = data.sum -val totalMean: Tensor0[Float] = data.mean -val totalMax: Tensor0[Float] = data.max -val totalMin: Tensor0[Float] = data.min -val totalStd: Tensor0[Float] = data.std +val totalSum: Tensor0[Float32] = data.sum +val totalMean: Tensor0[Float32] = data.mean +val totalMax: Tensor0[Float32] = data.max +val totalMin: Tensor0[Float32] = data.min +val totalStd: Tensor0[Float32] = data.std println(s"Sum: ${totalSum.item}") // 21.0 // Reduce along axis A (across rows) -val sumA: Tensor1[B, Float] = data.sum(Axis[A]) +val sumA: Tensor1[B, Float32] = data.sum(Axis[A]) println(s"Sum along A: ${sumA}") // [5.0, 7.0, 9.0] // Reduce along axis B (across columns) -val sumB: Tensor1[A, Float] = data.sum(Axis[B]) +val sumB: Tensor1[A, Float32] = data.sum(Axis[B]) println(s"Sum along B: ${sumB}") // [6.0, 15.0] // Mean along axes @@ -240,8 +240,8 @@ val meanA = data.mean(Axis[A]) // [2.5, 3.5, 4.5] val meanB = data.mean(Axis[B]) // [2.0, 5.0] // Argmax / Argmin (returns indices) -val argmaxB: Tensor1[A, Int] = data.argmax(Axis[B]) -val argminB: Tensor1[A, Int] = data.argmin(Axis[B]) +val argmaxB: Tensor1[A, Int32] = data.argmax(Axis[B]) +val argminB: Tensor1[A, Int32] = data.argmin(Axis[B]) ``` **Error: Reducing on non-existent axis** @@ -299,7 +299,7 @@ trait D derives Label // Dot product (vector · vector) val v1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f, 3.0f)) val v2 = Tensor1(Axis[A]).fromArray(Array(4.0f, 5.0f, 6.0f)) -val dotProduct: Tensor0[Float] = v1.dot(Axis[A])(v2) +val dotProduct: Tensor0[Float32] = v1.dot(Axis[A])(v2) println(s"Dot product: ${dotProduct.item}") // 32.0 // Matrix-vector multiplication @@ -338,7 +338,7 @@ val original = Tensor2(Axis[A], Axis[B]).fromArray( ) // Transpose -val transposed: Tensor2[B, A, Float] = original.transpose +val transposed: Tensor2[B, A, Float32] = original.transpose println(s"Original shape: ${original.shape}") println(s"Transposed shape: ${transposed.shape}") @@ -386,20 +386,20 @@ val data = Tensor2(Axis[Batch], Axis[Feature]).fromArray( ) // Normalize each sample (row) independently -def normalize(x: Tensor1[Feature, Float]): Tensor1[Feature, Float] = +def normalize(x: Tensor1[Feature, Float32]): Tensor1[Feature, Float32] = val mean = x.mean val std = x.std + Tensor0(1e-6f) // Avoid division by zero (x -! mean) /! std -val normalized: Tensor2[Batch, Feature, Float] = data.vmap(Axis[Batch])(normalize) +val normalized: Tensor2[Batch, Feature, Float32] = data.vmap(Axis[Batch])(normalize) println(s"Normalized data: ${normalized}") // Sum each row -val rowSums: Tensor1[Batch, Float] = data.vmap(Axis[Batch])(_.sum) +val rowSums: Tensor1[Batch, Float32] = data.vmap(Axis[Batch])(_.sum) println(s"Row sums: ${rowSums}") // [6.0, 15.0, 24.0] // vmap over columns (note: axis moves to front) -val colSums: Tensor1[Feature, Float] = data.vmap(Axis[Feature])(_.sum) +val colSums: Tensor1[Feature, Float32] = data.vmap(Axis[Feature])(_.sum) println(s"Column sums: ${colSums}") // [12.0, 15.0, 18.0] // Identity vmap doesn't change data, only axis order @@ -412,7 +412,7 @@ val reordered = data.vmap(Axis[Feature])(x => x) // Same as data.transpose ```scala mdoc:fail val t = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f))) // ERROR: Function expects Tensor2, but vmap provides Tensor1 -def wrongFunc(x: Tensor2[A, B, Float]): Tensor0[Float] = x.sum +def wrongFunc(x: Tensor2[A, B, Float32]): Tensor0[Float32] = x.sum val wrong = t.vmap(Axis[A])(wrongFunc) ``` @@ -428,10 +428,10 @@ val t1 = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(1.0f, 2.0f), Array(3.0f val t2 = Tensor2(Axis[A], Axis[B]).fromArray(Array(Array(10.0f, 20.0f), Array(30.0f, 40.0f))) // Compute L2 distance between corresponding rows -def l2Distance(v1: Tensor1[B, Float], v2: Tensor1[B, Float]): Tensor0[Float] = +def l2Distance(v1: Tensor1[B, Float32], v2: Tensor1[B, Float32]): Tensor0[Float32] = (v1 - v2).pow(Tensor0(2.0f)).sum.sqrt -val distances: Tensor1[A, Float] = zipvmap(Axis[A])(t1, t2)(l2Distance) +val distances: Tensor1[A, Float32] = zipvmap(Axis[A])(t1, t2)(l2Distance) println(s"L2 distances: ${distances}") // zipvmap with 4 tensors @@ -495,7 +495,7 @@ trait A derives Label import dimwit.autodiff.Autodiff // Scalar function: f(x) = x² -def f(x: Tensor0[Float]): Tensor0[Float] = x * x +def f(x: Tensor0[Float32]): Tensor0[Float32] = x * x val df = Autodiff.grad(f) val x = Tensor0(3.0f) @@ -503,7 +503,7 @@ val gradient = df(x) println(s"df/dx at x=3: ${gradient.value.item}") // 6.0 // Vector function: f(x) = sum(x²) -def g(x: Tensor1[A, Float]): Tensor0[Float] = (x * x).sum +def g(x: Tensor1[A, Float32]): Tensor0[Float32] = (x * x).sum val dg = Autodiff.grad(g) val xVec = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f, 3.0f)) @@ -514,16 +514,16 @@ println(s"dg/dx: ${vecGradient}") // [2.0, 4.0, 6.0] ### Higher-Order Derivatives ```scala mdoc:silent -def f2(x: Tensor0[Float]): Tensor0[Float] = x * x +def f2(x: Tensor0[Float32]): Tensor0[Float32] = x * x // First derivative val df2 = Autodiff.grad(f2) // Second derivative -val ddf2 = Autodiff.grad((x: Tensor0[Float]) => df2(x).value) +val ddf2 = Autodiff.grad((x: Tensor0[Float32]) => df2(x).value) // Third derivative -val dddf2 = Autodiff.grad((x: Tensor0[Float]) => ddf2(x).value) +val dddf2 = Autodiff.grad((x: Tensor0[Float32]) => ddf2(x).value) val x2 = Tensor0(3.0f) println(s"f''(3) = ${ddf2(x2).value.item}") // 2.0 @@ -534,7 +534,7 @@ println(s"f'''(3) = ${dddf2(x2).value.item}") // 0.0 ```scala mdoc:silent // f(x, y) = (x + 2y)² -def twoParam(x: Tensor1[A, Float], y: Tensor1[A, Float]): Tensor0[Float] = +def twoParam(x: Tensor1[A, Float32], y: Tensor1[A, Float32]): Tensor0[Float32] = ((x + (y *! Tensor0(2.0f))).pow(Tensor0(2.0f))).sum val dtwoParam = Autodiff.grad(twoParam) @@ -557,7 +557,7 @@ trait Batch derives Label trait Feature derives Label // Gradient works seamlessly with vmap -def batched(x: Tensor2[Batch, Feature, Float]): Tensor0[Float] = +def batched(x: Tensor2[Batch, Feature, Float32]): Tensor0[Float32] = x.vmap(Axis[Batch])(_.sum).sum val dBatched = Autodiff.grad(batched) @@ -579,17 +579,17 @@ trait Feature derives Label trait Hidden derives Label case class LinearParams( - weight: Tensor2[Feature, Hidden, Float], - bias: Tensor1[Hidden, Float] + weight: Tensor2[Feature, Hidden, Float32], + bias: Tensor1[Hidden, Float32] ) // Define a model -def linear(params: LinearParams)(input: Tensor1[Feature, Float]): Tensor1[Hidden, Float] = +def linear(params: LinearParams)(input: Tensor1[Feature, Float32]): Tensor1[Hidden, Float32] = val weighted = params.weight.transpose.dot(Axis[Feature])(input) weighted + params.bias // Define loss function -def loss(data: Tensor1[Feature, Float], target: Tensor1[Hidden, Float])(params: LinearParams): Tensor0[Float] = +def loss(data: Tensor1[Feature, Float32], target: Tensor1[Hidden, Float32])(params: LinearParams): Tensor0[Float32] = val prediction = linear(params)(data) (prediction - target).pow(Tensor0(2.0f)).sum @@ -611,7 +611,7 @@ println(s"Bias gradient shape: ${paramGradients.bias.shape}") ```scala mdoc:fail // ERROR: Cannot differentiate with respect to Int tensors -def intFunc(x: Tensor1[A, Int]): Tensor0[Int] = x.sum +def intFunc(x: Tensor1[A, Int32]): Tensor0[Int32] = x.sum val wrong = Autodiff.grad(intFunc) ``` @@ -624,7 +624,7 @@ import dimwit.autodiff.* trait A derives Label // Jacobian of f: R² -> R², f(x) = 2x -def linearMap(x: Tensor1[A, Float]): Tensor1[A, Float] = x *! Tensor0(2.0f) +def linearMap(x: Tensor1[A, Float32]): Tensor1[A, Float32] = x *! Tensor0(2.0f) val jacobian = Autodiff.jacobian(linearMap) val xJac = Tensor1(Axis[A]).fromArray(Array(1.0f, 1.0f)) @@ -651,11 +651,11 @@ trait Feature derives Label trait Batch derives Label // Define model parameters -case class SimpleModelParams(w: Tensor1[Feature, Float], b: Tensor0[Float]) +case class SimpleModelParams(w: Tensor1[Feature, Float32], b: Tensor0[Float32]) // Define loss function -def mse(data: Tensor2[Batch, Feature, Float], labels: Tensor1[Batch, Float]) - (params: SimpleModelParams): Tensor0[Float] = +def mse(data: Tensor2[Batch, Feature, Float32], labels: Tensor1[Batch, Float32]) + (params: SimpleModelParams): Tensor0[Float32] = val predictions = data.vmap(Axis[Batch]) { sample => sample.dot(Axis[Feature])(params.w) + params.b } @@ -720,11 +720,11 @@ val yData = Tensor1(Axis[Sample]).fromArray( ) // Model parameters -case class RegressionParams(slope: Tensor1[InputDim, Float], intercept: Tensor0[Float]) +case class RegressionParams(slope: Tensor1[InputDim, Float32], intercept: Tensor0[Float32]) // Loss function (MSE) -def regressionLoss(x: Tensor2[Sample, InputDim, Float], y: Tensor1[Sample, Float]) - (params: RegressionParams): Tensor0[Float] = +def regressionLoss(x: Tensor2[Sample, InputDim, Float32], y: Tensor1[Sample, Float32]) + (params: RegressionParams): Tensor0[Float32] = val predictions = x.vmap(Axis[Sample]) { xi => xi.dot(Axis[InputDim])(params.slope) + params.intercept } @@ -764,7 +764,7 @@ import dimwit.jax.Jit trait A derives Label // Define a complex function -def complexComputation(x: Tensor1[A, Float]): Tensor1[A, Float] = +def complexComputation(x: Tensor1[A, Float32]): Tensor1[A, Float32] = val y = x.exp val z = y.log val w = z.sin @@ -791,7 +791,7 @@ import dimwit.jitDonating import dimwit.jitDonatingUnsafe // jitDonating allows reusing input memory -def inPlaceOp(x: Tensor1[A, Float]): Tensor1[A, Float] = x *! Tensor0(2.0f) +def inPlaceOp(x: Tensor1[A, Float32]): Tensor1[A, Float32] = x *! Tensor0(2.0f) val (jitDonate, jitF, jitReclaim) = jitDonating(inPlaceOp) val inputDonate = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) val donated = jitDonate(inputDonate) @@ -852,13 +852,16 @@ This section demonstrates **compile-time** and **runtime** errors to help coding ### Type Constraint Violations -```scala mdoc:reset:fail +```scala mdoc:reset import dimwit.* trait A derives Label trait B derives Label trait C derives Label trait D derives Label +``` + +```scala mdoc:fail // ERROR: Cannot perform floating-point operations on Int tensors val intTensor = Tensor1(Axis[A]).fromArray(Array(1, 2, 3)) val wrong = intTensor.exp // exp requires IsFloating constraint @@ -899,7 +902,7 @@ val wrong = t + 10.0f // Should use +! for scalar broadcast val t1 = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) val t2 = Tensor1(Axis[A]).fromArray(Array(3.0f, 4.0f)) // This works but is semantically wrong (use + instead) -// val result = t1 +! t2 // Compiles but misleading +val wrong = t1 +! t2 ``` ### Axis Errors @@ -920,13 +923,13 @@ val wrong = t.vmap(Axis[C])(_.sum) // Axis[C] doesn't exist ```scala mdoc:fail // ERROR: Cannot differentiate Integer functions -def intFunc(x: Tensor0[Int]): Tensor0[Int] = x + x +def intFunc(x: Tensor0[Int32]): Tensor0[Int32] = x + x val wrong = Autodiff.grad(intFunc) ``` ```scala mdoc:fail // ERROR: Function doesn't return scalar for grad -def nonScalar(x: Tensor1[A, Float]): Tensor1[A, Float] = x * x +def nonScalar(x: Tensor1[A, Float32]): Tensor1[A, Float32] = x * x val wrong = Autodiff.grad(nonScalar) // Use jacobian instead ``` @@ -964,7 +967,7 @@ val embeddings = Tensor(Shape3(Axis[Batch] -> 8, Axis[SeqLen] -> 128, Axis[Embed trait Feature derives Label // GOOD: Clear type signatures -def process(input: Tensor2[Batch, Feature, Float]): Tensor1[Batch, Float] = +def process(input: Tensor2[Batch, Feature, Float32]): Tensor1[Batch, Float32] = input.vmap(Axis[Batch])(_.sum) // AVOID: Opaque Tensor types without explicit parameters @@ -979,9 +982,9 @@ trait Output derives Label // GOOD: Structured parameters with TensorTree case class ModelParams( - encoder: Tensor2[InputDim, Hidden, Float], - decoder: Tensor2[Hidden, Output, Float], - bias: Tensor1[Output, Float] + encoder: Tensor2[InputDim, Hidden, Float32], + decoder: Tensor2[Hidden, Output, Float32], + bias: Tensor1[Output, Float32] ) // AVOID: Tuples or loose parameters @@ -1029,7 +1032,7 @@ import dimwit.jax.Jit.jit trait Input derives Label -val simpleFunc = (x: Tensor1[Input, Float]) => x *! Tensor0(2.0f) +val simpleFunc = (x: Tensor1[Input, Float32]) => x *! Tensor0(2.0f) // GOOD: JIT for repeated calls val jitFunc = jit(simpleFunc) diff --git a/mdocs/README.md b/mdocs/README.md index a702e019..a48e8a1b 100644 --- a/mdocs/README.md +++ b/mdocs/README.md @@ -45,11 +45,11 @@ val t = Tensor( ) // Function to normalize a single feature vector -def normalize(x: Tensor1[Feature, Float]) : Tensor1[Feature, Float] = +def normalize(x: Tensor1[Feature, Float32]) : Tensor1[Feature, Float32] = (x -! x.mean) /! x.std // Apply the normalization function across the Batch dimension -val normalized: Tensor2[Batch, Feature, Float] = +val normalized: Tensor2[Batch, Feature, Float32] = t.vmap(Axis[Batch])(normalize) ``` diff --git a/mdocs/docs/quickstart.md b/mdocs/docs/quickstart.md index 5ec9a6e0..2d248f9d 100644 --- a/mdocs/docs/quickstart.md +++ b/mdocs/docs/quickstart.md @@ -19,19 +19,19 @@ trait Batch derives Label trait Feature derives Label // parameters are explicitly defined and usually bundled in a case class -case class Params(w: Tensor1[Feature, Float], b: Tensor0[Float]) derives TensorTree +case class Params(w: Tensor1[Feature, Float32], b: Tensor0[Float32]) derives TensorTree // the model as a function of data and parametesrs -def model(x: Tensor2[Batch, Feature, Float], y: Tensor1[Batch, Float])(params: Params): Tensor1[Batch, Float] = +def model(x: Tensor2[Batch, Feature, Float32], y: Tensor1[Batch, Float32])(params: Params): Tensor1[Batch, Float32] = x.dot(Axis[Feature])(params.w) +! params.b // the loss function as a function of data and parameters -def loss(x: Tensor2[Batch, Feature, Float], y: Tensor1[Batch, Float])(params: Params): Tensor0[Float] = +def loss(x: Tensor2[Batch, Feature, Float32], y: Tensor1[Batch, Float32])(params: Params): Tensor0[Float32] = val pred = model(x, y)(params) (pred - y).pow(Tensor0(2.0f)).mean // the training loop, which produces an iterator of parameters -def fit(x: Tensor2[Batch, Feature, Float], y: Tensor1[Batch, Float]): Iterator[Params] = +def fit(x: Tensor2[Batch, Feature, Float32], y: Tensor1[Batch, Float32]): Iterator[Params] = // initialize parameters val p0 = Params( @@ -146,7 +146,7 @@ Let's inspect the tensor more closely: ```scala mdoc tensor ``` -We see that the full type of the tensor is `Tensor[(Batch, Feature), Float]`, which indicates that the tensor has two axes, `Batch` and `Feature`, and that the data type of the tensor is `Float`. As this type is rather bulky to write, we can also use the convenient type aliases `Tensor2`, `Tensor1` and `Tensor0` for tensors of rank 2, rank 1 and rank 0 respectively. In this case, the type of the tensor could also be written as `Tensor2[Batch, Feature, Float]`. +We see that the full type of the tensor is `Tensor[(Batch, Feature), Float32]`, which indicates that the tensor has two axes, `Batch` and `Feature`, and that the data type of the tensor is `Float`. As this type is rather bulky to write, we can also use the convenient type aliases `Tensor2`, `Tensor1` and `Tensor0` for tensors of rank 2, rank 1 and rank 0 respectively. In this case, the type of the tensor could also be written as `Tensor2[Batch, Feature, Float32]`. In addition to the type aliases, we also have convenient factory methods for tensors of rank 0 to 2. These allow us to create tensors without having to explicitly create the shape. @@ -179,8 +179,8 @@ val matrix = Tensor2(Axis[Feature], Axis[Batch]).fromArray( ##### A note on type annotations for tensors -In DimWit, the type of a tensor is represented at the type level as `Tensor[ShapeTuple, VType]`, where `ShapeTuple` is a tuple of the labels of the axes and `DataType` is the type of the data. Hence a tensor with shape `Shape(Axis[Batch] -> 3, Axis[Feature] -> 2)` and data type `Float` has the type `Tensor[(Batch, Feature), Float]`. To make type annotations more convenient, we have the type aliases `Tensor0`, `Tensor1` and `Tensor2`, etc. to -refer to Tensors of a specific rank. For example, a `Tensor[(Batch, Feature), Float]` can be referred to as `Tensor2[Batch, Feature, Float]`, a Tensor `Tensor[Tuple1[Batch], Float]` can be referred to as `Tensor1[Batch, Float]` and a `Tensor[EmptyTuple, Float]` can be referred to as `Tensor0[Float]`. +In DimWit, the type of a tensor is represented at the type level as `Tensor[ShapeTuple, VType]`, where `ShapeTuple` is a tuple of the labels of the axes and `DataType` is the type of the data. Hence a tensor with shape `Shape(Axis[Batch] -> 3, Axis[Feature] -> 2)` and data type `Float` has the type `Tensor[(Batch, Feature), Float32]`. To make type annotations more convenient, we have the type aliases `Tensor0`, `Tensor1` and `Tensor2`, etc. to +refer to Tensors of a specific rank. For example, a `Tensor[(Batch, Feature), Float32]` can be referred to as `Tensor2[Batch, Feature, Float32]`, a Tensor `Tensor[Tuple1[Batch], Float32]` can be referred to as `Tensor1[Batch, Float32]` and a `Tensor[EmptyTuple, Float32]` can be referred to as `Tensor0[Float32]`. ### Arithmetic Operations on Tensors and broadcasting @@ -230,8 +230,8 @@ axis to sum over using the labels of the axes, which ensures that we are summing The resulting tensor has the correct shape, which is inferred from the labels of the axes: ```scala mdoc -val sumOverB : Tensor1[A, Float] = tensor1.sum(Axis[B]) -val sumOverA : Tensor1[B, Float] = tensor1.sum(Axis[A]) +val sumOverB : Tensor1[A, Float32] = tensor1.sum(Axis[B]) +val sumOverA : Tensor1[B, Float32] = tensor1.sum(Axis[A]) ``` ### Transforming the shape of tensors @@ -253,12 +253,12 @@ val tensor = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2, Axis[C] -> 4)).fill(1.0f) #### Flattening and unflattening axes The first operation we will discuss is `flatten` which flattens part of the tensor into a single axis. Invoked without arguments, `flatten` will flatten all axes into a single axis, resulting in a Tensor1. ```scala mdoc:silent -val flattened : Tensor1[A |*| B |*| C, Float] = tensor.flatten +val flattened : Tensor1[A |*| B |*| C, Float32] = tensor.flatten ``` Note that the resulting axis has a label that is a combination of the labels of the original axes. When flattening, we can also specify which axes to flatten, and the resulting axis will have a label that is a combination of the labels of the flattened axes. For example, we can flatten only the last two axes as follows: ```scala mdoc:silent - val partiallyFlattened: Tensor2[A, B |*| C, Float] = tensor.flatten((Axis[B], Axis[C])) + val partiallyFlattened: Tensor2[A, B |*| C, Float32] = tensor.flatten((Axis[B], Axis[C])) ``` The counterpart of flatten is `unflatten`, which takes an axis that was previously flattened and restores the original axes. Since in the process of flattening we lost the information about the original shape, @@ -282,7 +282,7 @@ Given two tensors with the same shape except for one axis, we can concatenate th val part1 = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2)).fill(1.0f) val part2 = Tensor(Shape(Axis[A] -> 7, Axis[B] -> 2)).fill(1.0f) - val concatenated: Tensor[(A, B), Float] = concatenate(Seq(part1, part2), Axis[A]) + val concatenated: Tensor[(A, B), Float32] = concatenate(Seq(part1, part2), Axis[A]) ``` The concatenated tensor can be split back into the original tensors using the split method @@ -299,13 +299,13 @@ it returns a single tensor that is a slice of the original tensor. For example, at index 1 along axis A as follows: ```scala mdoc:silent - val sliced1 : Tensor1[B, Float]= concatenated.slice(Axis[A].at(1)) + val sliced1 : Tensor1[B, Float32]= concatenated.slice(Axis[A].at(1)) ``` As for split, we can also provide multiple a tuple (or sequence) of indices, which will then select all slices in the tuple. ```scala mdoc:silent - val slicedMultiple : Tensor2[A, B, Float] = concatenated.slice(Axis[A].at((0, 2))) + val slicedMultiple : Tensor2[A, B, Float32] = concatenated.slice(Axis[A].at((0, 2))) ``` #### Squeezing, Expanding and transposing axes @@ -315,17 +315,17 @@ method. ```scala mdoc:silent val squeezableTensor = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 1, Axis[C] -> 4)).fill(1.0f) -val squeezedTensor : Tensor[(A, C), Float] = squeezableTensor.squeeze(Axis[B]) +val squeezedTensor : Tensor[(A, C), Float32] = squeezableTensor.squeeze(Axis[B]) ``` Similarly, we can add a new axis with extent 1 using the method `appendAxis`: ```scala mdoc:silent -val appendedTensor : Tensor[(A, C, B), Float] = squeezedTensor.appendAxis(Axis[B]) +val appendedTensor : Tensor[(A, C, B), Float32] = squeezedTensor.appendAxis(Axis[B]) ``` Finally, we can permute the axes of a tensor using the `transpose` method. The order of the axes is the order of the labels in the tuple that we pass as an argument to the method. For example, we can reorder the above tensor to have the order of axes `A, B, C` as follows: ```scala mdoc:silent -val restoredTensor : Tensor[(A, B, C), Float] = appendedTensor.transpose((Axis[A], Axis[B], Axis[C])) +val restoredTensor : Tensor[(A, B, C), Float32] = appendedTensor.transpose((Axis[A], Axis[B], Axis[C])) ``` ### Mapping over axes @@ -348,13 +348,13 @@ val tensor = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2, Axis[C] -> 4)).fill(1.0f) The simplest method is `vapply`. `vapply` applies a function from a `Tensor1` to a `Tensor1` to each slice of the tensor along the specified axis. For example, we can apply the function that multiplies each element by 2 to each slice along axis A as follows: ```scala mdoc:silent -val doubled : Tensor3[A, B, C, Float] = tensor.vapply(Axis[A])((slice : Tensor1[A, Float]) => slice *! Tensor0(2.0f)) +val doubled : Tensor3[A, B, C, Float32] = tensor.vapply(Axis[A])((slice : Tensor1[A, Float32]) => slice *! Tensor0(2.0f)) ``` Similar to `vapply`is `vreduce`. `vreduce` applies a function that reduces a `Tensor1` to a `Tensor0` to each slice of the tensor along the specified axis. It effectively reduces the specified axis to a scalar. ```scala mdoc:silent -val summedA : Tensor2[B, C, Float] = tensor.vreduce(Axis[A])((slice : Tensor1[A, Float]) => slice.sum) +val summedA : Tensor2[B, C, Float32] = tensor.vreduce(Axis[A])((slice : Tensor1[A, Float32]) => slice.sum) ``` A more general method is `vmap`, which applies a function to the slice of the tensor along the specified axis. The function can return a tensor of any shape, not just a `Tensor1`. The resulting tensor will have the same shape as the original tensor, except that the specified axis will be replaced by the shape of the output of the function. The following example takes a Tensor2 as input and computes the @@ -362,7 +362,7 @@ mean of each slice along axis C ```scala mdoc:silent - val res : Tensor[(A, B), Float] = tensor.vmap(Axis[A])((slice : Tensor2[B, C, Float]) => slice.mean(Axis[C])) + val res : Tensor[(A, B), Float32] = tensor.vmap(Axis[A])((slice : Tensor2[B, C, Float32]) => slice.mean(Axis[C])) ``` `zipmap` is a variant of `vmap` that applies a function to the slices of multiple tensors along the specified axis. The function takes as input a tuple of slices, one from each tensor, and, as `vmap` returns a tensor of any shape. The resulting tensor will have the same shape as the original tensors, except that the specified axis will be replaced by the shape of the output of the function. For example, we can use `zipmap` to add two tensors along axis A as follows: @@ -370,7 +370,7 @@ mean of each slice along axis C val t1 = Tensor(Shape(Axis[A] -> 3, Axis[B] -> 2, Axis[C] -> 4)).fill(1.0f) val t2 = Tensor(Shape(Axis[A] -> 3, Axis[C] -> 3)).fill(2.0f) -val sumAlongA : Tensor1[A, Float] = zipvmap(Axis[A])(t1, t2)((s1: Tensor2[B, C, Float], s2: Tensor1[C, Float]) => s1.sum + s2.sum) +val sumAlongA : Tensor1[A, Float32] = zipvmap(Axis[A])(t1, t2)((s1: Tensor2[B, C, Float32], s2: Tensor1[C, Float32]) => s1.sum + s2.sum) ``` @@ -394,20 +394,20 @@ import Autodiff.grad Let's take a simple quadratic function as an example. ```scala mdoc:silent -def f(x: Tensor1[A, Float]): Tensor0[Float] = x.dot(Axis[A])(x) +def f(x: Tensor1[A, Float32]): Tensor0[Float32] = x.dot(Axis[A])(x) ``` To compute the gradient of this function with respect to its input, we can use the `grad` method as follows: ```scala mdoc:silent val x = Tensor1(Axis[A]).fromArray(Array(1.0f, 2.0f)) -val gradient : Tensor1[A, Float] => Grad[Tensor1[A, Float]] = grad(f) +val gradient : Tensor1[A, Float32] => Grad[Tensor1[A, Float32]] = grad(f) ``` Note that the result of `grad` is a function that takes as input a tensor and returns the gradient of the function with respect to that tensor. The gradient is a normal tensor, but wrapped in a `Grad` object, to make sure that we don't accidentally use it in a computation without realizing that it is a gradient. To get the actual tensor from the `Grad` object, we can use the `value` method as follows: ```scala mdoc:silent -val gradValue : Tensor1[A, Float] = gradient(x).value +val gradValue : Tensor1[A, Float32] = gradient(x).value ``` #### Tensor trees and gradients of multiple parameters @@ -426,13 +426,13 @@ trait Batch derives Label ``` ```scala mdoc:silent -case class Params(w: Tensor1[Feature, Float], b: Tensor0[Float]) derives TensorTree +case class Params(w: Tensor1[Feature, Float32], b: Tensor0[Float32]) derives TensorTree ``` To compute the gradient of a function that takes as input a tensor tree, we can use the same `grad` method, as long as the input type of the function is a tensor tree. The resulting gradient will then be a tensor tree of the same shape as the input tensor tree, as illustrated in the following example: ```scala mdoc:silent -def f(params: Params): Tensor0[Float] = params.w.dot(Axis[Feature])(params.w) + params.b.pow(Tensor0(2.0f)) +def f(params: Params): Tensor0[Float32] = params.w.dot(Axis[Feature])(params.w) + params.b.pow(Tensor0(2.0f)) val params = Params( w = Tensor1(Axis[Feature]).fromArray(Array(1.0f, 2.0f)), @@ -480,7 +480,7 @@ val keys = key.split(5) Often we need to create a sequence of random numbers. In this case, we can use the `splitvmap` method, which splits a key into a tensor of new keys and applies a function to each of the new keys. For example, we can create a vector of random numbers drawn from the normal distribution as follows: ```scala mdoc:silent -val sampleVec: Tensor1[A, Float] = key.splitvmap(Axis[A] -> 3)((k: Key) => normalDist.sample(k)) +val sampleVec: Tensor1[A, Float32] = key.splitvmap(Axis[A] -> 3)((k: Key) => normalDist.sample(k)) ``` A more flexible, but less performant way to create a tensor of keys and to use it together with the `vmap` or `zipvmap` method. For example, we can create a tensor of keys and use it to create a tensor of random numbers as follows: diff --git a/nn/src/main/scala/nn/Conv2DLayer.scala b/nn/src/main/scala/nn/Conv2DLayer.scala index bd2ec315..4e12b541 100644 --- a/nn/src/main/scala/nn/Conv2DLayer.scala +++ b/nn/src/main/scala/nn/Conv2DLayer.scala @@ -6,21 +6,21 @@ import dimwit.stats.Normal object Conv2DLayer: - case class Params[S1, S2, InChannel, OutChannel]( - kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, Float] + case class Params[S1, S2, InChannel, OutChannel, V]( + kernel: Tensor[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple, V] ) object Params: - given [S1: Label, S2: Label, InChannel: Label, OutChannel: Label]: TensorTree[Params[S1, S2, InChannel, OutChannel]] = TensorTree.derived + given [S1: Label, S2: Label, InChannel: Label, OutChannel: Label, V]: TensorTree[Params[S1, S2, InChannel, OutChannel, V]] = TensorTree.derived - def apply[S1: Label, S2: Label, InChannel: Label, OutChannel: Label](paramKey: Key)(kernelShape: Shape[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple])(using executionType: ExecutionType[Float]): Params[S1, S2, InChannel, OutChannel] = - Params(kernel = Normal.standardNormal(kernelShape).sample(paramKey)) + def apply[S1: Label, S2: Label, InChannel: Label, OutChannel: Label, V: IsFloating](paramKey: Key)(kernelShape: Shape[S1 *: S2 *: InChannel *: OutChannel *: EmptyTuple]): Params[S1, S2, InChannel, OutChannel, V] = + Params(kernel = Normal.standardNormal(kernelShape).sample(paramKey).asFloat(VType[V])) -case class Conv2DLayer[S1: Label, S2: Label, InChannel: Label, OutChannel: Label]( - params: Conv2DLayer.Params[S1, S2, InChannel, OutChannel], +case class Conv2DLayer[S1: Label, S2: Label, InChannel: Label, OutChannel: Label, V: IsFloating]( + params: Conv2DLayer.Params[S1, S2, InChannel, OutChannel, V], stride: Stride2[S1, S2] | Int = 1, padding: Padding = Padding.SAME ): - def apply(x: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, Float]): Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, Float] = + def apply(x: Tensor[S1 *: S2 *: InChannel *: EmptyTuple, V]): Tensor[S1 *: S2 *: OutChannel *: EmptyTuple, V] = x.conv2d(params.kernel, stride, padding) diff --git a/nn/src/main/scala/nn/GradientOptimizer.scala b/nn/src/main/scala/nn/GradientOptimizer.scala index 5237e1b2..d6a348a7 100644 --- a/nn/src/main/scala/nn/GradientOptimizer.scala +++ b/nn/src/main/scala/nn/GradientOptimizer.scala @@ -1,7 +1,6 @@ package nn import dimwit.* -import dimwit.Conversions.given import dimwit.autodiff.FloatTree.ops.* import dimwit.autodiff.FloatTree.* import dimwit.autodiff.* @@ -31,41 +30,40 @@ trait GradientOptimizer: type State[_] // Core API - def init[Params: TensorTree: FloatTree](params: Params): State[Params] - def update[Params: TensorTree: FloatTree](gradients: Grad[Params], params: Params, state: State[Params]): (Params, State[Params]) + def init[Params: TensorTree: FloatTreeFor[Float32]](params: Params): State[Params] + def update[Params: TensorTree: FloatTreeFor[Float32]](gradients: Grad[Params], params: Params, state: State[Params]): (Params, State[Params]) // Convenience: iterator with fixed gradient function - def iterateWithState[Params: TensorTree: FloatTree](init: Params)(df: Params => Grad[Params]): Iterator[(Params, State[Params])] = + def iterateWithState[Params: TensorTree: FloatTreeFor[Float32]](init: Params)(df: Params => Grad[Params]): Iterator[(Params, State[Params])] = Iterator.iterate((init, this.init(init))): (params, state) => val grads = df(params) update(grads, params, state) - def iterate[Params: TensorTree: FloatTree](init: Params)(df: Params => Grad[Params]): Iterator[Params] = + def iterate[Params: TensorTree: FloatTreeFor[Float32]](init: Params)(df: Params => Grad[Params]): Iterator[Params] = iterateWithState(init)(df).map(_._1) -case class GradientDescent(learningRate: Tensor0[Float]) extends GradientOptimizer: - import dimwit.Conversions.given +case class GradientDescent(learningRate: Tensor0[Float32]) extends GradientOptimizer: type State[P] = Unit // Stateless optimizer - def init[Params: TensorTree: FloatTree](params: Params): Unit = () + def init[Params: TensorTree: FloatTreeFor[Float32]](params: Params): Unit = () - def update[Params: TensorTree: FloatTree](gradients: Grad[Params], params: Params, state: Unit): (Params, Unit) = + def update[Params: TensorTree: FloatTreeFor[Float32]](gradients: Grad[Params], params: Params, state: Unit): (Params, Unit) = val newParams = params -- gradients.value.scale(learningRate) (newParams, ()) -case class Lion(learningRate: Tensor0[Float], weightDecay: Tensor0[Float] = Tensor0(0.0f), beta1: Tensor0[Float] = Tensor0(0.9f), beta2: Tensor0[Float] = Tensor0(0.99f)) extends GradientOptimizer: +case class Lion(learningRate: Tensor0[Float32], weightDecay: Tensor0[Float32] = Tensor0(0.0f), beta1: Tensor0[Float32] = Tensor0(0.9f), beta2: Tensor0[Float32] = Tensor0(0.99f)) extends GradientOptimizer: type State[P] = P // momentum state has same structure as params - def init[Params: TensorTree: FloatTree](params: Params): Params = + def init[Params: TensorTree: FloatTreeFor[Float32]](params: Params): Params = params.map([T <: Tuple] => (n: Labels[T]) ?=> - (t: Tensor[T, Float]) => + (t: Tensor[T, Float32]) => Tensor(t.shape).fill(0f) ) - def update[Params: TensorTree: FloatTree](gradients: Grad[Params], params: Params, momentums: Params): (Params, Params) = + def update[Params: TensorTree: FloatTreeFor[Float32]](gradients: Grad[Params], params: Params, momentums: Params): (Params, Params) = // the direction (1 or -1) // is determined by the sign of the momentum + gradient val updateDirection = (momentums **! beta1 ++ gradients.value **! (1f - beta1)).sign @@ -78,8 +76,8 @@ case class Lion(learningRate: Tensor0[Float], weightDecay: Tensor0[Float] = Tens case class AdamState[P]( momentums: P, // momentums velocities: P, // velocities - b1: Tensor0[Float], // decay rate for momentums mᵗ - b2: Tensor0[Float] // decay rate for velocities vᵗ + b1: Tensor0[Float32], // decay rate for momentums mᵗ + b2: Tensor0[Float32] // decay rate for velocities vᵗ ) /** Implements the Adam optimization algorithm. @@ -87,10 +85,10 @@ case class AdamState[P]( * @see [[https://arxiv.org/abs/1412.6980 Adam: A Method for Stochastic Optimization]] */ case class Adam( - learningRate: Tensor0[Float], // step size (learning rate) - b1: Tensor0[Float] = Tensor0(0.9f), // decay rate for momentums mᵗ - b2: Tensor0[Float] = Tensor0(0.999f), // decay rate for velocities vᵗ - epsilon: Tensor0[Float] = Tensor0(1e-8f) // small constant to prevent division by zero + learningRate: Tensor0[Float32], // step size (learning rate) + b1: Tensor0[Float32] = Tensor0(0.9f), // decay rate for momentums mᵗ + b2: Tensor0[Float32] = Tensor0(0.999f), // decay rate for velocities vᵗ + epsilon: Tensor0[Float32] = Tensor0(1e-8f) // small constant to prevent division by zero ) extends GradientOptimizer: private val β1 = b1 @@ -98,11 +96,11 @@ case class Adam( type State[P] = AdamState[P] - def init[Params: TensorTree: FloatTree](params: Params): State[Params] = + def init[Params: TensorTree: FloatTreeFor[Float32]](params: Params): State[Params] = def zeros = params.fillCopy(0f) AdamState(zeros, zeros, b1 = Tensor0(1f), b2 = Tensor0(1f)) - def update[Params: TensorTree: FloatTree]( + def update[Params: TensorTree: FloatTreeFor[Float32]]( gradients: Grad[Params], params: Params, state: State[Params] @@ -144,14 +142,14 @@ case class Adam( */ case class AdamW( val adam: Adam, - val weightDecayFactor: Tensor0[Float] + val weightDecayFactor: Tensor0[Float32] ) extends GradientOptimizer: type State[P] = adam.State[P] - def init[Params: TensorTree: FloatTree](params: Params): State[Params] = adam.init(params) + def init[Params: TensorTree: FloatTreeFor[Float32]](params: Params): State[Params] = adam.init(params) - def update[Params: TensorTree: FloatTree]( + def update[Params: TensorTree: FloatTreeFor[Float32]]( gradients: Grad[Params], params: Params, state: State[Params] diff --git a/nn/src/main/scala/nn/LinearLayer.scala b/nn/src/main/scala/nn/LinearLayer.scala index 0aaa1a7c..8934f997 100644 --- a/nn/src/main/scala/nn/LinearLayer.scala +++ b/nn/src/main/scala/nn/LinearLayer.scala @@ -4,12 +4,11 @@ import dimwit.* import dimwit.random.Random import dimwit.random.Random.Key import dimwit.tensor.VType -import dimwit.tensor.ExecutionType import dimwit.stats.Normal object LinearLayer: - case class Params[In, Out](weight: Tensor2[In, Out, Float], bias: Tensor1[Out, Float]) + case class Params[In, Out](weight: Tensor2[In, Out, Float32], bias: Tensor1[Out, Float32]) object Params: given [I: Label, O: Label]: TensorTree[Params[I, O]] = TensorTree.derived @@ -17,35 +16,31 @@ object LinearLayer: def apply[In: Label, Out: Label](paramKey: Key)( inputDim: AxisExtent[In], outputDim: AxisExtent[Out] - )(using - executionType: ExecutionType[Float] ): Params[In, Out] = Params( weight = Normal.standardNormal(Shape(inputDim, outputDim)).sample(paramKey), bias = Tensor(Shape(outputDim)).fill(0.0f) ) -case class LinearLayer[In: Label, Out: Label](params: LinearLayer.Params[In, Out]) extends Function[Tensor1[In, Float], Tensor1[Out, Float]]: - override def apply(x: Tensor1[In, Float]): Tensor1[Out, Float] = +case class LinearLayer[In: Label, Out: Label](params: LinearLayer.Params[In, Out]) extends Function[Tensor1[In, Float32], Tensor1[Out, Float32]]: + override def apply(x: Tensor1[In, Float32]): Tensor1[Out, Float32] = import params.{weight, bias} x.dot(Axis[In])(weight) + bias object LinearMap: - case class Params[In](weight: Tensor1[In, Float], bias: Tensor0[Float]) + case class Params[In](weight: Tensor1[In, Float32], bias: Tensor0[Float32]) object Params: given [In: Label]: TensorTree[Params[In]] = TensorTree.derived - def apply[In: Label](paramKey: Key)(inputDim: AxisExtent[In])(using - executionType: ExecutionType[Float] - ): Params[In] = + def apply[In: Label](paramKey: Key)(inputDim: AxisExtent[In]): Params[In] = Params( weight = Normal.standardNormal(Shape(inputDim)).sample(paramKey), bias = Tensor0(0.0f) ) -case class LinearMap[In: Label](params: LinearMap.Params[In]) extends Function[Tensor1[In, Float], Tensor0[Float]]: - override def apply(x: Tensor1[In, Float]): Tensor0[Float] = +case class LinearMap[In: Label](params: LinearMap.Params[In]) extends Function[Tensor1[In, Float32], Tensor0[Float32]]: + override def apply(x: Tensor1[In, Float32]): Tensor0[Float32] = import params.{weight, bias} x.dot(Axis[In])(weight) + bias diff --git a/nn/src/main/scala/nn/Loss.scala b/nn/src/main/scala/nn/Loss.scala index 872ace73..f24603ae 100644 --- a/nn/src/main/scala/nn/Loss.scala +++ b/nn/src/main/scala/nn/Loss.scala @@ -6,15 +6,15 @@ import nn.ActivationFunctions.softmax object Loss: // TODO move this to a more general utils place? - private def logsumexp[L: Label](logits: Tensor1[L, Float]): Tensor0[Float] = + private def logsumexp[L: Label](logits: Tensor1[L, Float32]): Tensor0[Float32] = val maxLogit = logits.max(Axis[L]) val logSumShifted = (logits -! maxLogit).exp.sum.log maxLogit + logSumShifted def crossEntropy[L: Label]( - logits: Tensor1[L, Float], - label: Tensor0[Int] - ): Tensor0[Float] = + logits: Tensor1[L, Float32], + label: Tensor0[Int32] + ): Tensor0[Float32] = val targetLogit = logits.slice(Axis[L].at(label)) val logNormalizer = logsumexp(logits) logNormalizer - targetLogit diff --git a/nn/src/main/scala/nn/TransposeConv2DLayer.scala b/nn/src/main/scala/nn/TransposeConv2DLayer.scala index 31e7a136..096f8067 100644 --- a/nn/src/main/scala/nn/TransposeConv2DLayer.scala +++ b/nn/src/main/scala/nn/TransposeConv2DLayer.scala @@ -7,7 +7,7 @@ import dimwit.stats.Normal object TransposeConvLayer: case class Params[S1, S2, InChannels, OutChannels]( - kernel: Tensor[S1 *: S2 *: InChannels *: OutChannels *: EmptyTuple, Float] + kernel: Tensor[S1 *: S2 *: InChannels *: OutChannels *: EmptyTuple, Float32] ) object Params: @@ -22,8 +22,6 @@ object TransposeConvLayer: */ def apply[S1: Label, S2: Label, InChannels: Label, OutChannels: Label](paramKey: Key)( kernelShape: Shape[S1 *: S2 *: InChannels *: OutChannels *: EmptyTuple] - )(using - executionType: ExecutionType[Float] ): Params[S1, S2, InChannels, OutChannels] = Params( kernel = Normal.standardNormal(kernelShape).sample(paramKey) @@ -40,5 +38,5 @@ case class TransposeConvLayer[S1: Label, S2: Label, InChannels: Label, OutChanne * * Input: (Spatial..., OutChannels) Output: (Spatial..., InChannels) */ - def apply(x: Tensor[S1 *: S2 *: OutChannels *: EmptyTuple, Float]): Tensor[S1 *: S2 *: InChannels *: EmptyTuple, Float] = + def apply(x: Tensor[S1 *: S2 *: OutChannels *: EmptyTuple, Float32]): Tensor[S1 *: S2 *: InChannels *: EmptyTuple, Float32] = x.transposeConv2d(params.kernel, stride, padding)