Skip to content

Simplify types for binary operations #95

@marcelluethi

Description

@marcelluethi

The following simple code fails to compile as expected

  trait A derives Label
  trait B derives Label

  val t1: Tensor1[A, Float] = ???
  val t2: Tensor1[B, Float] = ???
  t1 + t2 // should not compile

Unfortunately the error message is missleading

A tuple of axis labels T was given or inferred that does not have a valid Labels instance. 

Ensure that all of the types in the tuple have a 'derives Label' clause.
.
I found:

    dimwit.tensor.Labels.concat[head, tail](
      /* missing */summon[dimwit.tensor.Label[head]], ???)

But no implicit values were found that match type dimwit.tensor.Label[head]

The cause seems to be that add is implemented in a very generic way:

def add[T <: Tuple: Labels, T1 <: T, T2 <: T, V: IsNumber](t1: Tensor[T1, V], t2: Tensor[T2, V]): Tensor[T, V] = Tensor(Jax.jnp.add(t1.jaxValue, t2.jaxValue))

Similar issues occur with other binary operations.

The error message should be improved for this case. A simple fix would be to directly use T instead of the lower bound T1. However, I am not sure if the generality is intentional and needed anywhere.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions