Skip to content

Commit c2fdb40

Browse files
mtfishmanclaude
andcommitted
Strip LinearBroadcasted types to minimal interface
Remove adjoint, transpose, permuteddims, conj, StridedViews methods from LinearBroadcasted types and Mul. These were algebraic rewrite rules carried over from the AbstractArray era. Will bring back as needed when validating against NamedDimsArrays and GradedArrays PRs. Each type now only defines: axes, eltype, ndims, similar, operation, arguments. Plus the materialization chain (copy, copyto!, add!). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ea8a252 commit c2fdb40

2 files changed

Lines changed: 1 addition & 113 deletions

File tree

src/linearbroadcasted.jl

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import Base.Broadcast as BC
2-
import FunctionImplementations as FI
32
import LinearAlgebra as LA
4-
import StridedViews as SV
53

64
# TermInterface-like interface.
75
iscall(x) = false
@@ -54,17 +52,6 @@ function Base.similar(a::ScaledBroadcasted, elt::Type, ax)
5452
return similar(unscaled(a), elt, ax)
5553
end
5654

57-
function Base.adjoint(a::ScaledBroadcasted)
58-
return ScaledBroadcasted(coeff(a), adjoint(unscaled(a)))
59-
end
60-
function Base.transpose(a::ScaledBroadcasted)
61-
return ScaledBroadcasted(coeff(a), transpose(unscaled(a)))
62-
end
63-
64-
function FI.permuteddims(a::ScaledBroadcasted, perm)
65-
return ScaledBroadcasted(coeff(a), FI.permuteddims(unscaled(a), perm))
66-
end
67-
6855
operation(::ScaledBroadcasted) = *
6956
arguments(a::ScaledBroadcasted) = (coeff(a), unscaled(a))
7057

@@ -84,17 +71,6 @@ function Base.similar(a::ConjBroadcasted, elt::Type, ax)
8471
return similar(unconj(a), elt, ax)
8572
end
8673

87-
Base.conj(a::ConjBroadcasted) = unconj(a)
88-
Base.adjoint(a::ConjBroadcasted) = transpose(unconj(a))
89-
Base.transpose(a::ConjBroadcasted) = adjoint(unconj(a))
90-
91-
function FI.permuteddims(a::ConjBroadcasted, perm)
92-
return ConjBroadcasted(FI.permuteddims(unconj(a), perm))
93-
end
94-
95-
SV.isstrided(a::ConjBroadcasted) = SV.isstrided(unconj(a))
96-
SV.StridedView(a::ConjBroadcasted) = conj(SV.StridedView(unconj(a)))
97-
9874
operation(::ConjBroadcasted) = conj
9975
arguments(a::ConjBroadcasted) = (unconj(a),)
10076

@@ -120,17 +96,6 @@ function Base.similar(a::AddBroadcasted, elt::Type, ax)
12096
return similar(BC.Broadcasted(+, addends(a)), elt, ax)
12197
end
12298

123-
function Base.adjoint(a::AddBroadcasted)
124-
return AddBroadcasted(adjoint.(addends(a))...)
125-
end
126-
function Base.transpose(a::AddBroadcasted)
127-
return AddBroadcasted(transpose.(addends(a))...)
128-
end
129-
130-
function FI.permuteddims(a::AddBroadcasted, perm)
131-
return AddBroadcasted(Base.Fix2(FI.permuteddims, perm).(addends(a))...)
132-
end
133-
13499
operation(::AddBroadcasted) = +
135100
arguments(a::AddBroadcasted) = addends(a)
136101

@@ -166,20 +131,6 @@ function Base.show(io::IO, a::Mul)
166131
return nothing
167132
end
168133

169-
function Base.adjoint(a::Mul)
170-
f = factors(a)
171-
return Mul(adjoint(f[2]), adjoint(f[1]))
172-
end
173-
function Base.transpose(a::Mul)
174-
f = factors(a)
175-
return Mul(transpose(f[2]), transpose(f[1]))
176-
end
177-
178-
function FI.permuteddims(a::Mul, perm)
179-
perm == (1, 2) && return a
180-
return transpose(a)
181-
end
182-
183134
iscall(::Mul) = true
184135
operation(::Mul) = *
185136
arguments(a::Mul) = factors(a)

