-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTests.cpp
135 lines (102 loc) · 3.81 KB
/
Tests.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include "EGraph.h"
#include "TestLanguage.h"
#include "Serialization.h"
using namespace TestLanguage;
void rewriteIdentityRuleTest()
{
// given
e::Graph eGraph;
const auto abbc = makeExpression("(a * b) * (b + c)", eGraph);
const auto id1 = makeExpression("(a * b) * ((b + c) * 1)", eGraph);
const auto id2 = makeExpression("((a * 1) * b) * (b + (c * 1))", eGraph);
const auto id3 = makeExpression("((a * b) * (b + c)) * 1", eGraph);
const auto id4 = makeExpression("(((a * b) * (b + c)) * 1) * 1", eGraph);
const auto identityRule = makeRewriteRule("$x * 1 => $x");
// when
eGraph.restoreInvariants();
// then
assert(eGraph.find(id1) != eGraph.find(abbc));
// and when
eGraph.rewrite(identityRule);
// then
assert(eGraph.find(id1) == eGraph.find(abbc));
assert(eGraph.find(id2) == eGraph.find(abbc));
assert(eGraph.find(id3) == eGraph.find(abbc));
assert(eGraph.find(id4) == eGraph.find(abbc));
}
void rewriteZeroRuleTest()
{
// given
e::Graph eGraph;
const auto zeroTerm = eGraph.addTerm("0");
const auto zero1 = makeExpression("((a - b) + c) * ((b - c) * 0)", eGraph);
const auto zero2 = makeExpression("((a * (b + c)) * d) * 0", eGraph);
const auto zero3 = makeExpression("((a - b) * 0) * ((b + c) * 0)", eGraph);
const auto zeroRule = makeRewriteRule("$x * 0 => 0");
// when
eGraph.restoreInvariants();
// then
assert(eGraph.find(zero1) != eGraph.find(zeroTerm));
// and when
eGraph.rewrite(zeroRule);
eGraph.rewrite(zeroRule); // needs one more iteration to know that 0 * 0 == 0
// then
assert(eGraph.find(zero1) == eGraph.find(zeroTerm));
assert(eGraph.find(zero2) == eGraph.find(zeroTerm));
assert(eGraph.find(zero3) == eGraph.find(zeroTerm));
}
void rewriteAssociativityRuleTest()
{
// given
e::Graph eGraph;
const auto abc1 = makeExpression("(a + b) + c", eGraph);
const auto abc2 = makeExpression("a + (b + c)", eGraph);
const auto abcd1 = makeExpression("a + (b + (c + d))", eGraph);
const auto abcd2 = makeExpression("((a + b) + c) + d", eGraph);
const auto associativityRule = makeRewriteRule("($x + $y) + $z => $x + ($y + $z)");
// when
eGraph.rewrite(associativityRule);
// then
assert(eGraph.find(abc1) == eGraph.find(abc2));
assert(eGraph.find(abcd1) != eGraph.find(abcd2));
// and when
eGraph.rewrite(associativityRule); // needs one more iteration
// then
assert(eGraph.find(abcd1) == eGraph.find(abcd2));
assert(eGraph.find(abc1) != eGraph.find(abcd1));
}
void rewriteDistributivityRuleTest()
{
// given
e::Graph eGraph;
const auto expr1 = makeExpression("(10 + ((20 + 20) * 30)) * 40", eGraph);
const auto expr2 = makeExpression("(10 * 40) + (((20 * 30) + (20 * 30)) * 40)", eGraph);
const auto expr3 = makeExpression("(10 * 40) + (((20 + 20) * 30) * 40)", eGraph);
// when
eGraph.rewrite(makeRewriteRule("($x + $y) * $z => ($x * $z) + ($y * $z)"));
// then
assert(eGraph.find(expr1) == eGraph.find(expr2));
assert(eGraph.find(expr2) == eGraph.find(expr3));
}
void serializationTest()
{
// given
e::Graph eGraph;
const auto expr1 = makeExpression("(10 + ((20 + 30) + 40)) + 50", eGraph);
const auto expr2 = makeExpression("50 + ((40 + (30 + 20)) + 10)", eGraph);
// when
eGraph.rewrite(makeRewriteRule("$x + $y => $y + $x"));
const auto serializedData = e::serialize(eGraph);
const auto otherGraph = e::deserialize(serializedData);
// then
assert(otherGraph.find(expr1) == otherGraph.find(expr2));
assert(otherGraph.find(expr1) == eGraph.find(expr1));
}
int main(int argc, char **argv)
{
rewriteIdentityRuleTest();
rewriteZeroRuleTest();
rewriteAssociativityRuleTest();
rewriteDistributivityRuleTest();
serializationTest();
}