Skip to content

Commit

Permalink
WX-1078 ACR support (#7192)
Browse files Browse the repository at this point in the history
  • Loading branch information
aednichols authored Aug 9, 2023
1 parent d616608 commit f469ecb
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 8 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ lazy val databaseMigration = (project in file("database/migration"))

lazy val dockerHashing = project
.withLibrarySettings("cromwell-docker-hashing", dockerHashingDependencies)
.dependsOn(cloudSupport)
.dependsOn(core)
.dependsOn(core % "test->test")
.dependsOn(common % "test->test")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package cromwell.filesystems.blob
package cromwell.cloudsupport.azure

import cats.implicits.catsSyntaxValidatedId
import com.azure.core.credential.TokenRequestContext
Expand All @@ -9,7 +9,6 @@ import common.validation.ErrorOr.ErrorOr

import scala.concurrent.duration._
import scala.jdk.DurationConverters._

import scala.util.{Failure, Success, Try}

/**
Expand Down
9 changes: 9 additions & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,15 @@ docker {
max-retries = 3

// Supported registries (Docker Hub, Google, Quay) can have additional configuration set separately
azure {
// Worst case `ReadOps per minute` value from official docs
// https://github.com/MicrosoftDocs/azure-docs/blob/main/includes/container-registry-limits.md
throttle {
number-of-requests = 1000
per = 60 seconds
}
num-threads = 10
}
google {
// Example of how to configure throttling, available for all supported registries
throttle {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package cromwell.docker

import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry

import scala.util.{Failure, Success, Try}

sealed trait DockerImageIdentifier {
Expand All @@ -14,7 +16,14 @@ sealed trait DockerImageIdentifier {
lazy val name = repository map { r => s"$r/$image" } getOrElse image
// The name of the image with a repository prefix if a repository was specified, or with a default repository prefix of
// "library" if no repository was specified.
lazy val nameWithDefaultRepository = repository.getOrElse("library") + s"/$image"
lazy val nameWithDefaultRepository = {
// In ACR, the repository is part of the registry domain instead of the path
// e.g. `terrabatchdev.azurecr.io`
if (host.exists(_.contains(AzureContainerRegistry.domain)))
image
else
repository.getOrElse("library") + s"/$image"
}
lazy val hostAsString = host map { h => s"$h/" } getOrElse ""
// The full name of this image, including a repository prefix only if a repository was explicitly specified.
lazy val fullName = s"$hostAsString$name:$reference"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import cromwell.core.actor.StreamIntegration.{BackPressure, StreamContext}
import cromwell.core.{Dispatcher, DockerConfiguration}
import cromwell.docker.DockerInfoActor._
import cromwell.docker.registryv2.DockerRegistryV2Abstract
import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry
import cromwell.docker.registryv2.flows.dockerhub.DockerHubRegistry
import cromwell.docker.registryv2.flows.google.GoogleRegistry
import cromwell.docker.registryv2.flows.quay.QuayRegistry
Expand Down Expand Up @@ -232,6 +233,7 @@ object DockerInfoActor {

// To add a new registry, simply add it to that list
List(
("azure", { c: DockerRegistryConfig => new AzureContainerRegistry(c) }),
("dockerhub", { c: DockerRegistryConfig => new DockerHubRegistry(c) }),
("google", { c: DockerRegistryConfig => new GoogleRegistry(c) }),
("quay", { c: DockerRegistryConfig => new QuayRegistry(c) })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi
}

// Execute a request. No retries because they're expected to already be handled by the client
private def executeRequest[A](request: IO[Request[IO]], handler: Response[IO] => IO[A])(implicit client: Client[IO]): IO[A] = {
protected def executeRequest[A](request: IO[Request[IO]], handler: Response[IO] => IO[A])(implicit client: Client[IO]): IO[A] = {
request.flatMap(client.run(_).use[IO, A](handler))
}

Expand Down Expand Up @@ -188,7 +188,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi
/**
* Builds the token request
*/
private def buildTokenRequest(dockerInfoContext: DockerInfoContext): IO[Request[IO]] = {
protected def buildTokenRequest(dockerInfoContext: DockerInfoContext): IO[Request[IO]] = {
val request = Method.GET(
buildTokenRequestUri(dockerInfoContext.dockerImageID),
buildTokenRequestHeaders(dockerInfoContext): _*
Expand Down Expand Up @@ -220,7 +220,7 @@ abstract class DockerRegistryV2Abstract(override val config: DockerRegistryConfi
* Request to get the manifest, using the auth token if provided
*/
private def manifestRequest(token: Option[String], imageId: DockerImageIdentifier, manifestHeader: Accept): IO[Request[IO]] = {
val authorizationHeader = token.map(t => Authorization(Credentials.Token(AuthScheme.Bearer, t)))
val authorizationHeader: Option[Authorization] = token.map(t => Authorization(Credentials.Token(AuthScheme.Bearer, t)))
val request = Method.GET(
buildManifestUri(imageId),
List(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package cromwell.docker.registryv2.flows.azure

case class AcrAccessToken(access_token: String)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package cromwell.docker.registryv2.flows.azure

case class AcrRefreshToken(refresh_token: String)
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package cromwell.docker.registryv2.flows.azure

import cats.data.Validated.{Invalid, Valid}
import cats.effect.IO
import com.typesafe.scalalogging.LazyLogging
import common.validation.ErrorOr.ErrorOr
import cromwell.cloudsupport.azure.AzureCredentials
import cromwell.docker.DockerInfoActor.DockerInfoContext
import cromwell.docker.{DockerImageIdentifier, DockerRegistryConfig}
import cromwell.docker.registryv2.DockerRegistryV2Abstract
import org.http4s.{Header, Request, Response, Status}
import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry.domain
import org.http4s.circe.jsonOf
import org.http4s.client.Client
import io.circe.generic.auto._
import org.http4s._


class AzureContainerRegistry(config: DockerRegistryConfig) extends DockerRegistryV2Abstract(config) with LazyLogging {

/**
* (e.g registry-1.docker.io)
*/
override protected def registryHostName(dockerImageIdentifier: DockerImageIdentifier): String =
dockerImageIdentifier.host.getOrElse("")

override def accepts(dockerImageIdentifier: DockerImageIdentifier): Boolean =
dockerImageIdentifier.hostAsString.contains(domain)

override protected def authorizationServerHostName(dockerImageIdentifier: DockerImageIdentifier): String =
dockerImageIdentifier.host.getOrElse("")

/**
* In Azure, service name does not exist at the registry level, it varies per repo, e.g. `terrabatchdev.azurecr.io`
*/
override def serviceName: Option[String] =
throw new Exception("ACR service name is host of user-defined registry, must derive from `DockerImageIdentifier`")

/**
* Builds the list of headers for the token request
*/
override protected def buildTokenRequestHeaders(dockerInfoContext: DockerInfoContext): List[Header] = {
List(contentTypeHeader)
}

private val contentTypeHeader: Header = {
import org.http4s.headers.`Content-Type`
import org.http4s.MediaType

`Content-Type`(MediaType.application.`x-www-form-urlencoded`)
}

private def getRefreshToken(authServerHostname: String, defaultAccessToken: String): IO[Request[IO]] = {
import org.http4s.Uri.{Authority, Scheme}
import org.http4s.client.dsl.io._
import org.http4s._

val uri = Uri.apply(
scheme = Option(Scheme.https),
authority = Option(Authority(host = Uri.RegName(authServerHostname))),
path = "/oauth2/exchange",
query = Query.empty
)

org.http4s.Method.POST(
UrlForm(
"service" -> authServerHostname,
"access_token" -> defaultAccessToken,
"grant_type" -> "access_token"
),
uri,
List(contentTypeHeader): _*
)
}

/*
Unlike other repositories, Azure reserves `GET /oauth2/token` for Basic Authentication [0]
In order to use Oauth we must `POST /oauth2/token` [1]
[0] https://github.com/Azure/acr/blob/main/docs/Token-BasicAuth.md#using-the-token-api
[1] https://github.com/Azure/acr/blob/main/docs/AAD-OAuth.md#calling-post-oauth2token-to-get-an-acr-access-token
*/
private def getDockerAccessToken(hostname: String, repository: String, refreshToken: String): IO[Request[IO]] = {
import org.http4s.Uri.{Authority, Scheme}
import org.http4s.client.dsl.io._
import org.http4s._

val uri = Uri.apply(
scheme = Option(Scheme.https),
authority = Option(Authority(host = Uri.RegName(hostname))),
path = "/oauth2/token",
query = Query.empty
)

org.http4s.Method.POST(
UrlForm(
// Tricky behavior - invalid `repository` values return a 200 with a valid-looking token.
// However, the token will cause 401s on all subsequent requests.
"scope" -> s"repository:$repository:pull",
"service" -> hostname,
"refresh_token" -> refreshToken,
"grant_type" -> "refresh_token"
),
uri,
List(contentTypeHeader): _*
)
}

override protected def getToken(dockerInfoContext: DockerInfoContext)(implicit client: Client[IO]): IO[Option[String]] = {
val hostname = authorizationServerHostName(dockerInfoContext.dockerImageID)
val maybeAadAccessToken: ErrorOr[String] = AzureCredentials.getAccessToken(None) // AAD token suitable for get-refresh-token request
val repository = dockerInfoContext.dockerImageID.image // ACR uses what we think of image name, as the repository

// Top-level flow: AAD access token -> refresh token -> ACR access token
maybeAadAccessToken match {
case Valid(accessToken) =>
(for {
refreshToken <- executeRequest(getRefreshToken(hostname, accessToken), parseRefreshToken)
dockerToken <- executeRequest(getDockerAccessToken(hostname, repository, refreshToken), parseAccessToken)
} yield dockerToken).map(Option.apply)
case Invalid(errors) =>
IO.raiseError(
new Exception(s"Could not obtain AAD token to exchange for ACR refresh token. Error(s): ${errors}")
)
}
}

implicit val refreshTokenDecoder: EntityDecoder[IO, AcrRefreshToken] = jsonOf[IO, AcrRefreshToken]
implicit val accessTokenDecoder: EntityDecoder[IO, AcrAccessToken] = jsonOf[IO, AcrAccessToken]

private def parseRefreshToken(response: Response[IO]): IO[String] = response match {
case Status.Successful(r) => r.as[AcrRefreshToken].map(_.refresh_token)
case r =>
r.as[String].flatMap(b => IO.raiseError(new Exception(s"Request failed with status ${r.status.code} and body $b")))
}

private def parseAccessToken(response: Response[IO]): IO[String] = response match {
case Status.Successful(r) => r.as[AcrAccessToken].map(_.access_token)
case r =>
r.as[String].flatMap(b => IO.raiseError(new Exception(s"Request failed with status ${r.status.code} and body $b")))
}

}

object AzureContainerRegistry {

def domain: String = "azurecr.io"

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class DockerImageIdentifierSpec extends AnyFlatSpec with CromwellTimeoutSpec wit
("broad/cromwell/submarine", None, Option("broad/cromwell"), "submarine", "latest"),
("gcr.io/google/slim", Option("gcr.io"), Option("google"), "slim", "latest"),
("us-central1-docker.pkg.dev/google/slim", Option("us-central1-docker.pkg.dev"), Option("google"), "slim", "latest"),
("terrabatchdev.azurecr.io/postgres", Option("terrabatchdev.azurecr.io"), None, "postgres", "latest"),
// With tags
("ubuntu:latest", None, None, "ubuntu", "latest"),
("ubuntu:1235-SNAP", None, None, "ubuntu", "1235-SNAP"),
("ubuntu:V3.8-5_1", None, None, "ubuntu", "V3.8-5_1"),
("index.docker.io:9999/ubuntu:170904", Option("index.docker.io:9999"), None, "ubuntu", "170904"),
("localhost:5000/capture/transwf:170904", Option("localhost:5000"), Option("capture"), "transwf", "170904"),
("quay.io/biocontainers/platypus-variant:0.8.1.1--htslib1.5_0", Option("quay.io"), Option("biocontainers"), "platypus-variant", "0.8.1.1--htslib1.5_0"),
("terrabatchdev.azurecr.io/postgres:latest", Option("terrabatchdev.azurecr.io"), None, "postgres", "latest"),
// Very long tags with trailing spaces cause problems for the re engine
("someuser/someimage:supercalifragilisticexpialidociouseventhoughthesoundofitissomethingquiteatrociousifyousayitloudenoughyoullalwayssoundprecocious ", None, Some("someuser"), "someimage", "supercalifragilisticexpialidociouseventhoughthesoundofitissomethingquiteatrociousifyousayitloudenoughyoullalwayssoundprecocious")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cromwell.docker

import cromwell.core.Tags.IntegrationTest
import cromwell.docker.DockerInfoActor._
import cromwell.docker.registryv2.flows.azure.AzureContainerRegistry
import cromwell.docker.registryv2.flows.dockerhub.DockerHubRegistry
import cromwell.docker.registryv2.flows.google.GoogleRegistry
import cromwell.docker.registryv2.flows.quay.QuayRegistry
Expand All @@ -18,7 +19,8 @@ class DockerInfoActorSpec extends DockerRegistrySpec with AnyFlatSpecLike with M
override protected lazy val registryFlows = List(
new DockerHubRegistry(DockerRegistryConfig.default),
new GoogleRegistry(DockerRegistryConfig.default),
new QuayRegistry(DockerRegistryConfig.default)
new QuayRegistry(DockerRegistryConfig.default),
new AzureContainerRegistry(DockerRegistryConfig.default)
)

it should "retrieve a public docker hash" taggedAs IntegrationTest in {
Expand Down Expand Up @@ -50,6 +52,16 @@ class DockerInfoActorSpec extends DockerRegistrySpec with AnyFlatSpecLike with M
hash should not be empty
}
}

it should "retrieve a private docker hash on acr" taggedAs IntegrationTest in {
dockerActor ! makeRequest("terrabatchdev.azurecr.io/postgres:latest")

expectMsgPF(15 second) {
case DockerInfoSuccessResponse(DockerInformation(DockerHashResult(alg, hash), _), _) =>
alg shouldBe "sha256"
hash should not be empty
}
}

it should "send image not found message back if the image does not exist" taggedAs IntegrationTest in {
val notFound = makeRequest("ubuntu:nonexistingtag")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.azure.storage.blob.sas.{BlobContainerSasPermission, BlobServiceSasSig
import com.typesafe.config.Config
import com.typesafe.scalalogging.LazyLogging
import common.validation.Validation._
import cromwell.cloudsupport.azure.AzureUtils
import cromwell.cloudsupport.azure.{AzureCredentials, AzureUtils}

import java.net.URI
import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems}
Expand Down

0 comments on commit f469ecb

Please sign in to comment.