-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathextension.jl
103 lines (82 loc) · 3.78 KB
/
extension.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
using StatsModels: collect_matrix_terms, MatrixTerm, Schema
poly(x, n) = x^n
abstract type PolyModel end
struct PolyTerm <: AbstractTerm
term::Symbol
deg::Int
end
PolyTerm(t::Term, deg::ConstantTerm) = PolyTerm(t.sym, deg.n)
poly(t::AbstractTerm, deg) = PolyTerm(t, deg)
StatsModels.apply_schema(t::FunctionTerm{typeof(poly)}, sch::Schema, Mod::Type{<:PolyModel}) =
apply_schema(poly(t.args...), sch, Mod)
StatsModels.modelcols(p::PolyTerm, d::NamedTuple) =
reduce(hcat, [d[p.term].^n for n in 1:p.deg])
struct NonMatrixTerm{T} <: AbstractTerm
term::T
end
StatsModels.is_matrix_term(::Type{<:NonMatrixTerm}) = false
StatsModels.apply_schema(t::NonMatrixTerm, sch, Mod::Type) =
NonMatrixTerm(apply_schema(t.term, sch, Mod))
StatsModels.modelcols(t::NonMatrixTerm, d) = modelcols(t.term, d)
struct DummyTerm <: AbstractTerm
end
@testset "Extended formula/models" begin
d = (z = rand(10), y = rand(10), x = collect(1:10))
sch = schema(d)
@testset "Poly term" begin
f = @formula(y ~ poly(x, 3))
f_plain = apply_schema(f, sch)
@test f_plain.rhs.terms[1] isa FunctionTerm
# this works but == is not defined correctly and apply_schema creates a new instance
@test_broken f_plain == apply_schema(f, sch, Nothing)
@test last(modelcols(f_plain, d)) == hcat(d[:x].^3)
f_special = apply_schema(f, sch, PolyModel)
@test f_special.rhs.terms[1] isa PolyTerm
@test last(modelcols(f_special, d)) == hcat(d[:x], d[:x].^2, d[:x].^3)
end
@testset "Non-matrix term" begin
f = @formula(z ~ x + y)
f2 = term(:z) ~ term(:x) + NonMatrixTerm(term(:y))
f3 = term(:z) ~ NonMatrixTerm(term(:x)) + term(:y)
f4 = term(:z) ~ NonMatrixTerm.(f.rhs)
f5 = term(:z) ~ term(:x) + NonMatrixTerm(term(:y)) + term(:y)
@test collect_matrix_terms(f.rhs) == MatrixTerm((term(:x) + term(:y)))
@test collect_matrix_terms(f2.rhs) ==
(MatrixTerm((term(:x), )), NonMatrixTerm(term(:y)))
@test collect_matrix_terms(f3.rhs) ==
(MatrixTerm((term(:y), )), NonMatrixTerm(term(:x)))
@test collect_matrix_terms(f4.rhs) == f4.rhs
@test collect_matrix_terms(f5.rhs) ==
(MatrixTerm((term(:x), term(:y))), NonMatrixTerm(term(:y)))
f = apply_schema(f, sch)
@test f.rhs isa MatrixTerm
@test f.rhs == collect_matrix_terms(f.rhs)
@test modelcols(f.rhs, d) == hcat(d.x, d.y)
f2 = apply_schema(f2, sch)
@test f2.rhs isa Tuple{MatrixTerm, NonMatrixTerm}
@test f2.rhs == apply_schema((MatrixTerm(term(:x)), NonMatrixTerm(term(:y))), sch)
@test modelcols(f2.rhs, d) == (hcat(d.x), d.y)
# matrix term goes first
f3 = apply_schema(f3, sch)
@test f3.rhs isa Tuple{MatrixTerm, NonMatrixTerm}
@test f3.rhs == apply_schema((MatrixTerm(term(:y)), NonMatrixTerm(term(:x))), sch)
@test modelcols(f3.rhs, d) == (hcat(d.y), d.x)
f4 = apply_schema(f4, sch)
@test f4.rhs isa Tuple{NonMatrixTerm, NonMatrixTerm}
@test f4.rhs == apply_schema((NonMatrixTerm(term(:x)), NonMatrixTerm(term(:y))), sch)
@test modelcols(f4.rhs, d) == (d.x, d.y)
# matrix terms are gathered
f5 = apply_schema(f5, sch)
@test f5.rhs isa Tuple{MatrixTerm, NonMatrixTerm}
@test f5.rhs ==
apply_schema((MatrixTerm((term.((:x, :y)))), NonMatrixTerm(term(:y))), sch)
@test modelcols(f5.rhs, d) == (hcat(d.x, d.y), d.y)
end
@testset "Fallback" begin
@test_throws ArgumentError modelcols(DummyTerm(), (a=[1], ))
end
@testset "Ambiguity detection" begin
# ambiguities are introduced by adding additional methods here
@test_broken isempty(Test.detect_ambiguities(StatsModels))
end
end