Skip to content

Commit

Permalink
BT-732 Checksum validation for blobs read by engine (#6838)
Browse files Browse the repository at this point in the history
* Draft support for optional FileHash

* Draft getMd5 for BlobPath

* Resolve non-parallel IO to fix tests

* Checksum validation for BlobPath

* Nicer error message

* Test for missing Blob hash

* Break attr acquisition into separate method

* Cleanup, comments

* In-progress tests of blob hash command

* Remove test

* Remove unused import
  • Loading branch information
jgainerdewar authored Sep 9, 2022
1 parent 54fed3e commit f289382
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 14 deletions.
47 changes: 36 additions & 11 deletions engine/src/main/scala/cromwell/engine/io/nio/NioFlow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import cromwell.core.path.Path
import cromwell.engine.io.IoActor._
import cromwell.engine.io.RetryableRequestSupport.{isInfinitelyRetryable, isRetryable}
import cromwell.engine.io.{IoAttempts, IoCommandContext, IoCommandStalenessBackpressuring}
import cromwell.filesystems.blob.BlobPath
import cromwell.filesystems.drs.DrsPath
import cromwell.filesystems.gcs.GcsPath
import cromwell.filesystems.s3.S3Path
Expand Down Expand Up @@ -128,21 +129,33 @@ class NioFlow(parallelism: Int,

def readFileAndChecksum: IO[String] = {
for {
fileHash <- getHash(command.file)
fileHash <- getStoredHash(command.file)
uncheckedValue <- readFile
checksumResult <- checkHash(uncheckedValue, fileHash)
checksumResult <- fileHash match {
case Some(hash) => checkHash(uncheckedValue, hash)
// If there is no stored checksum, don't attempt to validate.
// If the missing checksum is itself an error condition, that
// should be detected by the code that gets the FileHash.
case None => IO.pure(ChecksumSkipped())
}
verifiedValue <- checksumResult match {
case _: ChecksumSkipped => IO.pure(uncheckedValue)
case _: ChecksumSuccess => IO.pure(uncheckedValue)
case failure: ChecksumFailure => IO.raiseError(
ChecksumFailedException(
s"Failed checksum for '${command.file}'. Expected '${fileHash.hashType}' hash of '${fileHash.hash}'. Calculated hash '${failure.calculatedHash}'"))
fileHash match {
case Some(hash) => s"Failed checksum for '${command.file}'. Expected '${hash.hashType}' hash of '${hash.hash}'. Calculated hash '${failure.calculatedHash}'"
case None => s"Failed checksum for '${command.file}'. Couldn't find stored file hash." // This should never happen
}
)
)
}
} yield verifiedValue
}

val fileContentIo = command.file match {
case _: DrsPath => readFileAndChecksum
case _: DrsPath => readFileAndChecksum
case _: BlobPath => readFileAndChecksum
case _ => readFile
}
fileContentIo.map(_.replaceAll("\\r\\n", "\\\n"))
Expand All @@ -153,19 +166,27 @@ class NioFlow(parallelism: Int,
}

private def hash(hash: IoHashCommand): IO[String] = {
getHash(hash.file).map(_.hash)
// If there is no hash accessible from the file storage system,
// we'll read the file and generate the hash ourselves.
getStoredHash(hash.file).flatMap {
case Some(storedHash) => IO.pure(storedHash)
case None => generateMd5FileHashForPath(hash.file)
}.map(_.hash)
}

private def getHash(file: Path): IO[FileHash] = {
private def getStoredHash(file: Path): IO[Option[FileHash]] = {
file match {
case gcsPath: GcsPath => getFileHashForGcsPath(gcsPath)
case gcsPath: GcsPath => getFileHashForGcsPath(gcsPath).map(Option(_))
case blobPath: BlobPath => getFileHashForBlobPath(blobPath)
case drsPath: DrsPath => IO {
// We assume all DRS files have a stored hash; this will throw
// if the file does not.
drsPath.getFileHash
}
}.map(Option(_))
case s3Path: S3Path => IO {
FileHash(HashType.S3Etag, s3Path.eTag)
Option(FileHash(HashType.S3Etag, s3Path.eTag))
}
case path => getMd5FileHashForPath(path)
case _ => IO.pure(None)
}
}

Expand Down Expand Up @@ -201,7 +222,11 @@ class NioFlow(parallelism: Int,
gcsPath.objectBlobId.map(id => FileHash(HashType.GcsCrc32c, gcsPath.cloudStorage.get(id).getCrc32c))
}

