-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsolver.cpp
185 lines (163 loc) · 5.28 KB
/
solver.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#include "solver.hpp"
#include <iostream>
namespace {
const int rows = 9;
const int columns = 9;
const int values = 9;
Minisat::Var toVar(int row, int column, int value) {
assert(row >= 0 && row < rows && "Attempt to get var for nonexistant row");
assert(column >= 0 && column < columns && "Attempt to get var for nonexistant column");
assert(value >= 0 && value < values && "Attempt to get var for nonexistant value");
return row * columns * values + column * values + value;
}
bool is_valid(board const& b) {
if (b.size() != rows) {
return false;
}
for (int row = 0; row < rows; ++row) {
if (b[row].size() != columns) {
return false;
}
for (int col = 0; col < columns; ++col) {
auto value = b[row][col];
if (value < 0 || value > 9) {
return false;
}
}
}
return true;
}
void log_var(Minisat::Lit lit) {
if (sign(lit)) {
std::clog << '-';
}
std::clog << var(lit) + 1 << ' ';
}
void log_clause(Minisat::vec<Minisat::Lit> const& clause) {
for (int i = 0; i < clause.size(); ++i) {
log_var(clause[i]);
}
std::clog << "0\n";
}
void log_clause(Minisat::Lit lhs, Minisat::Lit rhs) {
log_var(lhs); log_var(rhs);
std::clog << "0\n";
}
} //end anonymous namespace
void Solver::init_variables() {
if (m_write_dimacs) {
std::clog << "c (row, column, value) = variable\n";
}
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < columns; ++c) {
for (int v = 0; v < values; ++v) {
auto var = solver.newVar();
if (m_write_dimacs) {
std::clog << "c (" << r << ", " << c << ", " << v + 1 << ") = " << var + 1 << '\n';
}
}
}
}
std::clog << std::flush;
}
void Solver::exactly_one_true(Minisat::vec<Minisat::Lit> const& literals) {
if (m_write_dimacs) {
log_clause(literals);
}
solver.addClause(literals);
for (size_t i = 0; i < literals.size(); ++i) {
for (size_t j = i + 1; j < literals.size(); ++j) {
if (m_write_dimacs) {
log_clause(~literals[i], ~literals[j]);
}
solver.addClause(~literals[i], ~literals[j]);
}
}
}
void Solver::one_square_one_value() {
for (int row = 0; row < rows; ++row) {
for (int column = 0; column < columns; ++column) {
Minisat::vec<Minisat::Lit> literals;
for (int value = 0; value < values; ++value) {
literals.push(Minisat::mkLit(toVar(row, column, value)));
}
exactly_one_true(literals);
}
}
}
void Solver::non_duplicated_values() {
// In each row, for each value, forbid two column sharing that value
for (int row = 0; row < rows; ++row) {
for (int value = 0; value < values; ++value) {
Minisat::vec<Minisat::Lit> literals;
for (int column = 0; column < columns; ++column) {
literals.push(Minisat::mkLit(toVar(row, column, value)));
}
exactly_one_true(literals);
}
}
// In each column, for each value, forbid two rows sharing that value
for (int column = 0; column < columns; ++column) {
for (int value = 0; value < values; ++value) {
Minisat::vec<Minisat::Lit> literals;
for (int row = 0; row < rows; ++row) {
literals.push(Minisat::mkLit(toVar(row, column, value)));
}
exactly_one_true(literals);
}
}
// Now forbid sharing in the 3x3 boxes
for (int r = 0; r < 9; r += 3) {
for (int c = 0; c < 9; c += 3) {
for (int value = 0; value < values; ++value) {
Minisat::vec<Minisat::Lit> literals;
for (int rr = 0; rr < 3; ++rr) {
for (int cc = 0; cc < 3; ++cc) {
literals.push(Minisat::mkLit(toVar(r + rr, c + cc, value)));
}
}
exactly_one_true(literals);
}
}
}
}
Solver::Solver(bool write_dimacs):
m_write_dimacs(write_dimacs) {
// Initialize the board
init_variables();
one_square_one_value();
non_duplicated_values();
}
bool Solver::apply_board(board const& b) {
assert(is_valid(b) && "Provided board is not valid!");
bool ret = true;
for (int row = 0; row < rows; ++row) {
for (int col = 0; col < columns; ++col) {
auto value = b[row][col];
if (value != 0) {
ret &= solver.addClause(Minisat::mkLit(toVar(row, col, value - 1)));
}
}
}
return ret;
}
bool Solver::solve() {
return solver.solve();
}
board Solver::get_solution() const {
board b(rows, std::vector<int>(columns));
for (int row = 0; row < rows; ++row) {
for (int col = 0; col < columns; ++col) {
int found = 0;
for (int val = 0; val < values; ++val) {
if (solver.modelValue(toVar(row, col, val)).isTrue()) {
++found;
b[row][col] = val + 1;
}
}
assert(found == 1 && "The SAT solver assigned one position more than one value");
(void)found;
}
}
return b;
}