-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolve.py
137 lines (116 loc) · 4.56 KB
/
solve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from pprint import pprint
from statistics import mean
from typing import List
from ortools.sat.python import cp_model
from constants import FOOD_OFFSET
from data.download_data import download_data_if_needed
from solver.find_n_greatest import find_max_error
from solver.initialize import initialize
from solver.load_data import load_data
from solver.utils import get_arg_parser
class VarArraySolutionPrinter(cp_model.CpSolverSolutionCallback):
def __init__(self, variables, error_for_quantity):
super().__init__()
self.__variables = variables
self.__error_for_quantity = error_for_quantity
self.__solution = None
def get_solution(self):
return self.__solution
def on_solution_callback(self):
# Just the ordered IDs of the foods in the solution.
self.__solution = tuple(
sorted([int(v.Name()) for v in self.__variables if self.Value(v) != 0])
)
self.StopSearch()
def print_info(status, solver, solution_printer):
if status in [cp_model.OPTIMAL, cp_model.FEASIBLE]:
print("\nStatistics")
print(f" status : {solver.StatusName(status)}")
print(f" conflicts: {solver.NumConflicts()}")
print(f" branches : {solver.NumBranches()}")
print(f" wall time: {solver.WallTime()} s")
print(f" sol found: {solution_printer.get_solution()}")
else:
outcomes = [
"UNKNOWN",
"MODEL_INVALID",
"FEASIBLE",
"INFEASIBLE",
"OPTIMAL",
]
outcome = outcomes[status]
print(outcome)
def solve_it(
foods,
max_qty: List[int],
min_requirements: List[int],
max_requirements: List[int],
num_foods: int,
log_level: int = 0,
):
"""
:param min_requirements a list containing the lower bound of nutritional requirements.
:param max_requirements a list containing the upper bound of nutritional requirements.
:param foods: A list specifying the nutritional value of each food.
:param num_foods: Restrict the solution to only use this many foods.
:param log_level: 0 = No logging, 1 = Log solution status, 2 = Log solution status and solver progress.
:return: A list of solutions.
"""
model = cp_model.CpModel()
# If the data file and nutritional requirements are static,
# then food_max_value and max_error could be cached somewhere.
max_error = find_max_error(foods, max_qty, num_foods, min_requirements)
quantity_of_food = [
model.NewIntVar(0, max_qty[i], name=str(food[0]))
for i, food in enumerate(foods)
]
intermediate_values = [
model.NewIntVar(0, max_qty[i], name=str(food[0]))
for i, food in enumerate(foods)
]
error_for_quantity = [
model.NewIntVar(0, max_error[i], f"Error {nutrient}")
for i, nutrient in enumerate(min_requirements)
]
should_use_food = [model.NewIntVar(0, 1, name=str(food[0])) for food in foods]
for j in range(len(foods)):
model.AddMultiplicationEquality(
intermediate_values[j], quantity_of_food[j], should_use_food[j]
)
model.Add(sum(should_use_food) == num_foods)
for i in range(len(min_requirements)):
nutrient_intake = sum(
food[i + FOOD_OFFSET] * intermediate_values[j]
for j, food in enumerate(foods)
)
model.AddLinearConstraint( # min_requirements[i] <= nutrient_intake <= max_requirements[i]
nutrient_intake, min_requirements[i], max_requirements[i]
)
# Here we apply the traditional metric for error using absolute value:
model.AddAbsEquality(
target=error_for_quantity[i], expr=nutrient_intake - min_requirements[i]
)
model.Minimize(sum(error_for_quantity))
solver = cp_model.CpSolver()
solver.parameters.log_search_progress = bool(log_level >= 2)
solver.parameters.enumerate_all_solutions = True
solution_printer = VarArraySolutionPrinter(intermediate_values, error_for_quantity)
status = solver.Solve(model, solution_printer)
if log_level >= 1:
print_info(status, solver, solution_printer)
if status in [cp_model.OPTIMAL, cp_model.FEASIBLE]:
solution = solution_printer.get_solution()
if log_level:
print(solution)
return solution
if __name__ == "__main__":
foods, max_foods, min_requirements, max_requirements, verbose = initialize()
solutions = solve_it(
foods,
max_foods,
min_requirements,
max_requirements,
num_foods=7,
log_level=verbose,
)
print(solutions)