diff --git a/spec/std/openssl/digest_io_spec.cr b/spec/std/openssl/digest_io_spec.cr new file mode 100644 index 000000000000..6b3fb3a450d8 --- /dev/null +++ b/spec/std/openssl/digest_io_spec.cr @@ -0,0 +1,82 @@ +require "spec" +require "../src/openssl" + +describe OpenSSL::DigestIO do + it "calculates digest from reading" do + base_io = IO::Memory.new("foo") + base_digest = OpenSSL::Digest.new("SHA256") + io = OpenSSL::DigestIO.new(base_io, base_digest) + slice = Bytes.new(256) + io.read(slice).should eq(3) + + slice[0, 3].should eq("foo".to_slice) + io.digest.should eq("2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae".hexbytes) + end + + it "calculates digest from multiple reads" do + base_io = IO::Memory.new("foo") + base_digest = OpenSSL::Digest.new("SHA256") + io = OpenSSL::DigestIO.new(base_io, base_digest) + slice = Bytes.new(2) + io.read(slice).should eq(2) + slice[0, 2].should eq("fo".to_slice) + + io.read(slice).should eq(1) + slice[0, 1].should eq("o".to_slice) + + io.digest.should eq("2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae".hexbytes) + end + + it "does not calculate digest on read" do + base_io = IO::Memory.new("foo") + base_digest = OpenSSL::Digest.new("SHA256") + empty_digest = OpenSSL::Digest.new("SHA256").digest + io = OpenSSL::DigestIO.new(base_io, base_digest, digest_on_read: false) + slice = Bytes.new(256) + io.read(slice).should eq(3) + slice[0, 3].should eq("foo".to_slice) + io.digest.should eq(empty_digest) + end + + it "calculates digest from writing" do + base_io = IO::Memory.new + base_digest = OpenSSL::Digest.new("SHA256") + io = OpenSSL::DigestIO.new(base_io, base_digest) + io.write("foo".to_slice) + + base_io.to_slice[0, 3].should eq("foo".to_slice) + io.digest.should eq("2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae".hexbytes) + end + + it "calculates digest from writing a string" do + base_io = IO::Memory.new + base_digest = OpenSSL::Digest.new("SHA256") + io = OpenSSL::DigestIO.new(base_io, base_digest) + io.print("foo") + + base_io.to_slice[0, 3].should eq("foo".to_slice) + io.digest.should eq("2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae".hexbytes) + end + + it "calculates digest from multiple writes" do + base_io = IO::Memory.new + base_digest = OpenSSL::Digest.new("SHA256") + io = OpenSSL::DigestIO.new(base_io, base_digest) + io.write("fo".to_slice) + io.write("o".to_slice) + base_io.to_slice[0, 3].should eq("foo".to_slice) + + io.digest.should eq("2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae".hexbytes) + end + + it "does not calculate digest on write" do + base_io = IO::Memory.new + base_digest = OpenSSL::Digest.new("SHA256") + empty_digest = OpenSSL::Digest.new("SHA256").digest + io = OpenSSL::DigestIO.new(base_io, base_digest, digest_on_write: false) + io.write("foo".to_slice) + + base_io.to_slice[0, 3].should eq("foo".to_slice) + io.digest.should eq(empty_digest) + end +end diff --git a/src/openssl/digest/digest_io.cr b/src/openssl/digest/digest_io.cr new file mode 100644 index 000000000000..e464581aa7eb --- /dev/null +++ b/src/openssl/digest/digest_io.cr @@ -0,0 +1,52 @@ +require "./digest_base" + +# Wraps an IO by calculating a specified digest on read and/or write operations +# +# ### Example +# +# ``` +# require "openssl" +# +# underlying_io = IO::Memory.new("foo") +# io = OpenSSL::DigestIO.new(underlying_io, "SHA256") +# buffer = Bytes.new(256) +# io.read(buffer) +# io.digest # => 2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +# ``` +# +module OpenSSL + class DigestIO + include IO + + getter io : IO + getter digest_algorithm : OpenSSL::Digest + getter digest_on_read : Bool + getter digest_on_write : Bool + + delegate close, closed?, flush, peek, tty?, rewind, to: @io + delegate digest, to: @digest_algorithm + delegate digest, hexdigest, base64digest, to: @digest_algorithm + + def initialize(@io : IO, @digest_algorithm : OpenSSL::Digest, *, @digest_on_read = true, @digest_on_write = true) + end + + def initialize(@io : IO, algorithm : String, *, @digest_on_read = true, @digest_on_write = true) + @digest_algorithm = OpenSSL::Digest.new(algorithm) + end + + def read(slice : Bytes) + read_bytes = io.read(slice) + if @digest_on_read + digest_algorithm.update(slice[0, read_bytes]) + end + read_bytes + end + + def write(slice : Bytes) + if @digest_on_write + digest_algorithm.update(slice) + end + io.write(slice) + end + end +end