Add foreachWithName to TensorTree#87
Conversation
04e0745 to
322a340
Compare
322a340 to
7fe63f0
Compare
marcelluethi
left a comment
There was a problem hiding this comment.
It is not clear to me what this PR is trying to achieve and why is it needed.
Can you add some motivation?
|
General Motivation: This PR adds a foreach and fill operation over a TensorTree structure. foreach is equivalent to List.foreach in Scala, to consume the structure for some effect (e.g. printing), except it additionally provides a String of the name representing the internal tree structure name. I think a foreach without name makes no sense, as the structure is unknown; however an alternative design would be "zipWithNames.foreach" analog to "zipWithIndex.foreach"... fill is kind of the opposite to foreach. It requires a method that can map a tree structure name back to a Tensor. It maps through the structure, creates each tensor to create a tree from its divided flat representation (e.g. a Map of Names and Tensors). My use case I require these operations in DeepWit for my logging concept. The idea is to log Tensors and TensorTrees during Training and visualize them after the run in a Notebook (instead of logging images during Training in e.g. WandDB): val logger = new TenZarrLogger(f"out/$time")
// During training
logger.logTensorTree(s"train/grad", i, grads)
// After training
val gradientsAtStep100 = logger.loadTensorTree[MLPParams](s"train/grad", 100)
// Class
class TenZarrLogger(storePath: String = "logs.zarr"):
private val zarr = py.module("zarr")
private val np = py.module("numpy")
private val root = zarr.open(storePath, mode = "a")
def logTree[Data: TensorTree](name: String, iteration: Int, data: Data): Unit =
TensorTree.foreach(
data,
[T <: Tuple, V] =>
(labels: Labels[T]) ?=>
(tensorName: String, tensor: Tensor[T, V]) =>
log(s"$name/$tensorName", iteration, tensor) // log logs (among other things) the tensor in "$name/$tensorName/values"
)
def loadTree[Data: TensorTree](name: String, iteration: Int): Option[Data] =
Option.when(root.as[py.Dynamic].__contains__(name).as[Boolean]):
val pyData = root.bracketAccess(name)
TensorTree[Data].fill([T <: Tuple, V] =>
(labels: Labels[T]) ?=>
(path: String) =>
val raw = pyData.bracketAccess(path).bracketAccess("values").bracketAccess(iteration)
liftPyTensor[T, V](raw)
)
|
marcelluethi
left a comment
There was a problem hiding this comment.
This is an interesting use case and I agree that we need such methods. While I am fine with merging this for the moment, I think we should create an issue and try to think it through. For example, your suggestion of zipWithNames would be much more general than the foreach with name, as it could be used together with map. The fill could maybe be replaced by a map, as the tree structure seems already to be in place.
My suggestion is to merge it for now, and after we have finalized #89 we can think about the generic methods. At that point we should also add api doc to such fundamental methods.
No description provided.