From 7fe63f082ee5f329a85c3b1e1a5ceb74013a41de Mon Sep 17 00:00:00 2001 From: Benjamin Meyer Date: Thu, 19 Mar 2026 18:04:20 +0100 Subject: [PATCH] Add foreachWithName to TensorTree --- .../scala/dimwit/autodiff/TensorTree.scala | 41 ++++++++++++- .../dimwit/autodiff/TensorTreeSuite.scala | 57 +++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/dimwit/autodiff/TensorTree.scala b/core/src/main/scala/dimwit/autodiff/TensorTree.scala index 9353b3f6..8ef033a2 100644 --- a/core/src/main/scala/dimwit/autodiff/TensorTree.scala +++ b/core/src/main/scala/dimwit/autodiff/TensorTree.scala @@ -3,17 +3,31 @@ package dimwit.autodiff import dimwit.tensor.* import scala.deriving.* import scala.compiletime.* +import scala.quoted.* // TODO hot fix with retag and context parameter... maybe this can be improved? trait TensorTree[P]: + def foreach(p: P, consumer: TensorTree.TensorConsumerWithPath, path: String = "", pathSeparator: String = TensorTree.DEFAULT_PATH_SEPARATOR): Unit + def fill(provider: TensorTree.TensorProvider, path: String = "", pathSeparator: String = TensorTree.DEFAULT_PATH_SEPARATOR): P def map(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): P def zipMap(p1: P, p2: P, f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): P object TensorTree extends TensorTreeLowPriority: + + val DEFAULT_PATH_SEPARATOR = "." + + type TensorConsumerWithPath = [T <: Tuple, V] => (Labels[T]) ?=> (String, Tensor[T, V]) => Unit + type TensorProvider = [T <: Tuple, V] => (Labels[T]) ?=> String => Tensor[T, V] + def apply[P](using pt: TensorTree[P]): TensorTree[P] = pt given [Q <: Tuple, V](using n: Labels[Q]): TensorTree[Tensor[Q, V]] with + def foreach(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> (String, Tensor[T, V2]) => Unit, path: String, pathSeparator: String): Unit = + f[Q, V](using n)(path, t) + + def fill(provider: TensorProvider, path: String, pathSeparator: String): Tensor[Q, V] = provider[Q, V](using n)(path) + def map(t: Tensor[Q, V], f: [T <: Tuple, V2] => (Labels[T]) ?=> (Tensor[T, V2] => Tensor[T, V2])): Tensor[Q, V] = import TensorOps.retag f[Q, V](using n)(t.retag[Q](using n)) @@ -25,12 +39,29 @@ object TensorTree extends TensorTreeLowPriority: inline given derived[P <: Product](using m: Mirror.ProductOf[P]): TensorTree[P] = val elemInstances = summonAll[Tuple.Map[m.MirroredElemTypes, TensorTree]] val instances = elemInstances.toList.asInstanceOf[List[TensorTree[Any]]] - derivedImpl(instances, m) + val fieldNames = constValueTuple[m.MirroredElemLabels].toList.map(_.toString) + derivedImpl(instances, fieldNames, m) private def derivedImpl[P <: Product]( instances: List[TensorTree[Any]], + fieldNames: List[String], m: Mirror.ProductOf[P] ): TensorTree[P] = new TensorTree[P]: + def foreach(p: P, consumer: TensorConsumerWithPath, path: String, pathSeparator: String): Unit = + val inputs = p.productIterator.toList + inputs.zip(instances).zip(fieldNames).foreach: + case ((elem, inst), fieldName) => + // Construct the nested path + val nextName = if path.isEmpty then fieldName else s"$path$pathSeparator$fieldName" + inst.foreach(elem, consumer, nextName, pathSeparator) + + def fill(provider: TensorProvider, path: String, pathSeparator: String): P = + val generatedElems = instances.zip(fieldNames).map: + case (inst, fieldName) => + val newPath = if path.isEmpty then fieldName else s"$path$pathSeparator$fieldName" + inst.fill(provider, newPath, pathSeparator) + m.fromProduct(Tuple.fromArray(generatedElems.map(_.asInstanceOf[Object]).toArray)) + def map(p: P, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): P = val inputs = p.productIterator.toList val mappedElems = inputs @@ -49,7 +80,15 @@ object TensorTree extends TensorTreeLowPriority: case ((e1, e2), inst) => inst.zipMap(e1, e2, f) m.fromProduct(Tuple.fromArray(mappedElems.map(_.asInstanceOf[Object]).toArray)) + def foreach[P](p: P, consumer: TensorConsumerWithPath, pathSeparator: String = DEFAULT_PATH_SEPARATOR)(using tt: TensorTree[P]): Unit = + tt.foreach(p, consumer, "", pathSeparator) + + def fill[P](provider: TensorProvider, pathSeparator: String = DEFAULT_PATH_SEPARATOR)(using tt: TensorTree[P]): P = + tt.fill(provider, "", pathSeparator) + trait TensorTreeLowPriority: given identity[A]: TensorTree[A] = new TensorTree[A]: + def foreach(p: A, consumer: TensorTree.TensorConsumerWithPath, path: String, pathSeparator: String): Unit = () + def fill(provider: [T <: Tuple, V] => (x: Labels[T]) ?=> String => Tensor[T, V], path: String, pathSeparator: String): A = ??? def map(p: A, f: [T <: Tuple, V] => (Labels[T]) ?=> (Tensor[T, V] => Tensor[T, V])): A = p def zipMap(p1: A, p2: A, f: [T <: Tuple, V] => (Labels[T]) ?=> ((Tensor[T, V], Tensor[T, V]) => Tensor[T, V])): A = p1 diff --git a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala index 30b322a8..4fd50c85 100644 --- a/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala +++ b/core/src/test/scala/dimwit/autodiff/TensorTreeSuite.scala @@ -8,6 +8,63 @@ import org.scalatest.matchers.should.Matchers class TensorTreeSuite extends AnyFunSpec with Matchers: + describe("foreach"): + it("2-level case class"): + case class Child( + val numbers: Tensor1[A, Float], + val counts: Tensor1[A, Int] + ) + case class Parent( + val child: Child, + val flags: Tensor1[A, Boolean] + ) + val params = Parent( + Child( + Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), + Tensor1(Axis[A]).fromArray(Array(1, 2, 3)) + ), + Tensor1(Axis[A]).fromArray(Array(true, false, true)) + ) + TensorTree.foreach( + params, + [T <: Tuple, V] => + (labels: Labels[T]) ?=> + (name: String, tensor: Tensor[T, V]) => + name match + case "child-numbers" => tensor should equal(params.child.numbers) + case "child-counts" => tensor should equal(params.child.counts) + case "flags" => tensor should equal(params.flags) + case _ => fail(s"Unexpected name: $name"), + pathSeparator = "-" + ) + + describe("fill"): + it("2-level case class"): + case class Child( + val numbers: Tensor1[A, Float], + val counts: Tensor1[A, Int] + ) + case class Parent( + val child: Child, + val flags: Tensor1[A, Boolean] + ) + val expected = Parent( + Child( + Tensor1(Axis[A]).fromArray(Array(0.1f, 0.2f, 0.3f)), + Tensor1(Axis[A]).fromArray(Array(1, 2, 3)) + ), + Tensor1(Axis[A]).fromArray(Array(true, false, true)) + ) + val provider: TensorTree.TensorProvider = [T <: Tuple, V] => + (labels: Labels[T]) ?=> + (name: String) => + name match + case "child-numbers" => expected.child.numbers.asInstanceOf[Tensor[T, V]] + case "child-counts" => expected.child.counts.asInstanceOf[Tensor[T, V]] + case "flags" => expected.flags.asInstanceOf[Tensor[T, V]] + case _ => fail(s"Unexpected name: $name") + TensorTree.fill[Parent](provider, pathSeparator = "-") should equal(expected) + describe("map"): it("1-level case class"): case class Data(