diff --git a/README.md b/README.md index acbdfe4..a72d98f 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,6 @@ A Julia package to explore a new system of array views. - A systematic approach to detect contiguous views (statically) - Views work with linear algebra functions - ## Overview The key function in this package is ``view``. This function is similar to ``sub`` in Julia Base, except that it returns an view instance with more efficient representation: @@ -30,6 +29,8 @@ view(a, 1:2, 1:2:5, 4) view(a, 2, :, 3:6) ``` +The ``@view`` macro can be used to convert an array indexed with square bracket syntax to a call to the `view` function. For example, ``@view(a[:, 2])`` translates to ``view(a, :, 2)``. + The ``view`` function returns a view of type ``ArrayView``. Here, ``ArrayView`` is an abstract type with two derived types (``ContiguousView`` and ``StridedView``), defined as: ```julia diff --git a/src/ArrayViews.jl b/src/ArrayViews.jl index 71759c2..2869aab 100644 --- a/src/ArrayViews.jl +++ b/src/ArrayViews.jl @@ -9,6 +9,9 @@ export ContiguousArray, ContiguousVector, ContiguousMatrix export contiguous_view, strided_view, view, ellipview, reshape_view export iscontiguous, contiguousrank +export @view + +include("viewmacro.jl") ################################################# # diff --git a/src/viewmacro.jl b/src/viewmacro.jl new file mode 100644 index 0000000..f87a7f6 --- /dev/null +++ b/src/viewmacro.jl @@ -0,0 +1,39 @@ +#fixend replaces symbol("end") in an index with length(objname) if d==0 +#or size(objname, d) for d > 0 +fixend(arg::Symbol, objnam::Symbol, d::Int) = + if arg == symbol("end") + if d == 0 + :(length($objnam)) + else + :(size($objnam, $d)) + end + else + arg + end +function fixend(arg::Expr, objnam::Symbol, d::Int) + #if this expr is another :ref, then any symbol("end")s in subexpressions refer to this :ref, + # so return + arg.head == :ref && return arg + #check for the special case :(1:end) and convert that to :(:) + arg == :(1:$(symbol("end"))) && return :(:) + #otherwise fix any ends in the args + map!(a -> fixend(a, objnam, d), arg.args) + arg +end +fixend(arg, objnam::Symbol, d::Int) = arg + +macro view(ex) + isa(ex, Symbol) && return :(view($(esc(ex)))) + isa(ex, Expr) && ex.head == :ref && isa(ex.args[1], Symbol) || + error("@view accepts a named object or indexed named object (e.g. A(I_1, I_2,...,I_n))") + objnam = ex.args[1] + + if length(ex.args) == 2 + ex.args[2] = fixend(ex.args[2], objnam, 0) + else + for (d, arg) in enumerate(ex.args[2:end]) + ex.args[d+1] = fixend(arg, objnam, d) + end + end + :(view($(map(esc, ex.args)...))) +end diff --git a/test/runtests.jl b/test/runtests.jl index ec7c01d..75dc163 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ tests = ["viewtypes", "contrank", - "subviews"] + "subviews", + "viewmacro"] for t in tests fp = string(t, ".jl") diff --git a/test/viewmacro.jl b/test/viewmacro.jl new file mode 100644 index 0000000..6277375 --- /dev/null +++ b/test/viewmacro.jl @@ -0,0 +1,15 @@ +using ArrayViews +using Base.Test +A = reshape(1:12, 4,3) + +@test @view(A) == A +@test isa(@view(A[1,1:end]), ArrayView) +@test @view(A[1,1:end]) == A[1,1:end] +@test @view(A[1,:]) == A[1,:] +#check symbol("end") is replaced by length instead of size with only 1 dim indexing +@test @view(A[end-4:end-2]) == A[end-4:end-2] +#check that symbol("end") is not replaced when used in an index in a subexpression +r = 1:2 +@test @view(A[1:end, r[1:end]]) == A[1:end, r[1:end]] +#check @view requires an expr of form A[....] +@test_throws ErrorException eval(:(@view((A+A)[1,1:end])))