-
Notifications
You must be signed in to change notification settings - Fork 714
/
Copy pathevaluator.cpp
2821 lines (2443 loc) · 116 KB
/
evaluator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include "seal/evaluator.h"
#include "seal/util/common.h"
#include "seal/util/galois.h"
#include "seal/util/numth.h"
#include "seal/util/polyarithsmallmod.h"
#include "seal/util/polycore.h"
#include "seal/util/scalingvariant.h"
#include "seal/util/uintarith.h"
#include <algorithm>
#include <cmath>
#include <functional>
using namespace std;
using namespace seal::util;
namespace seal
{
namespace
{
template <typename T, typename S>
SEAL_NODISCARD inline bool are_same_scale(const T &value1, const S &value2) noexcept
{
return util::are_close<double>(value1.scale(), value2.scale());
}
SEAL_NODISCARD inline bool is_scale_within_bounds(
double scale, const SEALContext::ContextData &context_data) noexcept
{
int scale_bit_count_bound = 0;
switch (context_data.parms().scheme())
{
case scheme_type::bfv:
case scheme_type::bgv:
scale_bit_count_bound = context_data.parms().plain_modulus().bit_count();
break;
case scheme_type::ckks:
scale_bit_count_bound = context_data.total_coeff_modulus_bit_count();
break;
default:
// Unsupported scheme; check will fail
scale_bit_count_bound = -1;
};
return !(scale <= 0 || (static_cast<int>(log2(scale)) >= scale_bit_count_bound));
}
/**
Returns (f, e1, e2) such that
(1) e1 * factor1 = e2 * factor2 = f mod p;
(2) gcd(e1, p) = 1 and gcd(e2, p) = 1;
(3) abs(e1_bal) + abs(e2_bal) is minimal, where e1_bal and e2_bal represent e1 and e2 in (-p/2, p/2].
*/
SEAL_NODISCARD inline auto balance_correction_factors(
uint64_t factor1, uint64_t factor2, const Modulus &plain_modulus) -> tuple<uint64_t, uint64_t, uint64_t>
{
uint64_t t = plain_modulus.value();
uint64_t half_t = t / 2;
auto sum_abs = [&](uint64_t x, uint64_t y) {
int64_t x_bal = static_cast<int64_t>(x > half_t ? x - t : x);
int64_t y_bal = static_cast<int64_t>(y > half_t ? y - t : y);
return abs(x_bal) + abs(y_bal);
};
// ratio = f2 / f1 mod p
uint64_t ratio = 1;
if (!try_invert_uint_mod(factor1, plain_modulus, ratio))
{
throw logic_error("invalid correction factor1");
}
ratio = multiply_uint_mod(ratio, factor2, plain_modulus);
uint64_t e1 = ratio;
uint64_t e2 = 1;
int64_t sum = sum_abs(e1, e2);
// Extended Euclidean
int64_t prev_a = static_cast<int64_t>(plain_modulus.value());
int64_t prev_b = static_cast<int64_t>(0);
int64_t a = static_cast<int64_t>(ratio);
int64_t b = 1;
while (a != 0)
{
int64_t q = prev_a / a;
int64_t temp = prev_a % a;
prev_a = a;
a = temp;
temp = sub_safe(prev_b, mul_safe(b, q));
prev_b = b;
b = temp;
uint64_t a_mod = barrett_reduce_64(static_cast<uint64_t>(abs(a)), plain_modulus);
if (a < 0)
{
a_mod = negate_uint_mod(a_mod, plain_modulus);
}
uint64_t b_mod = barrett_reduce_64(static_cast<uint64_t>(abs(b)), plain_modulus);
if (b < 0)
{
b_mod = negate_uint_mod(b_mod, plain_modulus);
}
if (a_mod != 0 && gcd(a_mod, t) == 1) // which also implies gcd(b_mod, t) == 1
{
int64_t new_sum = sum_abs(a_mod, b_mod);
if (new_sum < sum)
{
sum = new_sum;
e1 = a_mod;
e2 = b_mod;
}
}
}
return make_tuple(multiply_uint_mod(e1, factor1, plain_modulus), e1, e2);
}
} // namespace
Evaluator::Evaluator(const SEALContext &context) : context_(context)
{
// Verify parameters
if (!context_.parameters_set())
{
throw invalid_argument("encryption parameters are not set correctly");
}
}
void Evaluator::negate_inplace(Ciphertext &encrypted) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
{
throw invalid_argument("encrypted is not valid for encryption parameters");
}
// Extract encryption parameters.
auto &context_data = *context_.get_context_data(encrypted.parms_id());
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
size_t encrypted_size = encrypted.size();
// Negate each poly in the array
negate_poly_coeffmod(encrypted, encrypted_size, coeff_modulus, encrypted);
#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT
// Transparent ciphertext output is not allowed.
if (encrypted.is_transparent())
{
throw logic_error("result ciphertext is transparent");
}
#endif
}
void Evaluator::add_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted1, context_) || !is_buffer_valid(encrypted1))
{
throw invalid_argument("encrypted1 is not valid for encryption parameters");
}
if (!is_metadata_valid_for(encrypted2, context_) || !is_buffer_valid(encrypted2))
{
throw invalid_argument("encrypted2 is not valid for encryption parameters");
}
if (encrypted1.parms_id() != encrypted2.parms_id())
{
throw invalid_argument("encrypted1 and encrypted2 parameter mismatch");
}
if (encrypted1.is_ntt_form() != encrypted2.is_ntt_form())
{
throw invalid_argument("NTT form mismatch");
}
if (!are_same_scale(encrypted1, encrypted2))
{
throw invalid_argument("scale mismatch");
}
// Extract encryption parameters.
auto &context_data = *context_.get_context_data(encrypted1.parms_id());
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
auto &plain_modulus = parms.plain_modulus();
size_t coeff_count = parms.poly_modulus_degree();
size_t coeff_modulus_size = coeff_modulus.size();
size_t encrypted1_size = encrypted1.size();
size_t encrypted2_size = encrypted2.size();
size_t max_count = max(encrypted1_size, encrypted2_size);
size_t min_count = min(encrypted1_size, encrypted2_size);
// Size check
if (!product_fits_in(max_count, coeff_count))
{
throw logic_error("invalid parameters");
}
if (encrypted1.correction_factor() != encrypted2.correction_factor())
{
// Balance correction factors and multiply by scalars before addition in BGV
auto factors = balance_correction_factors(
encrypted1.correction_factor(), encrypted2.correction_factor(), plain_modulus);
multiply_poly_scalar_coeffmod(
ConstPolyIter(encrypted1.data(), coeff_count, coeff_modulus_size), encrypted1.size(), get<1>(factors),
coeff_modulus, PolyIter(encrypted1.data(), coeff_count, coeff_modulus_size));
Ciphertext encrypted2_copy = encrypted2;
multiply_poly_scalar_coeffmod(
ConstPolyIter(encrypted2.data(), coeff_count, coeff_modulus_size), encrypted2.size(), get<2>(factors),
coeff_modulus, PolyIter(encrypted2_copy.data(), coeff_count, coeff_modulus_size));
// Set new correction factor
encrypted1.correction_factor() = get<0>(factors);
encrypted2_copy.correction_factor() = get<0>(factors);
add_inplace(encrypted1, encrypted2_copy);
}
else
{
// Prepare destination
encrypted1.resize(context_, context_data.parms_id(), max_count);
// Add ciphertexts
add_poly_coeffmod(encrypted1, encrypted2, min_count, coeff_modulus, encrypted1);
// Copy the remainding polys of the array with larger count into encrypted1
if (encrypted1_size < encrypted2_size)
{
set_poly_array(
encrypted2.data(min_count), encrypted2_size - encrypted1_size, coeff_count, coeff_modulus_size,
encrypted1.data(encrypted1_size));
}
}
#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT
// Transparent ciphertext output is not allowed.
if (encrypted1.is_transparent())
{
throw logic_error("result ciphertext is transparent");
}
#endif
}
void Evaluator::add_many(const vector<Ciphertext> &encrypteds, Ciphertext &destination) const
{
if (encrypteds.empty())
{
throw invalid_argument("encrypteds cannot be empty");
}
for (size_t i = 0; i < encrypteds.size(); i++)
{
if (&encrypteds[i] == &destination)
{
throw invalid_argument("encrypteds must be different from destination");
}
}
destination = encrypteds[0];
for (size_t i = 1; i < encrypteds.size(); i++)
{
add_inplace(destination, encrypteds[i]);
}
}
void Evaluator::sub_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted1, context_) || !is_buffer_valid(encrypted1))
{
throw invalid_argument("encrypted1 is not valid for encryption parameters");
}
if (!is_metadata_valid_for(encrypted2, context_) || !is_buffer_valid(encrypted2))
{
throw invalid_argument("encrypted2 is not valid for encryption parameters");
}
if (encrypted1.parms_id() != encrypted2.parms_id())
{
throw invalid_argument("encrypted1 and encrypted2 parameter mismatch");
}
if (encrypted1.is_ntt_form() != encrypted2.is_ntt_form())
{
throw invalid_argument("NTT form mismatch");
}
if (!are_same_scale(encrypted1, encrypted2))
{
throw invalid_argument("scale mismatch");
}
// Extract encryption parameters.
auto &context_data = *context_.get_context_data(encrypted1.parms_id());
auto &parms = context_data.parms();
auto &coeff_modulus = parms.coeff_modulus();
auto &plain_modulus = parms.plain_modulus();
size_t coeff_count = parms.poly_modulus_degree();
size_t coeff_modulus_size = coeff_modulus.size();
size_t encrypted1_size = encrypted1.size();
size_t encrypted2_size = encrypted2.size();
size_t max_count = max(encrypted1_size, encrypted2_size);
size_t min_count = min(encrypted1_size, encrypted2_size);
// Size check
if (!product_fits_in(max_count, coeff_count))
{
throw logic_error("invalid parameters");
}
if (encrypted1.correction_factor() != encrypted2.correction_factor())
{
// Balance correction factors and multiply by scalars before subtraction in BGV
auto factors = balance_correction_factors(
encrypted1.correction_factor(), encrypted2.correction_factor(), plain_modulus);
multiply_poly_scalar_coeffmod(
ConstPolyIter(encrypted1.data(), coeff_count, coeff_modulus_size), encrypted1.size(), get<1>(factors),
coeff_modulus, PolyIter(encrypted1.data(), coeff_count, coeff_modulus_size));
Ciphertext encrypted2_copy = encrypted2;
multiply_poly_scalar_coeffmod(
ConstPolyIter(encrypted2.data(), coeff_count, coeff_modulus_size), encrypted2.size(), get<2>(factors),
coeff_modulus, PolyIter(encrypted2_copy.data(), coeff_count, coeff_modulus_size));
// Set new correction factor
encrypted1.correction_factor() = get<0>(factors);
encrypted2_copy.correction_factor() = get<0>(factors);
sub_inplace(encrypted1, encrypted2_copy);
}
else
{
// Prepare destination
encrypted1.resize(context_, context_data.parms_id(), max_count);
// Subtract ciphertexts
sub_poly_coeffmod(encrypted1, encrypted2, min_count, coeff_modulus, encrypted1);
// If encrypted2 has larger count, negate remaining entries
if (encrypted1_size < encrypted2_size)
{
negate_poly_coeffmod(
iter(encrypted2) + min_count, encrypted2_size - min_count, coeff_modulus,
iter(encrypted1) + min_count);
}
}
#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT
// Transparent ciphertext output is not allowed.
if (encrypted1.is_transparent())
{
throw logic_error("result ciphertext is transparent");
}
#endif
}
void Evaluator::multiply_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted1, context_) || !is_buffer_valid(encrypted1))
{
throw invalid_argument("encrypted1 is not valid for encryption parameters");
}
if (!is_metadata_valid_for(encrypted2, context_) || !is_buffer_valid(encrypted2))
{
throw invalid_argument("encrypted2 is not valid for encryption parameters");
}
if (encrypted1.parms_id() != encrypted2.parms_id())
{
throw invalid_argument("encrypted1 and encrypted2 parameter mismatch");
}
auto context_data_ptr = context_.first_context_data();
switch (context_data_ptr->parms().scheme())
{
case scheme_type::bfv:
bfv_multiply(encrypted1, encrypted2, pool);
break;
case scheme_type::ckks:
ckks_multiply(encrypted1, encrypted2, pool);
break;
case scheme_type::bgv:
bgv_multiply(encrypted1, encrypted2, pool);
break;
default:
throw invalid_argument("unsupported scheme");
}
#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT
// Transparent ciphertext output is not allowed.
if (encrypted1.is_transparent())
{
throw logic_error("result ciphertext is transparent");
}
#endif
}
void Evaluator::bfv_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool) const
{
if (encrypted1.is_ntt_form() || encrypted2.is_ntt_form())
{
throw invalid_argument("encrypted1 or encrypted2 cannot be in NTT form");
}
// Extract encryption parameters.
auto &context_data = *context_.get_context_data(encrypted1.parms_id());
auto &parms = context_data.parms();
size_t coeff_count = parms.poly_modulus_degree();
size_t base_q_size = parms.coeff_modulus().size();
size_t encrypted1_size = encrypted1.size();
size_t encrypted2_size = encrypted2.size();
uint64_t plain_modulus = parms.plain_modulus().value();
auto rns_tool = context_data.rns_tool();
size_t base_Bsk_size = rns_tool->base_Bsk()->size();
size_t base_Bsk_m_tilde_size = rns_tool->base_Bsk_m_tilde()->size();
// Determine destination.size()
size_t dest_size = sub_safe(add_safe(encrypted1_size, encrypted2_size), size_t(1));
// Size check
if (!product_fits_in(dest_size, coeff_count, base_Bsk_m_tilde_size))
{
throw logic_error("invalid parameters");
}
// Set up iterators for bases
auto base_q = iter(parms.coeff_modulus());
auto base_Bsk = iter(rns_tool->base_Bsk()->base());
// Set up iterators for NTT tables
auto base_q_ntt_tables = iter(context_data.small_ntt_tables());
auto base_Bsk_ntt_tables = iter(rns_tool->base_Bsk_ntt_tables());
// Microsoft SEAL uses BEHZ-style RNS multiplication. This process is somewhat complex and consists of the
// following steps:
//
// (1) Lift encrypted1 and encrypted2 (initially in base q) to an extended base q U Bsk U {m_tilde}
// (2) Remove extra multiples of q from the results with Montgomery reduction, switching base to q U Bsk
// (3) Transform the data to NTT form
// (4) Compute the ciphertext polynomial product using dyadic multiplication
// (5) Transform the data back from NTT form
// (6) Multiply the result by t (plain_modulus)
// (7) Scale the result by q using a divide-and-floor algorithm, switching base to Bsk
// (8) Use Shenoy-Kumaresan method to convert the result to base q
// Resize encrypted1 to destination size
encrypted1.resize(context_, context_data.parms_id(), dest_size);
// This lambda function takes as input an IterTuple with three components:
//
// 1. (Const)RNSIter to read an input polynomial from
// 2. RNSIter for the output in base q
// 3. RNSIter for the output in base Bsk
//
// It performs steps (1)-(3) of the BEHZ multiplication (see above) on the given input polynomial (given as an
// RNSIter or ConstRNSIter) and writes the results in base q and base Bsk to the given output
// iterators.
auto behz_extend_base_convert_to_ntt = [&](auto I) {
// Make copy of input polynomial (in base q) and convert to NTT form
// Lazy reduction
set_poly(get<0>(I), coeff_count, base_q_size, get<1>(I));
ntt_negacyclic_harvey_lazy(get<1>(I), base_q_size, base_q_ntt_tables);
// Allocate temporary space for a polynomial in the Bsk U {m_tilde} base
SEAL_ALLOCATE_GET_RNS_ITER(temp, coeff_count, base_Bsk_m_tilde_size, pool);
// (1) Convert from base q to base Bsk U {m_tilde}
rns_tool->fastbconv_m_tilde(get<0>(I), temp, pool);
// (2) Reduce q-overflows in with Montgomery reduction, switching base to Bsk
rns_tool->sm_mrq(temp, get<2>(I), pool);
// Transform to NTT form in base Bsk
// Lazy reduction
ntt_negacyclic_harvey_lazy(get<2>(I), base_Bsk_size, base_Bsk_ntt_tables);
};
// Allocate space for a base q output of behz_extend_base_convert_to_ntt for encrypted1
SEAL_ALLOCATE_GET_POLY_ITER(encrypted1_q, encrypted1_size, coeff_count, base_q_size, pool);
// Allocate space for a base Bsk output of behz_extend_base_convert_to_ntt for encrypted1
SEAL_ALLOCATE_GET_POLY_ITER(encrypted1_Bsk, encrypted1_size, coeff_count, base_Bsk_size, pool);
// Perform BEHZ steps (1)-(3) for encrypted1
SEAL_ITERATE(iter(encrypted1, encrypted1_q, encrypted1_Bsk), encrypted1_size, behz_extend_base_convert_to_ntt);
// Repeat for encrypted2
SEAL_ALLOCATE_GET_POLY_ITER(encrypted2_q, encrypted2_size, coeff_count, base_q_size, pool);
SEAL_ALLOCATE_GET_POLY_ITER(encrypted2_Bsk, encrypted2_size, coeff_count, base_Bsk_size, pool);
SEAL_ITERATE(iter(encrypted2, encrypted2_q, encrypted2_Bsk), encrypted2_size, behz_extend_base_convert_to_ntt);
// Allocate temporary space for the output of step (4)
// We allocate space separately for the base q and the base Bsk components
SEAL_ALLOCATE_ZERO_GET_POLY_ITER(temp_dest_q, dest_size, coeff_count, base_q_size, pool);
SEAL_ALLOCATE_ZERO_GET_POLY_ITER(temp_dest_Bsk, dest_size, coeff_count, base_Bsk_size, pool);
// Perform BEHZ step (4): dyadic multiplication on arbitrary size ciphertexts
SEAL_ITERATE(iter(size_t(0)), dest_size, [&](auto I) {
// We iterate over relevant components of encrypted1 and encrypted2 in increasing order for
// encrypted1 and reversed (decreasing) order for encrypted2. The bounds for the indices of
// the relevant terms are obtained as follows.
size_t curr_encrypted1_last = min<size_t>(I, encrypted1_size - 1);
size_t curr_encrypted2_first = min<size_t>(I, encrypted2_size - 1);
size_t curr_encrypted1_first = I - curr_encrypted2_first;
// size_t curr_encrypted2_last = I - curr_encrypted1_last;
// The total number of dyadic products is now easy to compute
size_t steps = curr_encrypted1_last - curr_encrypted1_first + 1;
// This lambda function computes the ciphertext product for BFV multiplication. Since we use the BEHZ
// approach, the multiplication of individual polynomials is done using a dyadic product where the inputs
// are already in NTT form. The arguments of the lambda function are expected to be as follows:
//
// 1. a ConstPolyIter pointing to the beginning of the first input ciphertext (in NTT form)
// 2. a ConstPolyIter pointing to the beginning of the second input ciphertext (in NTT form)
// 3. a ConstModulusIter pointing to an array of Modulus elements for the base
// 4. the size of the base
// 5. a PolyIter pointing to the beginning of the output ciphertext
auto behz_ciphertext_product = [&](ConstPolyIter in1_iter, ConstPolyIter in2_iter,
ConstModulusIter base_iter, size_t base_size, PolyIter out_iter) {
// Create a shifted iterator for the first input
auto shifted_in1_iter = in1_iter + curr_encrypted1_first;
// Create a shifted reverse iterator for the second input
auto shifted_reversed_in2_iter = reverse_iter(in2_iter + curr_encrypted2_first);
// Create a shifted iterator for the output
auto shifted_out_iter = out_iter[I];
SEAL_ITERATE(iter(shifted_in1_iter, shifted_reversed_in2_iter), steps, [&](auto J) {
SEAL_ITERATE(iter(J, base_iter, shifted_out_iter), base_size, [&](auto K) {
SEAL_ALLOCATE_GET_COEFF_ITER(temp, coeff_count, pool);
dyadic_product_coeffmod(get<0, 0>(K), get<0, 1>(K), coeff_count, get<1>(K), temp);
add_poly_coeffmod(temp, get<2>(K), coeff_count, get<1>(K), get<2>(K));
});
});
};
// Perform the BEHZ ciphertext product both for base q and base Bsk
behz_ciphertext_product(encrypted1_q, encrypted2_q, base_q, base_q_size, temp_dest_q);
behz_ciphertext_product(encrypted1_Bsk, encrypted2_Bsk, base_Bsk, base_Bsk_size, temp_dest_Bsk);
});
// Perform BEHZ step (5): transform data from NTT form
// Lazy reduction here. The following multiply_poly_scalar_coeffmod will correct the value back to [0, p)
inverse_ntt_negacyclic_harvey_lazy(temp_dest_q, dest_size, base_q_ntt_tables);
inverse_ntt_negacyclic_harvey_lazy(temp_dest_Bsk, dest_size, base_Bsk_ntt_tables);
// Perform BEHZ steps (6)-(8)
SEAL_ITERATE(iter(temp_dest_q, temp_dest_Bsk, encrypted1), dest_size, [&](auto I) {
// Bring together the base q and base Bsk components into a single allocation
SEAL_ALLOCATE_GET_RNS_ITER(temp_q_Bsk, coeff_count, base_q_size + base_Bsk_size, pool);
// Step (6): multiply base q components by t (plain_modulus)
multiply_poly_scalar_coeffmod(get<0>(I), base_q_size, plain_modulus, base_q, temp_q_Bsk);
multiply_poly_scalar_coeffmod(get<1>(I), base_Bsk_size, plain_modulus, base_Bsk, temp_q_Bsk + base_q_size);
// Allocate yet another temporary for fast divide-and-floor result in base Bsk
SEAL_ALLOCATE_GET_RNS_ITER(temp_Bsk, coeff_count, base_Bsk_size, pool);
// Step (7): divide by q and floor, producing a result in base Bsk
rns_tool->fast_floor(temp_q_Bsk, temp_Bsk, pool);
// Step (8): use Shenoy-Kumaresan method to convert the result to base q and write to encrypted1
rns_tool->fastbconv_sk(temp_Bsk, get<2>(I), pool);
});
}
void Evaluator::ckks_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool) const
{
if (!(encrypted1.is_ntt_form() && encrypted2.is_ntt_form()))
{
throw invalid_argument("encrypted1 or encrypted2 must be in NTT form");
}
// Extract encryption parameters.
auto &context_data = *context_.get_context_data(encrypted1.parms_id());
auto &parms = context_data.parms();
size_t coeff_count = parms.poly_modulus_degree();
size_t coeff_modulus_size = parms.coeff_modulus().size();
size_t encrypted1_size = encrypted1.size();
size_t encrypted2_size = encrypted2.size();
// Determine destination.size()
// Default is 3 (c_0, c_1, c_2)
size_t dest_size = sub_safe(add_safe(encrypted1_size, encrypted2_size), size_t(1));
// Size check
if (!product_fits_in(dest_size, coeff_count, coeff_modulus_size))
{
throw logic_error("invalid parameters");
}
// Set up iterator for the base
auto coeff_modulus = iter(parms.coeff_modulus());
// Prepare destination
encrypted1.resize(context_, context_data.parms_id(), dest_size);
// Set up iterators for input ciphertexts
PolyIter encrypted1_iter = iter(encrypted1);
ConstPolyIter encrypted2_iter = iter(encrypted2);
if (dest_size == 3)
{
// We want to keep six polynomials in the L1 cache: x[0], x[1], x[2], y[0], y[1], temp.
// For a 32KiB cache, which can store 32768 / 8 = 4096 coefficients, = 682.67 coefficients per polynomial,
// we should keep the tile size at 682 or below. The tile size must divide coeff_count, i.e. be a power of
// two. Some testing shows similar performance with tile size 256 and 512, and worse performance on smaller
// tiles. We pick the smaller of the two to prevent L1 cache misses on processors with < 32 KiB L1 cache.
size_t tile_size = min<size_t>(coeff_count, size_t(256));
size_t num_tiles = coeff_count / tile_size;
#ifdef SEAL_DEBUG
if (coeff_count % tile_size != 0)
{
throw invalid_argument("tile_size does not divide coeff_count");
}
#endif
// Semantic misuse of RNSIter; each is really pointing to the data for each RNS factor in sequence
ConstRNSIter encrypted2_0_iter(*encrypted2_iter[0], tile_size);
ConstRNSIter encrypted2_1_iter(*encrypted2_iter[1], tile_size);
RNSIter encrypted1_0_iter(*encrypted1_iter[0], tile_size);
RNSIter encrypted1_1_iter(*encrypted1_iter[1], tile_size);
RNSIter encrypted1_2_iter(*encrypted1_iter[2], tile_size);
// Temporary buffer to store intermediate results
SEAL_ALLOCATE_GET_COEFF_ITER(temp, tile_size, pool);
// Computes the output tile_size coefficients at a time
// Given input tuples of polynomials x = (x[0], x[1], x[2]), y = (y[0], y[1]), computes
// x = (x[0] * y[0], x[0] * y[1] + x[1] * y[0], x[1] * y[1])
// with appropriate modular reduction
SEAL_ITERATE(coeff_modulus, coeff_modulus_size, [&](auto I) {
SEAL_ITERATE(iter(size_t(0)), num_tiles, [&](SEAL_MAYBE_UNUSED auto J) {
// Compute third output polynomial, overwriting input
// x[2] = x[1] * y[1]
dyadic_product_coeffmod(
encrypted1_1_iter[0], encrypted2_1_iter[0], tile_size, I, encrypted1_2_iter[0]);
// Compute second output polynomial, overwriting input
// temp = x[1] * y[0]
dyadic_product_coeffmod(encrypted1_1_iter[0], encrypted2_0_iter[0], tile_size, I, temp);
// x[1] = x[0] * y[1]
dyadic_product_coeffmod(
encrypted1_0_iter[0], encrypted2_1_iter[0], tile_size, I, encrypted1_1_iter[0]);
// x[1] += temp
add_poly_coeffmod(encrypted1_1_iter[0], temp, tile_size, I, encrypted1_1_iter[0]);
// Compute first output polynomial, overwriting input
// x[0] = x[0] * y[0]
dyadic_product_coeffmod(
encrypted1_0_iter[0], encrypted2_0_iter[0], tile_size, I, encrypted1_0_iter[0]);
// Manually increment iterators
encrypted1_0_iter++;
encrypted1_1_iter++;
encrypted1_2_iter++;
encrypted2_0_iter++;
encrypted2_1_iter++;
});
});
}
else
{
// Allocate temporary space for the result
SEAL_ALLOCATE_ZERO_GET_POLY_ITER(temp, dest_size, coeff_count, coeff_modulus_size, pool);
SEAL_ITERATE(iter(size_t(0)), dest_size, [&](auto I) {
// We iterate over relevant components of encrypted1 and encrypted2 in increasing order for
// encrypted1 and reversed (decreasing) order for encrypted2. The bounds for the indices of
// the relevant terms are obtained as follows.
size_t curr_encrypted1_last = min<size_t>(I, encrypted1_size - 1);
size_t curr_encrypted2_first = min<size_t>(I, encrypted2_size - 1);
size_t curr_encrypted1_first = I - curr_encrypted2_first;
// size_t curr_encrypted2_last = secret_power_index - curr_encrypted1_last;
// The total number of dyadic products is now easy to compute
size_t steps = curr_encrypted1_last - curr_encrypted1_first + 1;
// Create a shifted iterator for the first input
auto shifted_encrypted1_iter = encrypted1_iter + curr_encrypted1_first;
// Create a shifted reverse iterator for the second input
auto shifted_reversed_encrypted2_iter = reverse_iter(encrypted2_iter + curr_encrypted2_first);
SEAL_ITERATE(iter(shifted_encrypted1_iter, shifted_reversed_encrypted2_iter), steps, [&](auto J) {
// Extra care needed here:
// temp_iter must be dereferenced once to produce an appropriate RNSIter
SEAL_ITERATE(iter(J, coeff_modulus, temp[I]), coeff_modulus_size, [&](auto K) {
SEAL_ALLOCATE_GET_COEFF_ITER(prod, coeff_count, pool);
dyadic_product_coeffmod(get<0, 0>(K), get<0, 1>(K), coeff_count, get<1>(K), prod);
add_poly_coeffmod(prod, get<2>(K), coeff_count, get<1>(K), get<2>(K));
});
});
});
// Set the final result
set_poly_array(temp, dest_size, coeff_count, coeff_modulus_size, encrypted1.data());
}
// Set the scale
encrypted1.scale() *= encrypted2.scale();
if (!is_scale_within_bounds(encrypted1.scale(), context_data))
{
throw invalid_argument("scale out of bounds");
}
}
void Evaluator::bgv_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, MemoryPoolHandle pool) const
{
if (!encrypted1.is_ntt_form() || !encrypted2.is_ntt_form())
{
throw invalid_argument("encrypted1 or encrypted2 must be in NTT form");
}
// Extract encryption parameters.
auto &context_data = *context_.get_context_data(encrypted1.parms_id());
auto &parms = context_data.parms();
size_t coeff_count = parms.poly_modulus_degree();
size_t coeff_modulus_size = parms.coeff_modulus().size();
size_t encrypted1_size = encrypted1.size();
size_t encrypted2_size = encrypted2.size();
// Determine destination.size()
// Default is 3 (c_0, c_1, c_2)
size_t dest_size = sub_safe(add_safe(encrypted1_size, encrypted2_size), size_t(1));
// Set up iterator for the base
auto coeff_modulus = iter(parms.coeff_modulus());
// Prepare destination
encrypted1.resize(context_, context_data.parms_id(), dest_size);
// Convert c0 and c1 to ntt
// Set up iterators for input ciphertexts
PolyIter encrypted1_iter = iter(encrypted1);
ConstPolyIter encrypted2_iter = iter(encrypted2);
if (dest_size == 3)
{
// We want to keep six polynomials in the L1 cache: x[0], x[1], x[2], y[0], y[1], temp.
// For a 32KiB cache, which can store 32768 / 8 = 4096 coefficients, = 682.67 coefficients per polynomial,
// we should keep the tile size at 682 or below. The tile size must divide coeff_count, i.e. be a power of
// two. Some testing shows similar performance with tile size 256 and 512, and worse performance on smaller
// tiles. We pick the smaller of the two to prevent L1 cache misses on processors with < 32 KiB L1 cache.
size_t tile_size = min<size_t>(coeff_count, size_t(256));
size_t num_tiles = coeff_count / tile_size;
#ifdef SEAL_DEBUG
if (coeff_count % tile_size != 0)
{
throw invalid_argument("tile_size does not divide coeff_count");
}
#endif
// Semantic misuse of RNSIter; each is really pointing to the data for each RNS factor in sequence
ConstRNSIter encrypted2_0_iter(*encrypted2_iter[0], tile_size);
ConstRNSIter encrypted2_1_iter(*encrypted2_iter[1], tile_size);
RNSIter encrypted1_0_iter(*encrypted1_iter[0], tile_size);
RNSIter encrypted1_1_iter(*encrypted1_iter[1], tile_size);
RNSIter encrypted1_2_iter(*encrypted1_iter[2], tile_size);
// Temporary buffer to store intermediate results
SEAL_ALLOCATE_GET_COEFF_ITER(temp, tile_size, pool);
// Computes the output tile_size coefficients at a time
// Given input tuples of polynomials x = (x[0], x[1], x[2]), y = (y[0], y[1]), computes
// x = (x[0] * y[0], x[0] * y[1] + x[1] * y[0], x[1] * y[1])
// with appropriate modular reduction
SEAL_ITERATE(coeff_modulus, coeff_modulus_size, [&](auto I) {
SEAL_ITERATE(iter(size_t(0)), num_tiles, [&](SEAL_MAYBE_UNUSED auto J) {
// Compute third output polynomial, overwriting input
// x[2] = x[1] * y[1]
dyadic_product_coeffmod(
encrypted1_1_iter[0], encrypted2_1_iter[0], tile_size, I, encrypted1_2_iter[0]);
// Compute second output polynomial, overwriting input
// temp = x[1] * y[0]
dyadic_product_coeffmod(encrypted1_1_iter[0], encrypted2_0_iter[0], tile_size, I, temp);
// x[1] = x[0] * y[1]
dyadic_product_coeffmod(
encrypted1_0_iter[0], encrypted2_1_iter[0], tile_size, I, encrypted1_1_iter[0]);
// x[1] += temp
add_poly_coeffmod(encrypted1_1_iter[0], temp, tile_size, I, encrypted1_1_iter[0]);
// Compute first output polynomial, overwriting input
// x[0] = x[0] * y[0]
dyadic_product_coeffmod(
encrypted1_0_iter[0], encrypted2_0_iter[0], tile_size, I, encrypted1_0_iter[0]);
// Manually increment iterators
encrypted1_0_iter++;
encrypted1_1_iter++;
encrypted1_2_iter++;
encrypted2_0_iter++;
encrypted2_1_iter++;
});
});
}
else
{
// Allocate temporary space for the result
SEAL_ALLOCATE_ZERO_GET_POLY_ITER(temp, dest_size, coeff_count, coeff_modulus_size, pool);
SEAL_ITERATE(iter(size_t(0)), dest_size, [&](auto I) {
// We iterate over relevant components of encrypted1 and encrypted2 in increasing order for
// encrypted1 and reversed (decreasing) order for encrypted2. The bounds for the indices of
// the relevant terms are obtained as follows.
size_t curr_encrypted1_last = min<size_t>(I, encrypted1_size - 1);
size_t curr_encrypted2_first = min<size_t>(I, encrypted2_size - 1);
size_t curr_encrypted1_first = I - curr_encrypted2_first;
// size_t curr_encrypted2_last = secret_power_index - curr_encrypted1_last;
// The total number of dyadic products is now easy to compute
size_t steps = curr_encrypted1_last - curr_encrypted1_first + 1;
// Create a shifted iterator for the first input
auto shifted_encrypted1_iter = encrypted1_iter + curr_encrypted1_first;
// Create a shifted reverse iterator for the second input
auto shifted_reversed_encrypted2_iter = reverse_iter(encrypted2_iter + curr_encrypted2_first);
SEAL_ITERATE(iter(shifted_encrypted1_iter, shifted_reversed_encrypted2_iter), steps, [&](auto J) {
// Extra care needed here:
// temp_iter must be dereferenced once to produce an appropriate RNSIter
SEAL_ITERATE(iter(J, coeff_modulus, temp[I]), coeff_modulus_size, [&](auto K) {
SEAL_ALLOCATE_GET_COEFF_ITER(prod, coeff_count, pool);
dyadic_product_coeffmod(get<0, 0>(K), get<0, 1>(K), coeff_count, get<1>(K), prod);
add_poly_coeffmod(prod, get<2>(K), coeff_count, get<1>(K), get<2>(K));
});
});
});
// Set the final result
set_poly_array(temp, dest_size, coeff_count, coeff_modulus_size, encrypted1.data());
}
// Set the correction factor
encrypted1.correction_factor() =
multiply_uint_mod(encrypted1.correction_factor(), encrypted2.correction_factor(), parms.plain_modulus());
}
void Evaluator::square_inplace(Ciphertext &encrypted, MemoryPoolHandle pool) const
{
// Verify parameters.
if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted))
{
throw invalid_argument("encrypted is not valid for encryption parameters");
}
auto context_data_ptr = context_.first_context_data();
switch (context_data_ptr->parms().scheme())
{
case scheme_type::bfv:
bfv_square(encrypted, std::move(pool));
break;
case scheme_type::ckks:
ckks_square(encrypted, std::move(pool));
break;
case scheme_type::bgv:
bgv_square(encrypted, std::move(pool));
break;
default:
throw invalid_argument("unsupported scheme");
}
#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT
// Transparent ciphertext output is not allowed.
if (encrypted.is_transparent())
{
throw logic_error("result ciphertext is transparent");
}
#endif
}
void Evaluator::bfv_square(Ciphertext &encrypted, MemoryPoolHandle pool) const
{
if (encrypted.is_ntt_form())
{
throw invalid_argument("encrypted cannot be in NTT form");
}
// Extract encryption parameters.
auto &context_data = *context_.get_context_data(encrypted.parms_id());
auto &parms = context_data.parms();
size_t coeff_count = parms.poly_modulus_degree();
size_t base_q_size = parms.coeff_modulus().size();
size_t encrypted_size = encrypted.size();
uint64_t plain_modulus = parms.plain_modulus().value();
auto rns_tool = context_data.rns_tool();
size_t base_Bsk_size = rns_tool->base_Bsk()->size();
size_t base_Bsk_m_tilde_size = rns_tool->base_Bsk_m_tilde()->size();
// Optimization implemented currently only for size 2 ciphertexts
if (encrypted_size != 2)
{
bfv_multiply(encrypted, encrypted, std::move(pool));
return;
}
// Determine destination.size()
size_t dest_size = sub_safe(add_safe(encrypted_size, encrypted_size), size_t(1));
// Size check
if (!product_fits_in(dest_size, coeff_count, base_Bsk_m_tilde_size))
{
throw logic_error("invalid parameters");
}
// Set up iterators for bases
auto base_q = iter(parms.coeff_modulus());
auto base_Bsk = iter(rns_tool->base_Bsk()->base());
// Set up iterators for NTT tables
auto base_q_ntt_tables = iter(context_data.small_ntt_tables());
auto base_Bsk_ntt_tables = iter(rns_tool->base_Bsk_ntt_tables());
// Microsoft SEAL uses BEHZ-style RNS multiplication. For details, see Evaluator::bfv_multiply. This function
// uses additionally Karatsuba multiplication to reduce the complexity of squaring a size-2 ciphertext, but the
// steps are otherwise the same as in Evaluator::bfv_multiply.
// Resize encrypted to destination size
encrypted.resize(context_, context_data.parms_id(), dest_size);
// This lambda function takes as input an IterTuple with three components:
//
// 1. (Const)RNSIter to read an input polynomial from
// 2. RNSIter for the output in base q
// 3. RNSIter for the output in base Bsk
//
// It performs steps (1)-(3) of the BEHZ multiplication on the given input polynomial (given as an RNSIter
// or ConstRNSIter) and writes the results in base q and base Bsk to the given output iterators.
auto behz_extend_base_convert_to_ntt = [&](auto I) {
// Make copy of input polynomial (in base q) and convert to NTT form
// Lazy reduction
set_poly(get<0>(I), coeff_count, base_q_size, get<1>(I));
ntt_negacyclic_harvey_lazy(get<1>(I), base_q_size, base_q_ntt_tables);
// Allocate temporary space for a polynomial in the Bsk U {m_tilde} base
SEAL_ALLOCATE_GET_RNS_ITER(temp, coeff_count, base_Bsk_m_tilde_size, pool);
// (1) Convert from base q to base Bsk U {m_tilde}
rns_tool->fastbconv_m_tilde(get<0>(I), temp, pool);
// (2) Reduce q-overflows in with Montgomery reduction, switching base to Bsk
rns_tool->sm_mrq(temp, get<2>(I), pool);
// Transform to NTT form in base Bsk
// Lazy reduction
ntt_negacyclic_harvey_lazy(get<2>(I), base_Bsk_size, base_Bsk_ntt_tables);
};
// Allocate space for a base q output of behz_extend_base_convert_to_ntt
SEAL_ALLOCATE_GET_POLY_ITER(encrypted_q, encrypted_size, coeff_count, base_q_size, pool);
// Allocate space for a base Bsk output of behz_extend_base_convert_to_ntt
SEAL_ALLOCATE_GET_POLY_ITER(encrypted_Bsk, encrypted_size, coeff_count, base_Bsk_size, pool);
// Perform BEHZ steps (1)-(3)
SEAL_ITERATE(iter(encrypted, encrypted_q, encrypted_Bsk), encrypted_size, behz_extend_base_convert_to_ntt);
// Allocate temporary space for the output of step (4)
// We allocate space separately for the base q and the base Bsk components
SEAL_ALLOCATE_ZERO_GET_POLY_ITER(temp_dest_q, dest_size, coeff_count, base_q_size, pool);
SEAL_ALLOCATE_ZERO_GET_POLY_ITER(temp_dest_Bsk, dest_size, coeff_count, base_Bsk_size, pool);
// Perform BEHZ step (4): dyadic Karatsuba-squaring on size-2 ciphertexts
// This lambda function computes the size-2 ciphertext square for BFV multiplication. Since we use the BEHZ
// approach, the multiplication of individual polynomials is done using a dyadic product where the inputs
// are already in NTT form. The arguments of the lambda function are expected to be as follows:
//
// 1. a ConstPolyIter pointing to the beginning of the input ciphertext (in NTT form)
// 3. a ConstModulusIter pointing to an array of Modulus elements for the base
// 4. the size of the base
// 5. a PolyIter pointing to the beginning of the output ciphertext
auto behz_ciphertext_square = [&](ConstPolyIter in_iter, ConstModulusIter base_iter, size_t base_size,
PolyIter out_iter) {
// Compute c0^2
dyadic_product_coeffmod(in_iter[0], in_iter[0], base_size, base_iter, out_iter[0]);
// Compute 2*c0*c1
dyadic_product_coeffmod(in_iter[0], in_iter[1], base_size, base_iter, out_iter[1]);
add_poly_coeffmod(out_iter[1], out_iter[1], base_size, base_iter, out_iter[1]);
// Compute c1^2
dyadic_product_coeffmod(in_iter[1], in_iter[1], base_size, base_iter, out_iter[2]);
};
// Perform the BEHZ ciphertext square both for base q and base Bsk
behz_ciphertext_square(encrypted_q, base_q, base_q_size, temp_dest_q);
behz_ciphertext_square(encrypted_Bsk, base_Bsk, base_Bsk_size, temp_dest_Bsk);
// Perform BEHZ step (5): transform data from NTT form
inverse_ntt_negacyclic_harvey(temp_dest_q, dest_size, base_q_ntt_tables);
inverse_ntt_negacyclic_harvey(temp_dest_Bsk, dest_size, base_Bsk_ntt_tables);