-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsmooth_sensitivity.py
322 lines (240 loc) · 11.2 KB
/
smooth_sensitivity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# Copyright 2017 The 'Scalable Private Learning with PATE' Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for smooth sensitivity analysis for PATE mechanisms.
This library implements functionality for doing smooth sensitivity analysis
for Gaussian Noise Max (GNMax), Threshold with Gaussian noise, and Gaussian
Noise with Smooth Sensitivity (GNSS) mechanisms.
"""
import math
import numpy as np
import scipy
import pate_core as pate
################################
# SMOOTH SENSITIVITY FOR GNMAX #
################################
# Global dictionary for storing cached q0 values keyed by (sigma, order).
_logq0_cache = {}
def _compute_logq0(sigma, order):
key = (sigma, order)
if key in _logq0_cache:
return _logq0_cache[key]
logq0 = compute_logq0_gnmax(sigma, order)
_logq0_cache[key] = logq0 # Update the global variable.
return logq0
def _compute_logq1(sigma, order, num_classes):
logq0 = _compute_logq0(sigma, order) # Most likely already cached.
logq1 = math.log(_compute_bl_gnmax(math.exp(logq0), sigma, num_classes))
assert logq1 <= logq0
return logq1
def _compute_mu1_mu2_gnmax(sigma, logq):
# Computes mu1, mu2 according to Proposition 10.
mu2 = sigma * math.sqrt(-logq)
mu1 = mu2 + 1
return mu1, mu2
def _compute_data_dep_bound_gnmax(sigma, logq, order):
# Applies Theorem 6 in Appendix without checking that logq satisfies necessary
# constraints. The pre-conditions must be assured by comparing logq against
# logq0 by the caller.
variance = sigma**2
mu1, mu2 = _compute_mu1_mu2_gnmax(sigma, logq)
eps1 = mu1 / variance
eps2 = mu2 / variance
log1q = np.log1p(-math.exp(logq)) # log1q = log(1-q)
log_a = (order - 1) * (
log1q - (np.log1p(-math.exp((logq + eps2) * (1 - 1 / mu2)))))
log_b = (order - 1) * (eps1 - logq / (mu1 - 1))
return np.logaddexp(log1q + log_a, logq + log_b) / (order - 1)
def _compute_rdp_gnmax(sigma, logq, order):
logq0 = _compute_logq0(sigma, order)
if logq >= logq0:
return pate.rdp_data_independent_gaussian(sigma, order)
else:
return _compute_data_dep_bound_gnmax(sigma, logq, order)
def compute_logq0_gnmax(sigma, order):
"""Computes the point where we start using data-independent bounds.
Args:
sigma: std of the Gaussian noise
order: Renyi order lambda
Returns:
logq0: the point above which the data-ind bound overtakes data-dependent
bound.
"""
def _check_validity_conditions(logq):
# Function returns true iff logq is in the range where data-dependent bound
# is valid. (Theorem 6 in Appendix.)
mu1, mu2 = _compute_mu1_mu2_gnmax(sigma, logq)
if mu1 < order:
return False
eps2 = mu2 / sigma**2
# Do computation in the log space. The condition below comes from Lemma 9
# from Appendix.
return logq <= (mu2 - 1) * eps2 - mu2 * math.log(mu1 / (mu1 - 1) * mu2 / (mu2 - 1))
def _compare_dep_vs_ind(logq):
return (_compute_data_dep_bound_gnmax(sigma, logq, order) -
pate.rdp_data_independent_gaussian(sigma, order))
# Natural upper bounds on q0.
logub = min(-(1 + 1. / sigma)**2, -((order - .99) / sigma)**2, -1 / sigma**2)
assert _check_validity_conditions(logub)
# If data-dependent bound is already better, we are done already.
if _compare_dep_vs_ind(logub) < 0:
return logub
# Identifying a reasonable lower bound to bracket logq0.
loglb = 2 * logub # logub is negative, and thus loglb < logub.
while _compare_dep_vs_ind(loglb) > 0:
assert loglb > -10000, "The lower bound on q0 is way too low."
loglb *= 1.5
logq0, r = scipy.optimize.brentq(
_compare_dep_vs_ind, loglb, logub, full_output=True)
assert r.converged, "The root finding procedure failed to converge."
assert _check_validity_conditions(logq0) # just in case.
return logq0
def _compute_bl_gnmax(q, sigma, num_classes):
return ((num_classes - 1) / 2 * scipy.special.erfc(
1 / sigma + scipy.special.erfcinv(2 * q / (num_classes - 1))))
def _compute_bu_gnmax(q, sigma, num_classes):
return min(1, (num_classes - 1) / 2 * scipy.special.erfc(
-1 / sigma + scipy.special.erfcinv(2 * q / (num_classes - 1))))
def _compute_local_sens_gnmax(logq, sigma, num_classes, order):
"""Implements Algorithm 3 (computes an upper bound on local sensitivity).
(See Proposition 13 for proof of correctness.)
"""
logq0 = _compute_logq0(sigma, order)
logq1 = _compute_logq1(sigma, order, num_classes)
if logq1 <= logq <= logq0:
logq = logq1
beta = _compute_rdp_gnmax(sigma, logq, order)
beta_bu_q = _compute_rdp_gnmax(sigma, math.log(_compute_bu_gnmax(math.exp(logq), sigma, num_classes)), order)
beta_bl_q = _compute_rdp_gnmax(sigma, math.log(_compute_bl_gnmax(math.exp(logq), sigma, num_classes)), order)
return max(beta_bu_q - beta, beta - beta_bl_q)
def compute_local_sensitivity_bounds_gnmax(votes, num_teachers, sigma, order):
"""Computes a list of max-LS-at-distance-d for the GNMax mechanism.
A more efficient implementation of Algorithms 4 and 5 working in time
O(teachers*classes). A naive implementation is O(teachers^2*classes) or worse.
Args:
votes: A numpy array of votes.
num_teachers: Total number of voting teachers.
sigma: Standard deviation of the Guassian noise.
order: The Renyi order.
Returns:
A numpy array of local sensitivities at distances d, 0 <= d <= num_teachers.
"""
num_classes = len(votes) # Called m in the paper.
logq0 = _compute_logq0(sigma, order)
logq1 = _compute_logq1(sigma, order, num_classes)
logq = pate.compute_logq_gaussian(votes, sigma)
plateau = _compute_local_sens_gnmax(logq1, sigma, num_classes, order)
# print('LS args', logq1, logq0, sigma, num_classes, order)
# ls_vals = [_compute_local_sens_gnmax(np.log(q), sigma, num_classes, order) for q in np.arange(0.0002, 0.04, 0.0002)]
# np.save('LS_plot_vals.npy', np.asarray(ls_vals))
res = np.full(num_teachers, plateau)
if logq1 <= logq <= logq0:
return res
# Invariant: votes is sorted in the non-increasing order.
votes = sorted(votes, reverse=True)
res[0] = _compute_local_sens_gnmax(logq, sigma, num_classes, order)
curr_d = 0
go_left = logq > logq0 # Otherwise logq < logq1 and we go right.
# Iterate while the following is true:
# 1. If we are going left, logq is still larger than logq0 and we may still
# increase the gap between votes[0] and votes[1].
# 2. If we are going right, logq is still smaller than logq1.
while (go_left and logq > logq0 and votes[1] > 0) or (not go_left and logq < logq1):
curr_d += 1
if go_left: # Try decreasing logq.
votes[0] += 1
votes[1] -= 1
idx = 1
# Restore the invariant. (Can be implemented more efficiently by keeping
# track of the range of indices equal to votes[1]. Does not seem to matter
# for the overall running time.)
while idx < len(votes) - 1 and votes[idx] < votes[idx + 1]:
votes[idx], votes[idx + 1] = votes[idx + 1], votes[idx]
idx += 1
else: # Go right, i.e., try increasing logq.
votes[0] -= 1
votes[1] += 1 # The invariant holds since otherwise logq >= logq1.
logq = pate.compute_logq_gaussian(votes, sigma)
res[curr_d] = _compute_local_sens_gnmax(logq, sigma, num_classes, order)
return res
##################################################
# SMOOTH SENSITIVITY FOR THE THRESHOLD MECHANISM #
##################################################
# A global dictionary of RDPs for various threshold values. Indexed by a 4-tuple
# (num_teachers, threshold, sigma, order).
_rdp_thresholds = {}
def _compute_rdp_list_threshold(num_teachers, threshold, sigma, order):
key = (num_teachers, threshold, sigma, order)
if key in _rdp_thresholds:
return _rdp_thresholds[key]
res = np.zeros(num_teachers + 1)
for v in range(0, num_teachers + 1):
logp = scipy.stats.norm.logsf(threshold - v, scale=sigma)
res[v] = pate.compute_rdp_threshold(logp, sigma, order)
_rdp_thresholds[key] = res
return res
def compute_local_sensitivity_bounds_threshold(counts, num_teachers, threshold, sigma, order):
"""Computes a list of max-LS-at-distance-d for the threshold mechanism."""
def _compute_ls(v):
ls_step_up, ls_step_down = float("-inf"), float("-inf")
if v > 0:
ls_step_down = abs(rdp_list[v - 1] - rdp_list[v])
if v < num_teachers:
ls_step_up = abs(rdp_list[v + 1] - rdp_list[v])
return max(ls_step_down, ls_step_up) # Rely on max(x, None) = x.
cur_max = int(round(max(counts)))
rdp_list = _compute_rdp_list_threshold(num_teachers, threshold, sigma, order)
ls = np.zeros(num_teachers)
for d in range(max(cur_max, num_teachers - cur_max)):
ls_up, ls_down = float("-inf"), float("-inf")
if cur_max + d <= num_teachers:
ls_up = _compute_ls(cur_max + d)
if cur_max - d >= 0:
ls_down = _compute_ls(cur_max - d)
ls[d] = max(ls_up, ls_down)
# this seems to require the following correction to make it non-decreasing in d:
corrected = False # However, testing indicates that this does not change anything, so I'll leave it for now.
if corrected:
for d in range(1, num_teachers): # since the SS def requires dist to be <= d, not = d, ls must be non-decreasing
ls[d] = max(ls[d-1], ls[d]) # this also covers all the 0 entries
# print('corrected ls for threshold')
return ls
#############################################
# PROCEDURES FOR SMOOTH SENSITIVITY RELEASE #
#############################################
# A global dictionary of exponentially decaying arrays. Indexed by beta.
dict_beta_discount = {}
def compute_discounted_max(beta, a):
n = len(a)
if beta not in dict_beta_discount or (len(dict_beta_discount[beta]) < n):
dict_beta_discount[beta] = np.exp(-beta * np.arange(n))
# debug = False
# if debug:
# discounted_ls = a * dict_beta_discount[beta][:n]
# print('chosen distance for maximal discounted ls:', np.argmax(discounted_ls), 'value:', np.max(discounted_ls))
# return np.max(discounted_ls)
return max(a * dict_beta_discount[beta][:n])
def compute_smooth_sensitivity_gnmax(beta, counts, num_teachers, sigma, order):
"""Computes smooth sensitivity of a single application of GNMax."""
ls = compute_local_sensitivity_bounds_gnmax(counts, sigma, order, num_teachers)
return compute_discounted_max(beta, ls)
def compute_rdp_of_smooth_sensitivity_gaussian(beta, sigma, order):
"""Computes the RDP curve for the GNSS mechanism.
Implements Theorem 23 (https://arxiv.org/pdf/1802.08908.pdf).
"""
if beta > 0 and not 1 < order < 1 / (2 * beta):
raise ValueError("Order outside the (1, 1/(2*beta)) range.")
return order * math.exp(2 * beta) / sigma**2 + (
-.5 * math.log(1 - 2 * order * beta) + beta * order) / (
order - 1)