Skip to content
This repository has been archived by the owner on Jan 20, 2025. It is now read-only.

Commit

Permalink
Add support for cat and slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Dec 15, 2024
1 parent 18f7245 commit e8056f1
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 0 deletions.
121 changes: 121 additions & 0 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ using ArrayLayouts: ArrayLayouts
return ArrayLayouts.layout_getindex(a, I...)
end

@interface interface::AbstractArrayInterface function Base.setindex!(

Check warning on line 22 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L22

Added line #L22 was not covered by tests
a::AbstractArray, value, I...
)
# TODO: Change to this once broadcasting in `@interface` is supported:
# @interface interface a[I...] .= value
@interface interface map!(identity, @view(a[I...]), value)
return a

Check warning on line 28 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L27-L28

Added lines #L27 - L28 were not covered by tests
end

# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or
# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`.
# TODO: Use `MethodError`?
Expand All @@ -28,6 +37,27 @@ end
return error("Not implemented.")
end

# TODO: Make this more general, use `Base.to_index`.
@interface interface::AbstractArrayInterface function Base.getindex(

Check warning on line 41 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L41

Added line #L41 was not covered by tests
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
return @interface interface getindex(a, Tuple(I)...)

Check warning on line 44 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L44

Added line #L44 was not covered by tests
end

# TODO: Use `MethodError`?
@interface ::AbstractArrayInterface function Base.setindex!(

Check warning on line 48 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L48

Added line #L48 was not covered by tests
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
) where {N}
return error("Not implemented.")

Check warning on line 51 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L51

Added line #L51 was not covered by tests
end

# TODO: Make this more general, use `Base.to_index`.
@interface interface::AbstractArrayInterface function Base.setindex!(

Check warning on line 55 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L55

Added line #L55 was not covered by tests
a::AbstractArray{<:Any,N}, value, I::CartesianIndex{N}
) where {N}
return @interface interface setindex!(a, value, Tuple(I)...)

Check warning on line 58 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L58

Added line #L58 was not covered by tests
end

@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type)
return Broadcast.DefaultArrayStyle{ndims(type)}()
end
Expand Down Expand Up @@ -203,3 +233,94 @@ end
## @interface ::AbstractMatrixInterface function Base.*(a1, a2)
## return ArrayLayouts.mul(a1, a2)
## end

# Concatenation

axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
function axis_cat(

Check warning on line 240 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L239-L240

Added lines #L239 - L240 were not covered by tests
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return axis_cat(axis_cat(a1, a2), a_rest...)

Check warning on line 243 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L243

Added line #L243 was not covered by tests
end

unval(x) = x
unval(::Val{x}) where {x} = x

Check warning on line 247 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L246-L247

Added lines #L246 - L247 were not covered by tests

function cat_axes(as::AbstractArray...; dims)
return ntuple(length(first(axes.(as)))) do dim
return if dim in unval(dims)
axis_cat(map(axes -> axes[dim], axes.(as))...)

Check warning on line 252 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L249-L252

Added lines #L249 - L252 were not covered by tests
else
axes(first(as))[dim]

Check warning on line 254 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L254

Added line #L254 was not covered by tests
end
end
end

function cat! end

# Represents concatenating `args` over `dims`.
struct Cat{Args<:Tuple{Vararg{AbstractArray}},dims}
args::Args
end
to_cat_dims(dim::Integer) = Int(dim)
to_cat_dims(dim::Int) = (dim,)
to_cat_dims(dims::Val) = to_cat_dims(unval(dims))
to_cat_dims(dims::Tuple) = dims
Cat(args::AbstractArray...; dims) = Cat{typeof(args),to_cat_dims(dims)}(args)
cat_dims(::Cat{<:Any,dims}) where {dims} = dims

Check warning on line 270 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L265-L270

Added lines #L265 - L270 were not covered by tests

function Base.axes(a::Cat)
return cat_axes(a.args...; dims=cat_dims(a))

Check warning on line 273 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L272-L273

Added lines #L272 - L273 were not covered by tests
end
Base.eltype(a::Cat) = promote_type(eltype.(a.args)...)
function Base.similar(a::Cat)
ax = axes(a)
elt = eltype(a)

Check warning on line 278 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L275-L278

Added lines #L275 - L278 were not covered by tests
# TODO: This drops GPU information, maybe use MemoryLayout?
return similar(arraytype(interface(a.args...), elt), ax)

Check warning on line 280 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L280

Added line #L280 was not covered by tests
end

# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
# This is very similar to the `Base.cat` implementation but handles zero values better.
function cat_offset!(

Check warning on line 286 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L286

Added line #L286 was not covered by tests
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
)
inds = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)

Check warning on line 290 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L289-L290

