-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathparallel.jl
121 lines (98 loc) · 3.07 KB
/
parallel.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------
"""
Parallel(transforms)
A transform where `transforms` are applied in parallel.
"""
struct Parallel <: Transform
transforms::Vector{Transform}
end
isrevertible(p::Parallel) = any(isrevertible, p.transforms)
function apply(p::Parallel, table)
# apply transforms in parallel
f(transform) = apply(transform, table)
vals = tcollect(f(t) for t in p.transforms)
# retrieve tables and caches
tables = first.(vals)
caches = last.(vals)
# table with concatenated columns
newtable = tablehcat(tables)
# save original column names
onames = Tables.columnnames(table)
# find first revertible transform
ind = findfirst(isrevertible, p.transforms)
# save info to revert transform
rinfo = if isnothing(ind)
nothing
else
tnames = Tables.columnnames.(tables)
ncols = length.(tnames)
nrcols = ncols[ind]
start = sum(ncols[1:ind-1]) + 1
finish = start + nrcols - 1
range = start:finish
(ind, range)
end
newtable, (onames, caches, rinfo)
end
function revert(p::Parallel, newtable, cache)
# retrieve cache
onames = cache[1]
caches = cache[2]
rinfo = cache[3]
@assert !isnothing(rinfo) "transform is not revertible"
# retrieve info to revert transform
ind = rinfo[1]
range = rinfo[2]
rtrans = p.transforms[ind]
rcache = caches[ind]
# columns of transformed table
cols = Tables.columns(newtable)
# retrieve subtable to revert
rcols = [Tables.getcolumn(cols, j) for j in range]
𝒯 = (; zip(onames, rcols)...)
rtable = 𝒯 |> Tables.materializer(newtable)
# revert transform on subtable
revert(rtrans, rtable, rcache)
end
function reapply(p::Parallel, table, cache)
# retrieve caches
caches = cache[2]
# reapply transforms in parallel
f(t, c) = reapply(t, table, c)
itr = zip(p.transforms, caches)
tables = tcollect(f(t, c) for (t, c) in itr)
# table with concatenated columns
tablehcat(tables)
end
function tablehcat(tables)
# concatenate columns
allvars, allvals = [], []
varsdict = Set{Symbol}()
for 𝒯 in tables
cols = Tables.columns(𝒯)
vars = Tables.columnnames(𝒯)
vals = [Tables.getcolumn(cols, var) for var in vars]
for (var, val) in zip(vars, vals)
while var ∈ varsdict
var = Symbol(var,:_)
end
push!(varsdict, var)
push!(allvars, var)
push!(allvals, val)
end
end
# table with concatenated columns
𝒯 = (; zip(allvars, allvals)...)
𝒯 |> Tables.materializer(first(tables))
end
"""
transform₁ ⊔ transform₂ ⊔ ⋯ ⊔ transformₙ
Create a [`Parallel`](@ref) transform with
`[transform₁, transform₂, …, transformₙ]`.
"""
⊔(t1::Transform, t2::Transform) = Parallel([t1, t2])
⊔(t1::Transform, t2::Parallel) = Parallel([t1; t2.transforms])
⊔(t1::Parallel, t2::Transform) = Parallel([t1.transforms; t2])
⊔(t1::Parallel, t2::Parallel) = Parallel([t1.transforms; t2.transforms])