-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCompletion_search_space_efficient.cpp
executable file
·105 lines (97 loc) · 2.59 KB
/
Completion_search_space_efficient.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// Copyright (c) 2013 Elements of Programming Interviews. All rights reserved.
#include <iostream>
#include <iterator>
#include <cmath>
#include <limits>
#include <numeric>
#include <algorithm>
#include <ctime>
#include <cstdlib>
#include <vector>
#include <cassert>
using namespace std;
// @include
double completion_search(vector<double> &A, double budget) {
sort(A.begin(), A.end());
if (budget / A.size() < A.front()) {
return budget / A.size();
}
double reminder = budget;
for (int i = 0; i < A.size() - 1; ++i) {
reminder -= A[i];
if (reminder < 0.0) {
return -1.0;
}
double cutoff = reminder / (A.size() - i - 1);
if (A[i] <= cutoff && cutoff <= A[i + 1]) {
return cutoff;
}
}
return -1.0;
}
// @exclude
double check_answer(vector<double> &A, double budget) {
sort(A.begin(), A.end());
// Calculate the prefix sum for A
vector<double> prefix_sum;
partial_sum(A.cbegin(), A.cend(), back_inserter(prefix_sum));
// costs[i] represents the total payroll if the cap is A[i]
vector<double> costs;
for (int i = 0; i < prefix_sum.size(); ++i) {
costs.emplace_back(prefix_sum[i] + (A.size() - i - 1) * A[i]);
}
auto lower = lower_bound(costs.cbegin(), costs.cend(), budget);
if (lower == costs.cend()) {
return -1.0; // no solution since budget is too large
}
if (lower == costs.cbegin()) {
return budget / A.size();
}
auto idx = distance(costs.cbegin(), lower) - 1;
return A[idx] + (budget - costs[idx]) / (A.size() - idx - 1);
}
int main(int argc, char *argv[]) {
//srand(time(nullptr));
for (int times = 0; times < 10000; ++times) {
int n;
vector<double> A;
double tar;
if (argc == 2) {
n = atoi(argv[1]);
tar = rand() % 100000;
} else if (argc == 3) {
n = atoi(argv[1]), tar = atoi(argv[2]);
} else {
n = 1 + rand() % 1000;
tar = rand() % 100000;
}
for (int i = 0; i < n; ++i) {
A.emplace_back(rand() % 10000);
}
/*
cout << "A = ";
copy(A.begin(), A.end(), ostream_iterator<double>(cout, " "));
cout << endl;
cout << "tar = " << tar << endl;
//*/
double ret = completion_search(A, tar);
double ret2 = check_answer(A, tar);
cout << ret << " " << ret2 << endl;
assert(fabs(ret - ret2) <= 1.0e-10);
if (ret != -1) {
cout << "ret = " << ret << endl;
double sum = 0.0;
for (int i = 0; i < n; ++i) {
if (A[i] > ret) {
sum += ret;
} else {
sum += A[i];
}
}
tar -= sum;
cout << "sum = " << sum << endl;
assert(tar < 1.0e-8);
}
}
return 0;
}