-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmul.jl
102 lines (84 loc) · 3.07 KB
/
mul.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#################### MATRIX MULTIPLICATION ####################
"""
mul(A, B, s)
mul(A, B) = A *ᵃ B
Matrix multiplication for two `NamedDimsArray`s, which automatically transposes as required.
If given a name `s`, it arranges to sum over this index.
If not, it looks for exactly one name shared between the two.
The infix form is typed `*\\^a<tab>`.
See also the more general `contract(A, B, s)` for higher-rank tensors.
"""
function mul(A::NamedUnion, B::NamedUnion)
namesA, namesB = getnames(A), getnames(B)
found = map(namesA) do n
Base.sym_in(n, namesB) ? n : nothing
end
ok = filter(!isnothing, found)
isempty(ok) && error("no name in common between $namesA and $namesB")
length(ok) > 1 && error("no unique way to contract $namesA and $namesB")
# allunique(namesB) || error("repeated names in $namesB") # these will be caught later
mul(A, B, first(ok))
end
const *ᵃ = mul
function mul(x::NamedDimsArray{Lx,Tx,1}, y::NamedDimsArray{Ly,Ty,1}, s::Symbol) where {Lx,Tx,Ly,Ty}
s == Lx[1] == Ly[1] || contract_error(x, y, s)
if Tx <: Number
transpose(x) * y
else
first(permutedims(x) * y)
end
end
function mul(x::NamedDimsArray{Lx,Tx,2}, y::NamedDimsArray{Ly,Ty,1}, s::Symbol) where {Lx,Tx,Ly,Ty}
Lx[1] == Lx[2] && contract_error(x, y, s)
if s == Lx[2] == Ly[1]
x * y
elseif s == Lx[1] == Ly[1]
transpose1(x) * y
else
contract_error(x, y, s)
end
end
function mul(x::NamedDimsArray{Lx,Tx,1}, y::NamedDimsArray{Ly,Ty,2}, s::Symbol) where {Lx,Tx,Ly,Ty}
Ly[1] == Ly[2] && contract_error(x, y, s)
if s == Lx[1] == Ly[1]
transpose1(transpose1(x) * y)
elseif s == Lx[1] == Ly[2]
transpose1(transpose1(x) * transpose1(y))
else
contract_error(x, y, s)
end
end
function mul(x::NamedDimsArray{Lx,Tx,2}, y::NamedDimsArray{Ly,Ty,2}, s::Symbol) where {Lx,Tx,Ly,Ty}
Lx[1] == Lx[2] && contract_error(x, y, s)
Ly[1] == Ly[2] && contract_error(x, y, s)
if s == Lx[2] == Ly[1]
x * y
elseif s == Lx[1] == Ly[1]
transpose1(x) * y
elseif s == Lx[2] == Ly[2]
x * transpose1(y)
elseif s == Lx[1] == Ly[2]
transpose1(x) * transpose1(y)
else
contract_error(x, y, s)
end
end
function contract_error(x, y, s)
msg = "cannot contract index :$s between arrays with indices $(getnames(x)) and $(getnames(x))"
throw(DimensionMismatch(msg))
end
transpose1(x::AbstractArray{<:Number}) = transpose(x)
transpose1(x::AbstractArray) = permutedims(x)
#################### PACKAGES ####################
"""
contract(A, B) == A ⊙ᵃ B # using TensorOperations
contract(A, B, C) # using OMEinsum
This generalises matrix multiplication to higher-dimensional arrays.
To contract two arrays, you must load the package `TensorOperations`.
To contract three or more arrays (which is not really matrix multiplication at all)
you must load the package `OMEinsum`.
Those packages do all the work, and `NamedPlus` just handles the names.
"""
function contract end
const ⊙ᵃ = contract
####################