Skip to content

Commit

Permalink
Generalise some routines
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Jan 28, 2025
1 parent c7e0e90 commit d183947
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
12 changes: 1 addition & 11 deletions src/calculus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,12 @@ cumsum_size(::NTuple{N,Integer}, A, dims) where N = error("Not implemented")
####

@inline diff(a::AbstractQuasiArray, order...; dims::Integer=1) = diff_layout(MemoryLayout(a), a, order...; dims)
function diff_layout(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVector, order...; dims=1)
function diff_layout(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVecOrMat, order...; dims=1)

Check warning on line 55 in src/calculus.jl

View check run for this annotation

Codecov / codecov/patch

src/calculus.jl#L54-L55

Added lines #L54 - L55 were not covered by tests
a = arguments(LAY, V)
dims == 1 || throw(ArgumentError("cannot differentiate a vector along dimension $dims"))
*(diff(a[1], order...), tail(a)...)

Check warning on line 58 in src/calculus.jl

View check run for this annotation

Codecov / codecov/patch

src/calculus.jl#L57-L58

Added lines #L57 - L58 were not covered by tests
end

function diff_layout(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiMatrix, order...; dims=1)
a = arguments(LAY, V)
@assert dims == 1 #for type stability, for now
# if dims == 1
*(diff(a[1], order...), tail(a)...)
# else
# *(front(a)..., diff(a[end]; dims=dims))
# end
end

diff_layout(::MemoryLayout, A, order...; dims...) = diff_size(size(A), A, order...; dims...)
diff_size(sz, a; dims...) = error("diff not implemented for $(typeof(a))")
function diff_size(sz, a, order; dims...)
Expand Down
12 changes: 6 additions & 6 deletions src/quasibroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,13 @@ LazyArrays._broadcast_mul_mul((A,b)::Tuple{AbstractQuasiMatrix,AbstractQuasiVect
# support (A .* B) * y
_broadcasted_mul(a::Tuple{Number,Vararg{Any}}, b::AbstractQuasiVector) = (first(a)*sum(b), _broadcasted_mul(tail(a), b)...)
_broadcasted_mul(a::Tuple{Number,Vararg{Any}}, B::AbstractQuasiMatrix) = (first(a)*sum(B; dims=1), _broadcasted_mul(tail(a), B)...)
_broadcasted_mul(a::Tuple{AbstractQuasiVector,Vararg{Any}}, b::AbstractQuasiVector) = (first(a)*sum(b), _broadcasted_mul(tail(a), b)...)
_broadcasted_mul(a::Tuple{AbstractQuasiVector,Vararg{Any}}, B::AbstractQuasiMatrix) = (first(a)*sum(B; dims=1), _broadcasted_mul(tail(a), B)...)
_broadcasted_mul(A::Tuple{AbstractQuasiMatrix,Vararg{Any}}, b::AbstractQuasiVector) = (axes(first(A),2) == Base.OneTo(1) ? first(A)*sum(b) : (first(A)*b), _broadcasted_mul(tail(A), b)...)
_broadcasted_mul(A::Tuple{AbstractQuasiMatrix,Vararg{Any}}, B::AbstractQuasiMatrix) = (axes(first(A),2) == Base.OneTo(1) ? first(A)*sum(B; dims=1) : (first(A)*B), _broadcasted_mul(tail(A), B)...)
_broadcasted_mul(a::Tuple{AbstractQuasiVector,Vararg{Any}}, b::AbstractQuasiOrVector) = (first(a)*sum(b), _broadcasted_mul(tail(a), b)...)
_broadcasted_mul(a::Tuple{AbstractQuasiVector,Vararg{Any}}, B::AbstractQuasiOrMatrix) = (first(a)*sum(B; dims=1), _broadcasted_mul(tail(a), B)...)
_broadcasted_mul(A::Tuple{AbstractQuasiMatrix,Vararg{Any}}, b::AbstractQuasiOrVector) = (axes(first(A),2) == Base.OneTo(1) ? first(A)*sum(b) : (first(A)*b), _broadcasted_mul(tail(A), b)...)
_broadcasted_mul(A::Tuple{AbstractQuasiMatrix,Vararg{Any}}, B::AbstractQuasiOrMatrix) = (axes(first(A),2) == Base.OneTo(1) ? first(A)*sum(B; dims=1) : (first(A)*B), _broadcasted_mul(tail(A), B)...)

Check warning on line 192 in src/quasibroadcast.jl

View check run for this annotation

Codecov / codecov/patch

src/quasibroadcast.jl#L189-L192

Added lines #L189 - L192 were not covered by tests
_broadcasted_mul(A::AbstractQuasiMatrix, b::Tuple{Number,Vararg{Any}}) = (sum(A; dims=2)*first(b)[1], _broadcasted_mul(A, tail(b))...)
_broadcasted_mul(A::AbstractQuasiMatrix, b::Tuple{Union{AbstractVector,AbstractQuasiVector},Vararg{Any}}) = (size(first(b),1) == 1 ? (sum(A; dims=2)*first(b)[1]) : (A*first(b)), _broadcasted_mul(A, tail(b))...)
_broadcasted_mul(A::AbstractQuasiMatrix, B::Tuple{Union{AbstractMatrix,AbstractQuasiMatrix},Vararg{Any}}) = (size(first(B),1) == 1 ? (sum(A; dims=2) * first(B)) : (A * first(B)), _broadcasted_mul(A, tail(B))...)
_broadcasted_mul(A::AbstractQuasiMatrix, b::Tuple{AbstractQuasiOrVector,Vararg{Any}}) = (size(first(b),1) == 1 ? (sum(A; dims=2)*first(b)[1]) : (A*first(b)), _broadcasted_mul(A, tail(b))...)
_broadcasted_mul(A::AbstractQuasiMatrix, B::Tuple{AbstractQuasiOrMatrix,Vararg{Any}}) = (size(first(B),1) == 1 ? (sum(A; dims=2) * first(B)) : (A * first(B)), _broadcasted_mul(A, tail(B))...)

Check warning on line 195 in src/quasibroadcast.jl

View check run for this annotation

Codecov / codecov/patch

src/quasibroadcast.jl#L194-L195

Added lines #L194 - L195 were not covered by tests



Expand Down

0 comments on commit d183947

Please sign in to comment.