Added lines #L289 - L290 were not covered by tests
end
a_dest[inds...] = a1
new_offsets = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]

Check warning on line 294 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L292-L294

Added lines #L292 - L294 were not covered by tests
end
cat_offset!(a_dest, new_offsets, a_rest...; dims)
return a_dest

Check warning on line 297 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L296-L297

Added lines #L296 - L297 were not covered by tests
end
function cat_offset!(a_dest::AbstractArray, offsets; dims)
return a_dest

Check warning on line 300 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L299-L300

Added lines #L299 - L300 were not covered by tests
end

@interface ::AbstractArrayInterface function cat!(

Check warning on line 303 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L303

Added line #L303 was not covered by tests
a_dest::AbstractArray, as::AbstractArray...; dims
)
offsets = ntuple(zero, ndims(a_dest))

Check warning on line 306 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L306

Added line #L306 was not covered by tests
# TODO: Fill `a_dest` with zeros if needed using `zero!`.
cat_offset!(a_dest, offsets, as...; dims)
return a_dest

Check warning on line 309 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L308-L309

Added lines #L308 - L309 were not covered by tests
end

@interface interface::AbstractArrayInterface function Base.cat(as::AbstractArray...; dims)
a_dest = similar(Cat(as...; dims))
@interface interface cat!(a_dest, as...; dims)
return a_dest

Check warning on line 315 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L312-L315

Added lines #L312 - L315 were not covered by tests
end

# TODO: Use `@derive` instead:
# ```julia
# @derive (T=AbstractArray,) begin
# cat!(a_dest::AbstractArray, as::T...; dims)
# end
# ```
function cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
return @interface interface(as...) cat!(a_dest, as...; dims)

Check warning on line 325 in src/abstractarrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractarrayinterface.jl#L324-L325

Added lines #L324 - L325 were not covered by tests
end
1 change: 1 addition & 0 deletions src/abstractinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
interface(x) = interface(typeof(x))
# TODO: Define as `DefaultInterface()`.
interface(::Type) = error("Interface unknown.")
interface(x1, x_rest...) = combine_interfaces(x1, x_rest...)

Check warning on line 5 in src/abstractinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractinterface.jl#L5

Added line #L5 was not covered by tests

# Adapted from `Base.Broadcast.combine_styles`.
# Get the combined interfaces of the input objects.
Expand Down
2 changes: 2 additions & 0 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ function derive(::Val{:AbstractArrayOps}, type)
return quote
Base.getindex(::$type, ::Any...)
Base.getindex(::$type, ::Int...)
Base.setindex!(::$type, ::Any, ::Any...)

Check warning on line 27 in src/traits.jl

View check run for this annotation

Codecov / codecov/patch

src/traits.jl#L27

Added line #L27 was not covered by tests
Base.setindex!(::$type, ::Any, ::Int...)
Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}})
Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}})
Expand All @@ -44,6 +45,7 @@ function derive(::Val{:AbstractArrayOps}, type)
Base.permutedims!(::Any, ::$type, ::Any)
Broadcast.BroadcastStyle(::Type{<:$type})
Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}})
Base.cat(::$type...; kwargs...)

Check warning on line 48 in src/traits.jl

View check run for this annotation

Codecov / codecov/patch

src/traits.jl#L48

Added line #L48 was not covered by tests
ArrayLayouts.MemoryLayout(::Type{<:$type})
LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number)
end
Expand Down
34 changes: 34 additions & 0 deletions test/basics/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a[1, 2] = 12
b = similar(a)
copyto!(b, a)
@test b isa SparseArrayDOK{elt,2}
@test b == a
@test b[1, 2] == 12
@test storedlength(b) == 1
Expand Down Expand Up @@ -114,6 +115,39 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = zero(a)
@test b isa SparseArrayDOK{elt,2}
@test iszero(b)
@test iszero(storedlength(b))

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = SparseArrayDOK{elt}(4, 4)
b[2:3, 2:3] .= a
@test isone(storedlength(b))
@test b[2, 3] == 12

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = SparseArrayDOK{elt}(4, 4)
b[2:3, 2:3] = a
@test isone(storedlength(b))
@test b[2, 3] == 12

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = SparseArrayDOK{elt}(4, 4)
c = @view b[2:3, 2:3]
c .= a
@test isone(storedlength(b))
@test b[2, 3] == 12

a1 = SparseArrayDOK{elt}(2, 2)
a1[1, 2] = 12
a2 = SparseArrayDOK{elt}(2, 2)
a2[2, 1] = 21
b = cat(a1, a2; dims=(1, 2))
@test b isa SparseArrayDOK{elt,2}
@test storedlength(b) == 2
@test b[1, 2] == 12
@test b[4, 3] == 21
end

0 comments on commit e8056f1

Please sign in to comment.