diff --git a/spec/std/openssl/digest_io_spec.cr b/spec/std/openssl/digest_io_spec.cr index 6b3fb3a450d8..e57035a96365 100644 --- a/spec/std/openssl/digest_io_spec.cr +++ b/spec/std/openssl/digest_io_spec.cr @@ -31,7 +31,7 @@ describe OpenSSL::DigestIO do base_io = IO::Memory.new("foo") base_digest = OpenSSL::Digest.new("SHA256") empty_digest = OpenSSL::Digest.new("SHA256").digest - io = OpenSSL::DigestIO.new(base_io, base_digest, digest_on_read: false) + io = OpenSSL::DigestIO.new(base_io, base_digest, OpenSSL::DigestIO::DigestMode::Write) slice = Bytes.new(256) io.read(slice).should eq(3) slice[0, 3].should eq("foo".to_slice) @@ -41,7 +41,7 @@ describe OpenSSL::DigestIO do it "calculates digest from writing" do base_io = IO::Memory.new base_digest = OpenSSL::Digest.new("SHA256") - io = OpenSSL::DigestIO.new(base_io, base_digest) + io = OpenSSL::DigestIO.new(base_io, base_digest, OpenSSL::DigestIO::DigestMode::Write) io.write("foo".to_slice) base_io.to_slice[0, 3].should eq("foo".to_slice) @@ -51,7 +51,7 @@ describe OpenSSL::DigestIO do it "calculates digest from writing a string" do base_io = IO::Memory.new base_digest = OpenSSL::Digest.new("SHA256") - io = OpenSSL::DigestIO.new(base_io, base_digest) + io = OpenSSL::DigestIO.new(base_io, base_digest, OpenSSL::DigestIO::DigestMode::Write) io.print("foo") base_io.to_slice[0, 3].should eq("foo".to_slice) @@ -61,7 +61,7 @@ describe OpenSSL::DigestIO do it "calculates digest from multiple writes" do base_io = IO::Memory.new base_digest = OpenSSL::Digest.new("SHA256") - io = OpenSSL::DigestIO.new(base_io, base_digest) + io = OpenSSL::DigestIO.new(base_io, base_digest, OpenSSL::DigestIO::DigestMode::Write) io.write("fo".to_slice) io.write("o".to_slice) base_io.to_slice[0, 3].should eq("foo".to_slice) @@ -73,7 +73,7 @@ describe OpenSSL::DigestIO do base_io = IO::Memory.new base_digest = OpenSSL::Digest.new("SHA256") empty_digest = OpenSSL::Digest.new("SHA256").digest - io = OpenSSL::DigestIO.new(base_io, base_digest, digest_on_write: false) + io = OpenSSL::DigestIO.new(base_io, base_digest, OpenSSL::DigestIO::DigestMode::Read) io.write("foo".to_slice) base_io.to_slice[0, 3].should eq("foo".to_slice) diff --git a/src/openssl/digest/digest_io.cr b/src/openssl/digest/digest_io.cr index e464581aa7eb..9eb6bd216feb 100644 --- a/src/openssl/digest/digest_io.cr +++ b/src/openssl/digest/digest_io.cr @@ -1,6 +1,6 @@ require "./digest_base" -# Wraps an IO by calculating a specified digest on read and/or write operations +# Wraps an IO by calculating a specified digest on read or write operations # # ### Example # @@ -20,30 +20,33 @@ module OpenSSL getter io : IO getter digest_algorithm : OpenSSL::Digest - getter digest_on_read : Bool - getter digest_on_write : Bool + getter mode : DigestMode delegate close, closed?, flush, peek, tty?, rewind, to: @io - delegate digest, to: @digest_algorithm delegate digest, hexdigest, base64digest, to: @digest_algorithm - def initialize(@io : IO, @digest_algorithm : OpenSSL::Digest, *, @digest_on_read = true, @digest_on_write = true) + enum DigestMode + Read, + Write end - def initialize(@io : IO, algorithm : String, *, @digest_on_read = true, @digest_on_write = true) + def initialize(@io : IO, @digest_algorithm : OpenSSL::Digest, @mode = DigestMode::Read) + end + + def initialize(@io : IO, algorithm : String, @mode = DigestMode::Read) @digest_algorithm = OpenSSL::Digest.new(algorithm) end def read(slice : Bytes) read_bytes = io.read(slice) - if @digest_on_read + if @mode == DigestMode::Read digest_algorithm.update(slice[0, read_bytes]) end read_bytes end def write(slice : Bytes) - if @digest_on_write + if @mode == DigestMode::Write digest_algorithm.update(slice) end io.write(slice)