test/test_linearbroadcasted.jl

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import FunctionImplementations as FI
21
using Base.Broadcast: Broadcast as BC
32
using TensorAlgebra: TensorAlgebra as TA, linearbroadcasted
43
using Test: @test, @test_throws, @testset
@@ -16,7 +15,6 @@ using Test: @test, @test_throws, @testset
1615
x = linearbroadcasted(conj, a)
1716
@test x TA.ConjBroadcasted(a)
1817
@test copy(x) conj(a)
19-
@test conj(x) a
2018

2119
x = linearbroadcasted(+, a, b)
2220
@test x TA.AddBroadcasted(a, b)
@@ -39,67 +37,6 @@ using Test: @test, @test_throws, @testset
3937
)
4038
@test copy(x) 2 * a * b + 3 * c
4139
end
42-
@testset "adjoint" begin
43-
a = randn(ComplexF64, 2, 2)
44-
b = randn(ComplexF64, 2, 2)
45-
46-
x = linearbroadcasted(*, 2, a)'
47-
@test x linearbroadcasted(*, 2, a')
48-
@test copy(x) 2a'
49-
50-
x = linearbroadcasted(conj, a)'
51-
@test x transpose(a)
52-
@test copy(x) permutedims(a)
53-
54-
x = linearbroadcasted(+, a, b)'
55-
@test x linearbroadcasted(+, a', b')
56-
@test copy(x) a' + b'
57-
58-
x = TA.Mul(a, b)'
59-
@test x TA.Mul(b', a')
60-
@test copy(x) b' * a'
61-
end
62-
@testset "transpose" begin
63-
a = randn(ComplexF64, 2, 2)
64-
b = randn(ComplexF64, 2, 2)
65-
66-
x = transpose(linearbroadcasted(*, 2, a))
67-
@test x linearbroadcasted(*, 2, transpose(a))
68-
@test copy(x) 2transpose(a)
69-
70-
x = transpose(linearbroadcasted(conj, a))
71-
@test x adjoint(a)
72-
@test copy(x) permutedims(conj(a))
73-
74-
x = transpose(linearbroadcasted(+, a, b))
75-
@test x linearbroadcasted(+, transpose(a), transpose(b))
76-
@test copy(x) transpose(a) + transpose(b)
77-
78-
x = transpose(TA.Mul(a, b))
79-
@test x TA.Mul(transpose(b), transpose(a))
80-
@test copy(x) transpose(b) * transpose(a)
81-
end
82-
@testset "permuteddims" begin
83-
a = randn(ComplexF64, 2, 2)
84-
b = randn(ComplexF64, 2, 2)
85-
perm = (2, 1)
86-
87-
x = FI.permuteddims(linearbroadcasted(*, 2, a), perm)
88-
@test x linearbroadcasted(*, 2, FI.permuteddims(a, perm))
89-
@test copy(x) 2permutedims(a, perm)
90-
91-
x = FI.permuteddims(linearbroadcasted(conj, a), perm)
92-
@test x linearbroadcasted(conj, FI.permuteddims(a, perm))
93-
@test copy(x) conj(permutedims(a, perm))
94-
95-
x = FI.permuteddims(linearbroadcasted(+, a, b), perm)
96-
@test x
97-
linearbroadcasted(+, FI.permuteddims(a, perm), FI.permuteddims(b, perm))
98-
@test copy(x) permutedims(a, perm) + permutedims(b, perm)
99-
100-
x = FI.permuteddims(TA.Mul(a, b), perm)
101-
@test copy(x) permutedims(a * b, perm)
102-
end
10340
@testset "linear broadcast lowering" begin
10441
a = randn(ComplexF64, 2, 2)
10542
style = BC.DefaultArrayStyle{2}()
@@ -118,7 +55,7 @@ using Test: @test, @test_throws, @testset
11855
@test TA.broadcasted_linear(style, conj, a) linearbroadcasted(conj, a)
11956
@test_throws ArgumentError TA.broadcasted_linear(style, exp, a)
12057
end
121-
@testset "LinearBroadcastFunction algebra" begin
58+
@testset "linearbroadcasted algebra" begin
12259
a = randn(ComplexF64, 3, 3)
12360

12461
# Scaling absorbs coefficients

0 commit comments

Comments
 (0)