private def getMd5FileHashForPath(path: Path): IO[FileHash] = delayedIoFromTry {
private def getFileHashForBlobPath(blobPath: BlobPath): IO[Option[FileHash]] = delayedIoFromTry {
blobPath.md5HexString.map(md5 => md5.map(FileHash(HashType.Md5, _)))
}

private def generateMd5FileHashForPath(path: Path): IO[FileHash] = delayedIoFromTry {
tryWithResource(() => path.newInputStream) { inputStream =>
FileHash(HashType.Md5, org.apache.commons.codec.digest.DigestUtils.md5Hex(inputStream))
}
Expand Down
39 changes: 38 additions & 1 deletion engine/src/test/scala/cromwell/engine/io/nio/NioFlowSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ import org.mockito.Mockito.{times, verify, when}
import org.scalatest.flatspec.AsyncFlatSpecLike
import org.scalatest.matchers.should.Matchers
import common.mock.MockSugar
import cromwell.filesystems.blob.BlobPath

import java.nio.file.NoSuchFileException
import java.util.UUID
import scala.concurrent.ExecutionContext
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Failure
import scala.util.{Failure, Success, Try}
import scala.util.control.NoStackTrace

class NioFlowSpec extends TestKitSuite with AsyncFlatSpecLike with Matchers with MockSugar {
Expand Down Expand Up @@ -127,6 +128,23 @@ class NioFlowSpec extends TestKitSuite with AsyncFlatSpecLike with Matchers with
}
}

it should "get hash from a BlobPath when stored hash exists" in {
val testPath = mock[BlobPath]
val hashString = "2d01d5d9c24034d54fe4fba0ede5182d" // echo "hello there" | md5sum
testPath.md5HexString returns Try(Option(hashString))

val context = DefaultCommandContext(hashCommand(testPath).get, replyTo)
val testSource = Source.single(context)

val stream = testSource.via(flow).toMat(readSink)(Keep.right)

stream.run() map {
case (success: IoSuccess[_], _) => assert(success.result.asInstanceOf[String] == hashString)
case (ack, _) =>
fail(s"read returned an unexpected message:\n$ack\n\n")
}
}

it should "fail if DrsPath hash doesn't match checksum" in {
val testPath = mock[DrsPath]
when(testPath.limitFileContent(any[Option[Int]], any[Boolean])(any[ExecutionContext])).thenReturn("hello".getBytes)
Expand Down Expand Up @@ -171,6 +189,25 @@ class NioFlowSpec extends TestKitSuite with AsyncFlatSpecLike with Matchers with
}
}

it should "succeed if a BlobPath is missing a stored hash" in {
val testPath = mock[BlobPath]
when(testPath.limitFileContent(any[Option[Int]], any[Boolean])(any[ExecutionContext]))
.thenReturn("hello there".getBytes)
when(testPath.md5HexString)
.thenReturn(Success(None))

val context = DefaultCommandContext(contentAsStringCommand(testPath, Option(100), failOnOverflow = true).get, replyTo)
val testSource = Source.single(context)

val stream = testSource.via(flow).toMat(readSink)(Keep.right)

stream.run() map {
case (success: IoSuccess[_], _) => assert(success.result.asInstanceOf[String] == "hello there")
case (ack, _) =>
fail(s"read returned an unexpected message:\n$ack\n\n")
}
}

it should "copy Nio paths" in {
val testPath = DefaultPathBuilder.createTempFile()
val testCopyPath = testPath.sibling(UUID.randomUUID().toString)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package cromwell.filesystems.blob

import com.azure.core.credential.AzureSasCredential
import com.azure.storage.blob.nio.AzureFileSystem
import com.azure.storage.blob.nio.{AzureBlobFileAttributes, AzureFileSystem}
import com.google.common.net.UrlEscapers
import cromwell.core.path.{NioPath, Path, PathBuilder}
import cromwell.filesystems.blob.BlobPathBuilder._

import java.net.{MalformedURLException, URI}
import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems}
import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems, Files}
import scala.jdk.CollectionConverters._
import scala.language.postfixOps
import scala.util.{Failure, Try}
Expand Down Expand Up @@ -90,4 +90,19 @@ case class BlobPath private[blob](nioPath: NioPath, endpoint: String, container:
override def pathAsString: String = List(endpoint, container, nioPath.toString()).mkString("/")

override def pathWithoutScheme: String = parseURI(endpoint).getHost + "/" + container + "/" + nioPath.toString()

def blobFileAttributes: Try[AzureBlobFileAttributes] =
Try(Files.readAttributes(nioPath, classOf[AzureBlobFileAttributes]))

def md5HexString: Try[Option[String]] = {
blobFileAttributes.map(h =>
Option(h.blobHttpHeaders().getContentMd5) match {
case None => None
case Some(arr) if arr.isEmpty => None
// Convert the bytes to a hex-encoded string. Note that this value
// is rendered in base64 in the Azure web portal.
case Some(bytes) => Option(bytes.map("%02x".format(_)).mkString)
}
)
}
}

0 comments on commit f289382

Please sign in to comment.