diff --git a/examples/6_serialization.jl b/examples/6_serialization.jl new file mode 100644 index 0000000..b05de59 --- /dev/null +++ b/examples/6_serialization.jl @@ -0,0 +1,180 @@ +include("utilities.jl") + +using SEAL +using Printf + + +function example_serialization() + print_example_banner("Example: Serialization") + + parms_stream = UInt8[] + data_stream1 = UInt8[] + data_stream2 = UInt8[] + data_stream3 = UInt8[] + data_stream4 = UInt8[] + sk_stream = UInt8[] + + # Use `let` to create new variable scope to mimic curly braces-delimited blocks in C++ + + # Server + let + enc_parms = EncryptionParameters(SchemeType.CKKS) + poly_modulus_degree = 8192 + set_poly_modulus_degree!(enc_parms, poly_modulus_degree) + set_coeff_modulus!(enc_parms, coeff_modulus_create(poly_modulus_degree, [50, 20, 50])) + + resize!(parms_stream, save_size(enc_parms)) + out_bytes = save!(parms_stream, enc_parms) + resize!(parms_stream, out_bytes) + + print_line(@__LINE__) + println("EncryptionParameters: wrote ", out_bytes, " bytes") + + print_line(@__LINE__) + println("EncryptionParameters: data size upper bound (compr_mode_type::none): ", + save_size(ComprModeType.none, enc_parms)) + println(" EncryptionParameters: data size upper bound (compr_mode_type::deflate): ", + save_size(ComprModeType.deflate, enc_parms)) + + byte_buffer = Vector{UInt8}(undef, save_size(enc_parms)) + save!(byte_buffer, length(byte_buffer), enc_parms) + + enc_parms2 = EncryptionParameters() + load!(enc_parms2, byte_buffer, length(byte_buffer)) + + print_line(@__LINE__) + println("EncryptionParameters: parms == parms2: ", enc_parms == enc_parms2) + end + + # Client + let + enc_parms = EncryptionParameters() + load!(enc_parms, parms_stream) + + context = SEALContext(enc_parms) + + keygen = KeyGenerator(context) + pk = public_key(keygen) + sk = secret_key(keygen) + + resize!(sk_stream, save_size(sk)) + out_bytes = save!(sk_stream, sk) + resize!(sk_stream, out_bytes) + + rlk = relin_keys(keygen) + + resize!(data_stream1, save_size(rlk)) + size_rlk = save!(data_stream1, rlk) + resize!(data_stream1, size_rlk) + + rlk_local = relin_keys_local(keygen) + resize!(data_stream2, save_size(rlk_local)) + size_rlk_local = save!(data_stream2, rlk_local) + resize!(data_stream2, size_rlk_local) + + print_line(@__LINE__) + println("Serializable: wrote ", size_rlk, " bytes") + println(" ", "RelinKeys (local): wrote ", size_rlk_local, " bytes") + + initial_scale = 2.0^20 + encoder = CKKSEncoder(context) + plain1 = Plaintext() + plain2 = Plaintext() + encode!(plain1, 2.3, initial_scale, encoder) + encode!(plain2, 4.5, initial_scale, encoder) + + encryptor = Encryptor(context, pk) + encrypted1 = Ciphertext() + encrypted2 = Ciphertext() + encrypt!(encrypted1, plain1, encryptor) + encrypt!(encrypted2, plain2, encryptor) + + sym_encryptor = Encryptor(context, sk) + sym_encrypted1 = encrypt_symmetric(plain1, sym_encryptor) + sym_encrypted2 = encrypt_symmetric(plain2, sym_encryptor) + + resize!(data_stream2, save_size(sym_encrypted1)) + size_sym_encrypted1 = save!(data_stream2, sym_encrypted1) + resize!(data_stream2, size_sym_encrypted1) + resize!(data_stream3, save_size(encrypted1)) + size_encrypted1 = save!(data_stream3, encrypted1) + resize!(data_stream3, size_encrypted1) + + print_line(@__LINE__) + println("Serializable (symmetric-key): wrote ", size_sym_encrypted1, " bytes") + println(" ", "Ciphertext (public-key): wrote ", size_encrypted1, " bytes") + + resize!(data_stream3, save_size(sym_encrypted2)) + size_sym_encrypted2 = save!(data_stream3, sym_encrypted2) + resize!(data_stream3, size_sym_encrypted2) + end + + # Server + let + enc_parms = EncryptionParameters() + load!(enc_parms, parms_stream) + context = SEALContext(enc_parms) + + evaluator = Evaluator(context) + + rlk = RelinKeys() + encrypted1 = Ciphertext() + encrypted2 = Ciphertext() + + load!(rlk, context, data_stream1) + load!(encrypted1, context, data_stream2) + load!(encrypted2, context, data_stream3) + + encrypted_prod = Ciphertext() + multiply!(encrypted_prod, encrypted1, encrypted2, evaluator) + relinearize_inplace!(encrypted_prod, rlk, evaluator) + rescale_to_next_inplace!(encrypted_prod, evaluator) + + resize!(data_stream4, save_size(encrypted_prod)) + size_encrypted_prod = save!(data_stream4, encrypted_prod) + resize!(data_stream4, size_encrypted_prod) + + print_line(@__LINE__) + println("Ciphertext (symmetric-key): wrote ", size_encrypted_prod, " bytes") + end + + # Client + let + enc_parms = EncryptionParameters() + load!(enc_parms, parms_stream) + context = SEALContext(enc_parms) + + sk = SecretKey() + load!(sk, context, sk_stream) + decryptor = Decryptor(context, sk) + encoder = CKKSEncoder(context) + + encrypted_result = Ciphertext() + load!(encrypted_result, context, data_stream4) + + plain_result = Plaintext() + decrypt!(plain_result, encrypted_result, decryptor) + slot_count_ = slot_count(encoder) + result = Vector{Float64}(undef, slot_count_) + decode!(result, plain_result, encoder) + + print_line(@__LINE__) + println("Result: ") + print_vector(result, 3, 7) + end + + pt = Plaintext("1x^2 + 3") + stream = Vector{UInt8}(undef, save_size(pt)) + data_size = save!(stream, pt) + resize!(stream, data_size) + + header = SEALHeader() + load_header!(header, stream) + + print_line(@__LINE__) + println("Size written to stream: ", data_size, " bytes") + println(" ", "Size indicated in SEALHeader: ", header.size, " bytes") + println() + + return +end diff --git a/src/SEAL.jl b/src/SEAL.jl index 2329673..bff0b56 100644 --- a/src/SEAL.jl +++ b/src/SEAL.jl @@ -30,10 +30,14 @@ export version_major, version_minor, version_patch, version include("modulus.jl") export Modulus, SecLevelType, bit_count, value, coeff_modulus_create, coeff_modulus_bfv_default +include("serialization.jl") +export ComprModeType, SEALHeader, load_header! + include("encryptionparams.jl") export EncryptionParameters, SchemeType, get_poly_modulus_degree, set_poly_modulus_degree!, set_coeff_modulus!, coeff_modulus, - scheme, plain_modulus, set_plain_modulus!, plain_modulus_batching, parms_id + scheme, plain_modulus, set_plain_modulus!, plain_modulus_batching, parms_id, save!, + save_size, load! include("context.jl") export SEALContext, first_parms_id, last_parms_id, get_context_data, key_context_data, @@ -46,25 +50,25 @@ include("publickey.jl") export PublicKey, parms_id include("secretkey.jl") -export SecretKey, parms_id +export SecretKey, parms_id, save!, load! include("galoiskeys.jl") export GaloisKeys, parms_id include("relinkeys.jl") -export RelinKeys, parms_id +export RelinKeys, parms_id, save_size, save!, load! include("keygenerator.jl") export KeyGenerator, public_key, secret_key, relin_keys_local, relin_keys, galois_keys_local include("plaintext.jl") -export Plaintext, scale, scale!, parms_id, to_string +export Plaintext, scale, scale!, parms_id, to_string, save_size, save! include("ciphertext.jl") -export Ciphertext, scale, scale!, parms_id, size, length +export Ciphertext, scale, scale!, parms_id, size, length, save_size, save!, load! include("encryptor.jl") -export Encryptor, encrypt! +export Encryptor, encrypt!, encrypt_symmetric, encrypt_symmetric! include("evaluator.jl") export Evaluator, square!, square_inplace!, relinearize!, relinearize_inplace!, rescale_to_next!, diff --git a/src/ciphertext.jl b/src/ciphertext.jl index 0189f91..ff74020 100644 --- a/src/ciphertext.jl +++ b/src/ciphertext.jl @@ -68,3 +68,41 @@ function Base.size(encrypted::Ciphertext) end Base.length(encrypted::Ciphertext) = size(encrypted)[1] +function save_size(compr_mode, encrypted::Ciphertext) + result = Ref{Int64}(0) + retval = ccall((:Ciphertext_SaveSize, libsealc), Clong, + (Ptr{Cvoid}, UInt8, Ref{Int64}), + encrypted, compr_mode, result) + @check_return_value retval + return Int(result[]) +end +save_size(encrypted::Ciphertext) = save_size(ComprModeType.default, encrypted) + +function save!(buffer::DenseVector{UInt8}, length::Integer, + compr_mode::ComprModeType.ComprModeTypeEnum, encrypted::Ciphertext) + out_bytes = Ref{Int64}(0) + retval = ccall((:Ciphertext_Save, libsealc), Clong, + (Ptr{Cvoid}, Ref{UInt8}, UInt64, UInt8, Ref{Int64}), + encrypted, buffer, length, compr_mode, out_bytes) + @check_return_value retval + return Int(out_bytes[]) +end +function save!(buffer::DenseVector{UInt8}, length::Integer, encrypted::Ciphertext) + return save!(buffer, length, ComprModeType.default, encrypted) +end +function save!(buffer::DenseVector{UInt8}, encrypted::Ciphertext) + return save!(buffer, length(buffer), encrypted) +end + +function load!(encrypted::Ciphertext, context::SEALContext, buffer::DenseVector{UInt8}, length) + in_bytes = Ref{Int64}(0) + retval = ccall((:Ciphertext_Load, libsealc), Clong, + (Ptr{Cvoid}, Ptr{Cvoid}, Ref{UInt8}, UInt64, Ref{Int64}), + encrypted, context, buffer, length, in_bytes) + @check_return_value retval + return Int(in_bytes[]) +end +function load!(encrypted::Ciphertext, context::SEALContext, buffer::DenseVector{UInt8}) + return load!(encrypted, context, buffer, length(buffer)) +end + diff --git a/src/encryptionparams.jl b/src/encryptionparams.jl index 594888a..cb53812 100644 --- a/src/encryptionparams.jl +++ b/src/encryptionparams.jl @@ -23,7 +23,7 @@ See also: [`SEALContext`](@ref) mutable struct EncryptionParameters <: SEALObject handle::Ptr{Cvoid} - function EncryptionParameters(scheme::SchemeType.SchemeTypeEnum) + function EncryptionParameters(scheme::SchemeType.SchemeTypeEnum=SchemeType.none) handleref = Ref{Ptr{Cvoid}}(C_NULL) retval = ccall((:EncParams_Create1, libsealc), Clong, (UInt8, Ref{Ptr{Cvoid}}), @@ -131,3 +131,49 @@ function parms_id(enc_param::EncryptionParameters) return parms_id_ end +function save!(buffer::DenseVector{UInt8}, length::Integer, + compr_mode::ComprModeType.ComprModeTypeEnum, enc_param::EncryptionParameters) + out_bytes = Ref{Int64}(0) + retval = ccall((:EncParams_Save, libsealc), Clong, + (Ptr{Cvoid}, Ref{UInt8}, UInt64, UInt8, Ref{Int64}), + enc_param, buffer, length, compr_mode, out_bytes) + @check_return_value retval + return Int(out_bytes[]) +end +function save!(buffer::DenseVector{UInt8}, length::Integer, enc_param::EncryptionParameters) + return save!(buffer, length, ComprModeType.default, enc_param) +end +function save!(buffer::DenseVector{UInt8}, enc_param::EncryptionParameters) + return save!(buffer, length(buffer), enc_param) +end + +function save_size(compr_mode, enc_param::EncryptionParameters) + result = Ref{Int64}(0) + retval = ccall((:EncParams_SaveSize, libsealc), Clong, + (Ptr{Cvoid}, UInt8, Ref{Int64}), + enc_param, compr_mode, result) + @check_return_value retval + return Int(result[]) +end +save_size(enc_param::EncryptionParameters) = save_size(ComprModeType.default, enc_param) + +function load!(enc_param::EncryptionParameters, buffer::DenseVector{UInt8}, length) + in_bytes = Ref{Int64}(0) + retval = ccall((:EncParams_Load, libsealc), Clong, + (Ptr{Cvoid}, Ref{UInt8}, UInt64, Ref{Int64}), + enc_param, buffer, length, in_bytes) + @check_return_value retval + return Int(in_bytes[]) +end +load!(enc_param::EncryptionParameters, buffer::DenseVector{UInt8}) = load!(enc_param, buffer, + length(buffer)) + +function Base.:(==)(enc_param1::EncryptionParameters, enc_param2::EncryptionParameters) + result = Ref{UInt8}(0) + retval = ccall((:EncParams_Equals, libsealc), Clong, + (Ptr{Cvoid}, Ptr{Cvoid}, Ref{UInt8}), + enc_param1, enc_param1, result) + @check_return_value retval + return Bool(result[]) +end + diff --git a/src/encryptor.jl b/src/encryptor.jl index 90cbbf7..207ba23 100644 --- a/src/encryptor.jl +++ b/src/encryptor.jl @@ -54,3 +54,19 @@ function encrypt!(destination::Ciphertext, plain::Plaintext, encryptor::Encrypto return destination end +function encrypt_symmetric!(destination::Ciphertext, plain::Plaintext, encryptor::Encryptor) + retval = ccall((:Encryptor_EncryptSymmetric, libsealc), Clong, + (Ptr{Cvoid}, Ptr{Cvoid}, UInt8, Ptr{Cvoid}, Ptr{Cvoid}), + encryptor, plain, false, destination, C_NULL) + @check_return_value retval + return destination +end +function encrypt_symmetric(plain::Plaintext, encryptor::Encryptor) + destination = Ciphertext() + retval = ccall((:Encryptor_EncryptSymmetric, libsealc), Clong, + (Ptr{Cvoid}, Ptr{Cvoid}, UInt8, Ptr{Cvoid}, Ptr{Cvoid}), + encryptor, plain, true, destination, C_NULL) + @check_return_value retval + return destination +end + diff --git a/src/plaintext.jl b/src/plaintext.jl index 428f684..9bb8257 100644 --- a/src/plaintext.jl +++ b/src/plaintext.jl @@ -86,3 +86,29 @@ function to_string(plain::Plaintext) # Return as String but without terminating NULL byte return String(message[1:end-1]) end + +function save_size(compr_mode, plain::Plaintext) + result = Ref{Int64}(0) + retval = ccall((:Plaintext_SaveSize, libsealc), Clong, + (Ptr{Cvoid}, UInt8, Ref{Int64}), + plain, compr_mode, result) + @check_return_value retval + return Int(result[]) +end +save_size(plain::Plaintext) = save_size(ComprModeType.default, plain) + +function save!(buffer::DenseVector{UInt8}, length::Integer, + compr_mode::ComprModeType.ComprModeTypeEnum, plain::Plaintext) + out_bytes = Ref{Int64}(0) + retval = ccall((:Plaintext_Save, libsealc), Clong, + (Ptr{Cvoid}, Ref{UInt8}, UInt64, UInt8, Ref{Int64}), + plain, buffer, length, compr_mode, out_bytes) + @check_return_value retval + return Int(out_bytes[]) +end +function save!(buffer::DenseVector{UInt8}, length::Integer, plain::Plaintext) + return save!(buffer, length, ComprModeType.default, plain) +end +function save!(buffer::DenseVector{UInt8}, plain::Plaintext) + return save!(buffer, length(buffer), plain) +end diff --git a/src/relinkeys.jl b/src/relinkeys.jl index c314507..5c671d1 100644 --- a/src/relinkeys.jl +++ b/src/relinkeys.jl @@ -38,3 +38,42 @@ function parms_id(key::RelinKeys) return parms_id end +function save_size(compr_mode, key::RelinKeys) + result = Ref{Int64}(0) + retval = ccall((:KSwitchKeys_SaveSize, libsealc), Clong, + (Ptr{Cvoid}, UInt8, Ref{Int64}), + key, compr_mode, result) + @check_return_value retval + return Int(result[]) +end +save_size(key::RelinKeys) = save_size(ComprModeType.default, key) + +function save!(buffer::DenseVector{UInt8}, length::Integer, + compr_mode::ComprModeType.ComprModeTypeEnum, key::RelinKeys) + out_bytes = Ref{Int64}(0) + retval = ccall((:KSwitchKeys_Save, libsealc), Clong, + (Ptr{Cvoid}, Ref{UInt8}, UInt64, UInt8, Ref{Int64}), + key, buffer, length, compr_mode, out_bytes) + @check_return_value retval + return Int(out_bytes[]) +end +function save!(buffer::DenseVector{UInt8}, length::Integer, key::RelinKeys) + return save!(buffer, length, ComprModeType.default, key) +end +function save!(buffer::DenseVector{UInt8}, key::RelinKeys) + return save!(buffer, length(buffer), key) +end + +function load!(key::RelinKeys, context::SEALContext, buffer::DenseVector{UInt8}, length) + in_bytes = Ref{Int64}(0) + retval = ccall((:KSwitchKeys_Load, libsealc), Clong, + (Ptr{Cvoid}, Ptr{Cvoid}, Ref{UInt8}, UInt64, Ref{Int64}), + key, context, buffer, length, in_bytes) + @check_return_value retval + return Int(in_bytes[]) +end +load!(key::RelinKeys, context::SEALContext, buffer::DenseVector{UInt8}) = load!(key, + context, + buffer, + length(buffer)) + diff --git a/src/secretkey.jl b/src/secretkey.jl index 85de6ae..f2f2971 100644 --- a/src/secretkey.jl +++ b/src/secretkey.jl @@ -37,3 +37,41 @@ function parms_id(key::SecretKey) return parms_id end +function save!(buffer::DenseVector{UInt8}, length::Integer, + compr_mode::ComprModeType.ComprModeTypeEnum, key::SecretKey) + out_bytes = Ref{Int64}(0) + retval = ccall((:SecretKey_Save, libsealc), Clong, + (Ptr{Cvoid}, Ref{UInt8}, UInt64, UInt8, Ref{Int64}), + key, buffer, length, compr_mode, out_bytes) + @check_return_value retval + return Int(out_bytes[]) +end +function save!(buffer::DenseVector{UInt8}, length::Integer, key::SecretKey) + return save!(buffer, length, ComprModeType.default, key) +end +function save!(buffer::DenseVector{UInt8}, key::SecretKey) + return save!(buffer, length(buffer), key) +end + +function save_size(compr_mode, key::SecretKey) + result = Ref{Int64}(0) + retval = ccall((:SecretKey_SaveSize, libsealc), Clong, + (Ptr{Cvoid}, UInt8, Ref{Int64}), + key, compr_mode, result) + @check_return_value retval + return Int(result[]) +end +save_size(key::SecretKey) = save_size(ComprModeType.default, key) + +function load!(key::SecretKey, context::SEALContext, buffer::DenseVector{UInt8}, length) + in_bytes = Ref{Int64}(0) + retval = ccall((:SecretKey_Load, libsealc), Clong, + (Ptr{Cvoid}, Ptr{Cvoid}, Ref{UInt8}, UInt64, Ref{Int64}), + key, context, buffer, length, in_bytes) + @check_return_value retval + return Int(in_bytes[]) +end +function load!(key::SecretKey, context::SEALContext, buffer::DenseVector{UInt8}) + return load!(key, context, buffer, length(buffer)) +end + diff --git a/src/serialization.jl b/src/serialization.jl new file mode 100644 index 0000000..bb10a53 --- /dev/null +++ b/src/serialization.jl @@ -0,0 +1,29 @@ + +module ComprModeType +@enum ComprModeTypeEnum::UInt8 none=0 deflate=1 +const default = deflate +end + +mutable struct SEALHeader + magic::UInt16 + header_size::UInt8 + version_major::UInt8 + version_minor::UInt8 + compr_mode::UInt8 + reserved::UInt16 + size::UInt64 +end + +SEALHeader() = SEALHeader(0, 0, 0, 0, 0, 0, 0) + +function load_header!(header::SEALHeader, buffer::DenseVector{UInt8}) + io = IOBuffer(buffer) + header.magic = read(io, UInt16) + header.header_size = read(io, UInt8) + header.version_major = read(io, UInt8) + header.version_minor = read(io, UInt8) + header.compr_mode = read(io, UInt8) + header.reserved = read(io, UInt16) + header.size = read(io, UInt64) + return header +end diff --git a/test/runtests.jl b/test/runtests.jl index e6318b7..1798ff9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,4 +7,5 @@ include("test_2_encoders.jl") include("test_3_levels.jl") include("test_4_ckks_basics.jl") include("test_5_rotation.jl") +include("test_6_serialization.jl") include("test_extra.jl") diff --git a/test/test_6_serialization.jl b/test/test_6_serialization.jl new file mode 100644 index 0000000..aff034c --- /dev/null +++ b/test/test_6_serialization.jl @@ -0,0 +1,219 @@ +@testset "6_serialization" begin + parms_stream = UInt8[] + data_stream1 = UInt8[] + data_stream2 = UInt8[] + data_stream3 = UInt8[] + data_stream4 = UInt8[] + sk_stream = UInt8[] + + @testset "server (part 1)" begin + enc_parms = EncryptionParameters(SchemeType.CKKS) + poly_modulus_degree = 8192 + @testset "polynomial modulus degree" begin + @test_nowarn set_poly_modulus_degree!(enc_parms, poly_modulus_degree) + end + + @testset "coefficient modulus" begin + @test_nowarn set_coeff_modulus!(enc_parms, coeff_modulus_create(poly_modulus_degree, [50, 20, 50])) + end + + @testset "save! EncryptionParameters" begin + @test save_size(enc_parms) == 146 + resize!(parms_stream, save_size(enc_parms)) + @test save!(parms_stream, enc_parms) == 60 + out_bytes = 60 + resize!(parms_stream, out_bytes) + end + + @testset "save_size comparison" begin + @test save_size(ComprModeType.none, enc_parms) == 129 + @test save_size(ComprModeType.deflate, enc_parms) == 146 + end + + @testset "save! and load! EncryptionParameters" begin + byte_buffer = Vector{UInt8}(undef, save_size(enc_parms)) + @test save!(byte_buffer, length(byte_buffer), enc_parms) == 60 + + enc_parms2 = EncryptionParameters() + @test load!(enc_parms2, byte_buffer, length(byte_buffer)) == 60 + @test enc_parms == enc_parms2 + end + end + + @testset "client (part 1)" begin + enc_parms = EncryptionParameters() + @testset "load! EncryptionParameters" begin + @test load!(enc_parms, parms_stream) == 60 + end + + context = SEALContext(enc_parms) + keygen = KeyGenerator(context) + pk = public_key(keygen) + sk = secret_key(keygen) + + @testset "save! SecretKey" begin + @test save_size(sk) == 196773 + resize!(sk_stream, save_size(sk)) + @test isapprox(save!(sk_stream, sk), 148752, rtol=0.001) + out_bytes = save!(sk_stream, sk) + resize!(sk_stream, out_bytes) + end + + rlk = relin_keys(keygen) + @testset "save! relin_keys" begin + @test save_size(rlk) == 393755 + resize!(data_stream1, save_size(rlk)) + @test isapprox(save!(data_stream1, rlk), 297521, rtol=0.001) + size_rlk = save!(data_stream1, rlk) + resize!(data_stream1, size_rlk) + end + + rlk_local = relin_keys_local(keygen) + @testset "save! relin_keys_local" begin + @test save_size(rlk_local) == 786963 + resize!(data_stream2, save_size(rlk_local)) + @test isapprox(save!(data_stream2, rlk_local), 593391, rtol=0.001) + size_rlk_local = save!(data_stream2, rlk_local) + resize!(data_stream2, size_rlk_local) + end + + initial_scale = 2.0^20 + encoder = CKKSEncoder(context) + plain1 = Plaintext() + plain2 = Plaintext() + + @testset "encode!" begin + @test_nowarn encode!(plain1, 2.3, initial_scale, encoder) + @test_nowarn encode!(plain2, 4.5, initial_scale, encoder) + end + + encryptor = Encryptor(context, pk) + encrypted1 = Ciphertext() + encrypted2 = Ciphertext() + @testset "encrypt!" begin + @test_nowarn encrypt!(encrypted1, plain1, encryptor) + @test_nowarn encrypt!(encrypted2, plain2, encryptor) + end + + @testset "symmetric encryptor" begin + @test_nowarn Encryptor(context, sk) + end + sym_encryptor = Encryptor(context, sk) + + @testset "encrypt_symmetric, encrypt_symmetric!" begin + @test encrypt_symmetric(plain1, sym_encryptor) isa Ciphertext + @test encrypt_symmetric(plain2, sym_encryptor) isa Ciphertext + c = Ciphertext() + @test encrypt_symmetric!(c, plain2, sym_encryptor) == c + end + sym_encrypted1 = encrypt_symmetric(plain1, sym_encryptor) + sym_encrypted2 = encrypt_symmetric(plain2, sym_encryptor) + + @testset "save! Ciphertext" begin + @test save_size(sym_encrypted1) == 131298 + resize!(data_stream2, save_size(sym_encrypted1)) + @test isapprox(save!(data_stream2, sym_encrypted1), 88528, rtol=0.001) + size_sym_encrypted1 = save!(data_stream2, sym_encrypted1) + resize!(data_stream2, size_sym_encrypted1) + + @test save_size(encrypted1) == 262346 + resize!(data_stream3, save_size(encrypted1)) + @test isapprox(save!(data_stream3, encrypted1), 177295, rtol=0.001) + size_encrypted1 = save!(data_stream3, encrypted1) + resize!(data_stream3, size_encrypted1) + + @test save_size(sym_encrypted2) == 131298 + resize!(data_stream3, save_size(sym_encrypted2)) + @test isapprox(save!(data_stream3, sym_encrypted2), 88467, rtol=0.001) + size_sym_encrypted2 = save!(data_stream3, sym_encrypted2) + resize!(data_stream3, size_sym_encrypted2) + end + end + + @testset "server (part 2)" begin + enc_parms = EncryptionParameters() + @testset "load! EncryptionParameters" begin + @test load!(enc_parms, parms_stream) == 60 + end + + context = SEALContext(enc_parms) + evaluator = Evaluator(context) + rlk = RelinKeys() + encrypted1 = Ciphertext() + encrypted2 = Ciphertext() + + @testset "load! RelinKeys" begin + @test isapprox(load!(rlk, context, data_stream1), 297640, rtol=0.001) + end + + @testset "load! Ciphertext" begin + @test isapprox(load!(encrypted1, context, data_stream2), 88513, rtol=0.001) + @test isapprox(load!(encrypted2, context, data_stream3), 88464, rtol=0.001) + end + + encrypted_prod = Ciphertext() + @testset "multiply, relinearize, rescale" begin + @test multiply!(encrypted_prod, encrypted1, encrypted2, evaluator) == encrypted_prod + @test relinearize_inplace!(encrypted_prod, rlk, evaluator) == encrypted_prod + @test rescale_to_next_inplace!(encrypted_prod, evaluator) == encrypted_prod + end + + @testset "save! Ciphertext" begin + @test save_size(encrypted_prod) == 131234 + resize!(data_stream4, save_size(encrypted_prod)) + @test isapprox(save!(data_stream4, encrypted_prod), 119229, rtol=0.001) + size_encrypted_prod = save!(data_stream4, encrypted_prod) + resize!(data_stream4, size_encrypted_prod) + end + end + + @testset "client (part 2)" begin + enc_parms = EncryptionParameters() + load!(enc_parms, parms_stream) + context = SEALContext(enc_parms) + + sk = SecretKey() + @testset "load! SecretKey" begin + @test isapprox(load!(sk, context, sk_stream), 148772, rtol=0.001) + end + + decryptor = Decryptor(context, sk) + encoder = CKKSEncoder(context) + encrypted_result = Ciphertext() + @testset "load! Ciphertext" begin + @test_nowarn load!(encrypted_result, context, data_stream4) + end + + plain_result = Plaintext() + @testset "decrypt!" begin + @test_nowarn decrypt!(plain_result, encrypted_result, decryptor) + end + + slot_count_ = slot_count(encoder) + result = Vector{Float64}(undef, slot_count_) + @testset "decode! and check result" begin + @test_nowarn decode!(result, plain_result, encoder) + @test isapprox(result[1], 10.35, rtol=0.001) + @test isapprox(result[2], 10.35, rtol=0.001) + @test isapprox(result[3], 10.35, rtol=0.001) + @test isapprox(result[end-2], 10.35, rtol=0.001) + @test isapprox(result[end-1], 10.35, rtol=0.001) + @test isapprox(result[end-0], 10.35, rtol=0.001) + end + end + + pt = Plaintext("1x^2 + 3") + stream = Vector{UInt8}(undef, save_size(pt)) + @testset "save! Plaintext" begin + @test save!(stream, pt) == 49 + data_size = 49 + resize!(stream, data_size) + end + + header = SEALHeader() + @testset "load_header!" begin + @test load_header!(header, stream) == header + @test header.size == 49 + end +end +