Skip to content

Commit

Permalink
Merge pull request #86 from mcabbott/unsafe
Browse files Browse the repository at this point in the history
Mark all shifted indices unsafe in gradients
  • Loading branch information
mcabbott authored Mar 19, 2021
2 parents e7b6244 + 1d7d150 commit 34d47bd
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ function parse_input(expr, store)
unique!(store.leftind)
store.sharedind = unique!(setdiff(store.sharedind, store.notfree))
store.rightind = unique!(setdiff(store.rightind, store.notfree))
union!(store.unsaferight, store.shiftedind)
any(==(:_), vcat(store.leftind, store.rightind)) && throw("can't use _ as an index name")

unique!(store.outpre) # kill mutiple assertions, and evaluate any f(A) only once
Expand Down
4 changes: 2 additions & 2 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using Tullio, Test
using CUDA, KernelAbstractions
CUDA.allowscalar(false)
using Tracker
using Tracker, ForwardDiff
@tullio grad=Base

# matmul
Expand All @@ -17,14 +17,14 @@ A = rand(3,40); B = rand(40,500);
@test ΔA ones(3,500) * B'
@test cu(ΔA) Tracker.gradient((A,B) -> sum(mul(A, B)), cu(A), cu(B))[1]

#=
# shifts
@tullio D[i,j] := A[i,j+k] k in 0:10
@test axes(D) == (1:3, 1:30)
@tullio cD[i,j] := cu(A)[i,j+k] k in 0:10
@test cD isa CuArray
@test cD cu(D)

#=
# ranges
@tullio E[i,j] := A[i,j+k-1] + (-1:0.5:1)[k]
@test axes(E) == (1:3, 1:36)
Expand Down

0 comments on commit 34d47bd

Please sign in to comment.