Skip to content

Commit

Permalink
[WM-2500][WM-2502] Fetch Github token from ECM for importing and runn…
Browse files Browse the repository at this point in the history
…ing private workflows (#7392)
  • Loading branch information
salonishah11 authored Apr 2, 2024
1 parent 2f8c46d commit d7def8d
Show file tree
Hide file tree
Showing 11 changed files with 330 additions and 29 deletions.
6 changes: 2 additions & 4 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,8 @@ services {
config {
enabled = false
auth.azure = false
# Notes:
# - don't include the 'Bearer' before the token
# - this config value should be removed when support for fetching tokens from ECM has been added to Cromwell
access-token = "dummy-token"
# Set this to the service that Cromwell should retrieve Github access token associated with user's token.
# ecm.base-url = ""
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ object Dependencies {
val servicesDependencies: List[ModuleID] = List(
"com.google.api" % "gax-grpc" % googleGaxGrpcV,
"org.apache.commons" % "commons-csv" % commonsCsvV,
) ++ testDatabaseDependencies
) ++ testDatabaseDependencies ++ akkaHttpDependencies

val serverDependencies: List[ModuleID] = slf4jBindingDependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ object GithubAuthVending {
override def serviceName: String = "GithubAuthVending"
}

case class GithubAuthRequest(terraToken: String) extends GithubAuthVendingMessage
// types of tokens
case class TerraToken(value: String)
case class GithubToken(value: String)

case class GithubAuthRequest(terraToken: TerraToken) extends GithubAuthVendingMessage

sealed trait GithubAuthVendingResponse extends GithubAuthVendingMessage
case class GithubAuthTokenResponse(accessToken: String) extends GithubAuthVendingResponse
case class GithubAuthTokenResponse(githubAccessToken: GithubToken) extends GithubAuthVendingResponse
case object NoGithubAuthResponse extends GithubAuthVendingResponse
case class GithubAuthVendingFailure(error: Exception) extends GithubAuthVendingResponse
case class GithubAuthVendingFailure(errorMsg: String) extends GithubAuthVendingResponse

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import cromwell.services.auth.GithubAuthVending.{
GithubAuthTokenResponse,
GithubAuthVendingFailure,
GithubAuthVendingResponse,
NoGithubAuthResponse
NoGithubAuthResponse,
TerraToken
}

import net.ceedubs.ficus.Ficus._

import scala.concurrent.{ExecutionContext, Future}

trait GithubAuthVendingSupport extends AskSupport with StrictLogging {
Expand All @@ -30,7 +31,7 @@ trait GithubAuthVendingSupport extends AskSupport with StrictLogging {
def importAuthProvider(token: String)(implicit timeout: Timeout): ImportAuthProvider = new GithubImportAuthProvider {
override def authHeader(): Future[Map[String, String]] =
serviceRegistryActor
.ask(GithubAuthRequest(token))
.ask(GithubAuthRequest(TerraToken(token)))
.mapTo[GithubAuthVendingResponse]
.recoverWith {
case e: AskTimeoutException =>
Expand All @@ -41,10 +42,11 @@ trait GithubAuthVendingSupport extends AskSupport with StrictLogging {
Future.failed(new Exception("Failed to resolve github auth token", e))
}
.flatMap {
case GithubAuthTokenResponse(token) => Future.successful(Map("Authorization" -> s"Bearer ${token}"))
case GithubAuthTokenResponse(githubToken) =>
Future.successful(Map("Authorization" -> s"Bearer ${githubToken.value}"))
case NoGithubAuthResponse => Future.successful(Map.empty)
case GithubAuthVendingFailure(error) =>
Future.failed(new Exception("Failed to resolve github auth token", error))
Future.failed(new Exception(s"Failed to resolve GitHub auth token. Error: $error"))
}
}

Expand Down
10 changes: 10 additions & 0 deletions services/src/main/scala/cromwell/services/auth/ecm/EcmConfig.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package cromwell.services.auth.ecm

import com.typesafe.config.Config
import net.ceedubs.ficus.Ficus._

final case class EcmConfig(baseUrl: Option[String])

object EcmConfig {
def apply(config: Config): EcmConfig = EcmConfig(config.as[Option[String]]("ecm.base-url"))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package cromwell.services.auth.ecm

import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.RawHeader
import akka.util.ByteString
import cromwell.services.auth.GithubAuthVending.{GithubToken, TerraToken}
import spray.json._

import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Success, Try}

class EcmService(baseEcmUrl: String) {
private val getGithubAccessTokenApiPath = "api/oauth/v1/github/access-token"

/*
ECM does generally return standard JSON error response, but for 401 status code it seems some other layer in
between (like the apache proxies, etc) returns HTML pages. This helper method returns custom error message for 401
status code as it contains HTML tags. For all other status code, the response format is generally of ErrorReport
schema and this method tries to extract the actual message from the JSON object and return it. In case it fails
to parse JSON, it returns the original error response body.
ErrorReport schema: {"message":"<actual_error_msg>", "statusCode":<code>}
*/
def extractErrorMessage(errorCode: StatusCode, responseBodyAsStr: String): String =
errorCode match {
case StatusCodes.Unauthorized => "Invalid or missing authentication credentials."
case _ =>
Try(responseBodyAsStr.parseJson) match {
case Success(JsObject(fields)) =>
fields.get("message").map(_.toString().replaceAll("\"", "")).getOrElse(responseBodyAsStr)
case _ => responseBodyAsStr
}
}

def getGithubAccessToken(
userToken: TerraToken
)(implicit actorSystem: ActorSystem, ec: ExecutionContext): Future[GithubToken] = {

def responseEntityToFutureStr(responseEntity: ResponseEntity): Future[String] =
responseEntity.dataBytes.runFold(ByteString(""))(_ ++ _).map(_.utf8String)

val headers: HttpHeader = RawHeader("Authorization", s"Bearer ${userToken.value}")
val httpRequest =
HttpRequest(method = HttpMethods.GET, uri = s"$baseEcmUrl/$getGithubAccessTokenApiPath").withHeaders(headers)

Http()
.singleRequest(httpRequest)
.flatMap((response: HttpResponse) =>
if (response.status.isFailure()) {
responseEntityToFutureStr(response.entity) flatMap { errorBody =>
val errorMessage = extractErrorMessage(response.status, errorBody)
Future.failed(new RuntimeException(s"HTTP ${response.status.value}. $errorMessage"))
}
} else responseEntityToFutureStr(response.entity).map(GithubToken)
)
}
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,57 @@
package cromwell.services.auth.impl

import akka.actor.{Actor, ActorRef, Props}
import akka.actor.{Actor, ActorRef, ActorSystem, Props}
import com.typesafe.config.Config
import com.typesafe.scalalogging.LazyLogging
import common.util.StringUtil.EnhancedToStringable
import cromwell.core.Dispatcher.ServiceDispatcher
import cromwell.services.auth.GithubAuthVending.{GithubAuthRequest, GithubAuthTokenResponse, NoGithubAuthResponse}
import cromwell.services.auth.GithubAuthVending.{
GithubAuthRequest,
GithubAuthTokenResponse,
GithubAuthVendingFailure,
NoGithubAuthResponse
}
import cromwell.services.auth.ecm.{EcmConfig, EcmService}
import cromwell.util.GracefulShutdownHelper.ShutdownCommand

import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success}

class GithubAuthVendingActor(serviceConfig: Config, globalConfig: Config, serviceRegistryActor: ActorRef)
extends Actor
with LazyLogging {

lazy val enabled = serviceConfig.getBoolean("enabled")
implicit val system: ActorSystem = context.system
implicit val ec: ExecutionContext = context.dispatcher

lazy val enabled: Boolean = serviceConfig.getBoolean("enabled")

lazy val ecmConfigOpt: EcmConfig = EcmConfig(serviceConfig)
lazy val ecmServiceOpt: Option[EcmService] = ecmConfigOpt.baseUrl.map(url => new EcmService(url))

override def receive: Receive = {
case GithubAuthRequest(_) if enabled =>
sender() ! GithubAuthTokenResponse(serviceConfig.getString("access-token"))
case GithubAuthRequest(terraToken) if enabled =>
val respondTo = sender()
ecmServiceOpt match {
case Some(ecmService) =>
ecmService.getGithubAccessToken(terraToken) onComplete {
case Success(githubToken) => respondTo ! GithubAuthTokenResponse(githubToken)
case Failure(e) => respondTo ! GithubAuthVendingFailure(e.getMessage)
}
case None =>
respondTo ! GithubAuthVendingFailure(
"Invalid configuration for service 'GithubAuthVending': missing 'ecm.base-url' value."
)
}
case GithubAuthRequest(_) => sender() ! NoGithubAuthResponse
// This service currently doesn't do any work on shutdown but the service registry pattern requires it
// (see https://github.com/broadinstitute/cromwell/issues/2575)
case ShutdownCommand => context stop self
case _ =>
sender() ! NoGithubAuthResponse
case other =>
logger.error(
s"Programmer Error: Unexpected message ${other.toPrettyElidedString(1000)} received by ${this.self.path.name}."
)
sender() ! GithubAuthVendingFailure(s"Received unexpected message ${other.toPrettyElidedString(1000)}.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import com.typesafe.config.ConfigFactory
import cromwell.core.TestKitSuite
import cromwell.languages.util.ImportResolver.GithubImportAuthProvider
import cromwell.services.ServiceRegistryActor.ServiceRegistryFailure
import cromwell.services.auth.GithubAuthVending.GithubAuthRequest
import cromwell.services.auth.GithubAuthVending.{GithubAuthRequest, GithubToken, TerraToken}
import cromwell.services.auth.GithubAuthVendingSupportSpec.TestGithubAuthVendingSupport
import org.scalatest.concurrent.Eventually
import org.scalatest.flatspec.AnyFlatSpecLike
Expand Down Expand Up @@ -42,8 +42,8 @@ class GithubAuthVendingSupportSpec extends TestKitSuite with AnyFlatSpecLike wit
val provider = testSupport.importAuthProvider("user-token")
val authHeader: Future[Map[String, String]] = provider.authHeader()

serviceRegistryActor.expectMsg(GithubAuthRequest("user-token"))
serviceRegistryActor.reply(GithubAuthVending.GithubAuthTokenResponse("github-token"))
serviceRegistryActor.expectMsg(GithubAuthRequest(TerraToken("user-token")))
serviceRegistryActor.reply(GithubAuthVending.GithubAuthTokenResponse(GithubToken("github-token")))

Await.result(authHeader, 10.seconds) should be(Map("Authorization" -> "Bearer github-token"))
}
Expand All @@ -54,7 +54,7 @@ class GithubAuthVendingSupportSpec extends TestKitSuite with AnyFlatSpecLike wit
val provider = testSupport.importAuthProvider("user-token")
val authHeader: Future[Map[String, String]] = provider.authHeader()

serviceRegistryActor.expectMsg(GithubAuthRequest("user-token"))
serviceRegistryActor.expectMsg(GithubAuthRequest(TerraToken("user-token")))
serviceRegistryActor.reply(GithubAuthVending.NoGithubAuthResponse)

Await.result(authHeader, 10.seconds) should be(Map.empty)
Expand All @@ -66,13 +66,12 @@ class GithubAuthVendingSupportSpec extends TestKitSuite with AnyFlatSpecLike wit
val provider = testSupport.importAuthProvider("user-token")
val authHeader: Future[Map[String, String]] = provider.authHeader()

serviceRegistryActor.expectMsg(GithubAuthRequest("user-token"))
serviceRegistryActor.reply(GithubAuthVending.GithubAuthVendingFailure(new Exception("BOOM")))
serviceRegistryActor.expectMsg(GithubAuthRequest(TerraToken("user-token")))
serviceRegistryActor.reply(GithubAuthVending.GithubAuthVendingFailure("BOOM"))

eventually {
authHeader.isCompleted should be(true)
authHeader.value.get.failed.get.getMessage should be("Failed to resolve github auth token")
authHeader.value.get.failed.get.getCause.getMessage should be("BOOM")
authHeader.value.get.failed.get.getMessage should be("Failed to resolve GitHub auth token. Error: BOOM")
}
}

Expand All @@ -95,7 +94,7 @@ class GithubAuthVendingSupportSpec extends TestKitSuite with AnyFlatSpecLike wit
val provider = testSupport.importAuthProvider("user-token")
val authHeader: Future[Map[String, String]] = provider.authHeader()

serviceRegistryActor.expectMsg(GithubAuthRequest("user-token"))
serviceRegistryActor.expectMsg(GithubAuthRequest(TerraToken("user-token")))
serviceRegistryActor.reply(ServiceRegistryFailure("GithubAuthVending"))

eventually {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package cromwell.services.auth.ecm

import com.typesafe.config.ConfigFactory
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class EcmConfigSpec extends AnyFlatSpec with Matchers {

it should "parse ECM base url when present" in {
val config = ConfigFactory.parseString(s"""
|enabled = true
|auth.azure = true
|ecm.base-url = "https://mock-ecm-url.org"
""".stripMargin)

val actualEcmConfig = EcmConfig(config)

actualEcmConfig.baseUrl shouldBe defined
actualEcmConfig.baseUrl.get shouldBe "https://mock-ecm-url.org"
}

it should "return None when ECM base url is absent" in {
val config = ConfigFactory.parseString(s"""
|enabled = true
|auth.azure = true
""".stripMargin)

EcmConfig(config).baseUrl shouldBe None
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package cromwell.services.auth.ecm

import akka.http.scaladsl.model.StatusCodes
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.prop.TableDrivenPropertyChecks

class EcmServiceSpec extends AnyFlatSpec with Matchers with TableDrivenPropertyChecks {

private val ecmService = new EcmService("https://mock-ecm-url.org")

private val ecm400ErrorMsg = "No enum constant bio.terra.externalcreds.generated.model.Provider.MyOwnProvider"
private val ecm404ErrorMsg =
"No linked account found for user ID: 123 and provider: github. Please go to the Terra Profile page External Identities tab to link your account for this provider"

private val testCases = Table(
("test name", "response status code", "response body string", "expected error message"),
("return custom 401 error when status code is 401",
StatusCodes.Unauthorized,
"<h2>could be anything</h2>",
"Invalid or missing authentication credentials."
),
("extract message from valid ErrorReport JSON if status code is 400",
StatusCodes.BadRequest,
s"""{ "message" : "$ecm400ErrorMsg", "statusCode" : 400}""",
ecm400ErrorMsg
),
("extract message from valid ErrorReport JSON if status code is 404",
StatusCodes.NotFound,
s"""{ "message" : "$ecm404ErrorMsg", "statusCode" : 404}""",
ecm404ErrorMsg
),
("extract message from valid ErrorReport JSON if status code is 500",
StatusCodes.InternalServerError,
"""{ "message" : "Internal error", "statusCode" : 500}""",
"Internal error"
),
("return response error body if it fails to parse JSON",
StatusCodes.InternalServerError,
"Response error - not a JSON",
"Response error - not a JSON"
),
("return response error body if JSON doesn't contain 'message' key",
StatusCodes.BadRequest,
"""{"non-message-key" : "error message"}""",
"""{"non-message-key" : "error message"}"""
)
)

behavior of "extractErrorMessage in EcmService"

forAll(testCases) { (testName, statusCode, responseBodyAsStr, expectedErrorMsg) =>
it should testName in {
assert(ecmService.extractErrorMessage(statusCode, responseBodyAsStr) == expectedErrorMsg)
}
}
}
Loading

0 comments on commit d7def8d

Please sign in to comment.