Skip to content

Commit

Permalink
WM-2461] Add support for running private workflows on Azure (#7373)
Browse files Browse the repository at this point in the history
  • Loading branch information
salonishah11 authored Mar 7, 2024
1 parent 82b8dc5 commit 79c2bff
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 11 deletions.
1 change: 1 addition & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ services {
class = "cromwell.services.auth.impl.GithubAuthVendingActor"
config {
enabled = false
auth.azure = false
# Notes:
# - don't include the 'Bearer' before the token
# - this config value should be removed when support for fetching tokens from ECM has been added to Cromwell
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cromwell.engine.workflow.lifecycle.materialization

import akka.actor.{ActorRef, FSM, LoggingFSM, Props, Status}
import akka.pattern.pipe
import akka.util.Timeout
import cats.data.EitherT._
import cats.data.NonEmptyList
import cats.data.Validated.{Invalid, Valid}
Expand Down Expand Up @@ -36,6 +37,7 @@ import cromwell.filesystems.gcs.batch.GcsBatchCommandBuilder
import cromwell.languages.util.ImportResolver._
import cromwell.languages.util.LanguageFactoryUtil
import cromwell.languages.{LanguageFactory, ValidatedWomNamespace}
import cromwell.services.auth.GithubAuthVendingSupport
import cromwell.services.metadata.MetadataService._
import cromwell.services.metadata.{MetadataEvent, MetadataKey, MetadataValue}
import eu.timepit.refined.refineV
Expand All @@ -50,6 +52,7 @@ import wom.runtime.WomOutputRuntimeExtractor
import wom.values.{WomString, WomValue}

import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
import scala.language.postfixOps
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -183,15 +186,16 @@ object MaterializeWorkflowDescriptorActor {
}

// TODO WOM: need to decide where to draw the line between language specific initialization and WOM
class MaterializeWorkflowDescriptorActor(serviceRegistryActor: ActorRef,
class MaterializeWorkflowDescriptorActor(override val serviceRegistryActor: ActorRef,
workflowId: WorkflowId,
cromwellBackends: => CromwellBackends,
importLocalFilesystem: Boolean,
ioActorProxy: ActorRef,
hogGroup: HogGroup
) extends LoggingFSM[MaterializeWorkflowDescriptorActorState, Unit]
with StrictLogging
with WorkflowLogging {
with WorkflowLogging
with GithubAuthVendingSupport {

import MaterializeWorkflowDescriptorActor._
val tag = self.path.name
Expand All @@ -204,6 +208,8 @@ class MaterializeWorkflowDescriptorActor(serviceRegistryActor: ActorRef,

protected val pathBuilderFactories: List[PathBuilderFactory] = EngineFilesystems.configuredPathBuilderFactories

final private val githubAuthVendingTimeout = Timeout(60.seconds)

startWith(ReadyToMaterializeState, ())

when(ReadyToMaterializeState) {
Expand Down Expand Up @@ -346,7 +352,12 @@ class MaterializeWorkflowDescriptorActor(serviceRegistryActor: ActorRef,
for {
_ <- publishLabelsToMetadata(id, labels.asMap, serviceRegistryActor)
zippedImportResolver <- zippedResolverCheck
importResolvers = zippedImportResolver.toList ++ localFilesystemResolvers :+ HttpResolver(None, Map.empty)
importAuthProviderOpt <- importAuthProvider(conf)(githubAuthVendingTimeout).toIOChecked
importResolvers = zippedImportResolver.toList ++ localFilesystemResolvers :+ HttpResolver(
None,
Map.empty,
importAuthProviderOpt.toList
)
sourceAndResolvers <- fromEither[IO](
LanguageFactoryUtil.findWorkflowSource(sourceFiles.workflowSource, sourceFiles.workflowUrl, importResolvers)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package cromwell.engine.workflow.lifecycle

import akka.actor.Props
import akka.actor.{ActorRef, Props}
import akka.testkit.TestDuration
import cats.data.Validated.{Invalid, Valid}
import com.typesafe.config.ConfigFactory
Expand All @@ -14,6 +14,7 @@ import cromwell.engine.workflow.lifecycle.materialization.MaterializeWorkflowDes
MaterializeWorkflowDescriptorFailureResponse,
MaterializeWorkflowDescriptorSuccessResponse
}
import cromwell.services.auth.GithubAuthVendingSupport
import cromwell.util.SampleWdl.HelloWorld
import cromwell.{CromwellTestKitSpec, CromwellTestKitWordSpec}
import org.scalatest.BeforeAndAfter
Expand All @@ -23,7 +24,10 @@ import wom.values.{WomInteger, WomString}

import scala.concurrent.duration._

class MaterializeWorkflowDescriptorActorSpec extends CromwellTestKitWordSpec with BeforeAndAfter {
class MaterializeWorkflowDescriptorActorSpec
extends CromwellTestKitWordSpec
with BeforeAndAfter
with GithubAuthVendingSupport {

private val ioActor = system.actorOf(SimpleIoActor.props)
private val workflowId = WorkflowId.randomId()
Expand Down Expand Up @@ -89,6 +93,8 @@ class MaterializeWorkflowDescriptorActorSpec extends CromwellTestKitWordSpec wit

private val fooHogGroup = HogGroup("foo")

override def serviceRegistryActor: ActorRef = NoBehaviorActor

"MaterializeWorkflowDescriptorActor" should {
"accept valid WDL, inputs and options files" in {
val materializeWfActor = system.actorOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ package cromwell.services.auth
import akka.actor.ActorRef
import akka.pattern.{AskSupport, AskTimeoutException}
import akka.util.Timeout
import cats.data.Validated.{Invalid, Valid}
import cats.implicits.catsSyntaxValidatedId
import com.typesafe.config.Config
import com.typesafe.scalalogging.StrictLogging
import common.validation.ErrorOr.ErrorOr
import cromwell.cloudsupport.azure.AzureCredentials
import cromwell.languages.util.ImportResolver.{GithubImportAuthProvider, ImportAuthProvider}
import cromwell.services.auth.GithubAuthVending.{
GithubAuthRequest,
Expand All @@ -13,16 +18,16 @@ import cromwell.services.auth.GithubAuthVending.{
NoGithubAuthResponse
}

import net.ceedubs.ficus.Ficus._
import scala.concurrent.{ExecutionContext, Future}

trait GithubAuthVendingSupport extends AskSupport with StrictLogging {

def serviceRegistryActor: ActorRef

implicit val timeout: Timeout
implicit val ec: ExecutionContext

def importAuthProvider(token: String): ImportAuthProvider = new GithubImportAuthProvider {
def importAuthProvider(token: String)(implicit timeout: Timeout): ImportAuthProvider = new GithubImportAuthProvider {
override def authHeader(): Future[Map[String, String]] =
serviceRegistryActor
.ask(GithubAuthRequest(token))
Expand All @@ -42,4 +47,16 @@ trait GithubAuthVendingSupport extends AskSupport with StrictLogging {
Future.failed(new Exception("Failed to resolve github auth token", error))
}
}

def importAuthProvider(config: Config)(implicit timeout: Timeout): ErrorOr[Option[ImportAuthProvider]] = {
val isAuthAzure = config.as[Boolean]("services.GithubAuthVending.config.auth.azure")

if (isAuthAzure) {
val azureToken = AzureCredentials.getAccessToken()
azureToken match {
case Valid(token) => Option(importAuthProvider(token)).validNel
case Invalid(err) => s"Failed to fetch Azure token. Error: ${err.toString}".invalidNel
}
} else None.validNel
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package cromwell.services.auth
import akka.actor.ActorRef
import akka.testkit.TestProbe
import akka.util.Timeout
import cats.data.Validated.{Invalid, Valid}
import com.typesafe.config.ConfigFactory
import cromwell.core.TestKitSuite
import cromwell.languages.util.ImportResolver.GithubImportAuthProvider
import cromwell.services.ServiceRegistryActor.ServiceRegistryFailure
import cromwell.services.auth.GithubAuthVending.GithubAuthRequest
import cromwell.services.auth.GithubAuthVendingSupportSpec.TestGithubAuthVendingSupport
Expand All @@ -16,6 +19,21 @@ import scala.concurrent.duration._

class GithubAuthVendingSupportSpec extends TestKitSuite with AnyFlatSpecLike with Matchers with Eventually {

private def azureGithubAuthVendingConfig(enabled: Boolean = true) = ConfigFactory
.parseString(
s"""
|services {
| GithubAuthVending {
| config {
| auth.azure = ${enabled}
| }
| }
|}
|""".stripMargin
)

implicit val timeout = Timeout(10.seconds)

behavior of "GithubAuthVendingSupport"

it should "send a message to the service registry and handle success response" in {
Expand Down Expand Up @@ -60,8 +78,8 @@ class GithubAuthVendingSupportSpec extends TestKitSuite with AnyFlatSpecLike wit

it should "handle timeouts" in {
val serviceRegistryActor = TestProbe()
val testSupport = new TestGithubAuthVendingSupport(serviceRegistryActor.ref, 1.millisecond)
val provider = testSupport.importAuthProvider("user-token")
val testSupport = new TestGithubAuthVendingSupport(serviceRegistryActor.ref)
val provider = testSupport.importAuthProvider("user-token")(Timeout(1.millisecond))
val authHeader: Future[Map[String, String]] = provider.authHeader()

eventually {
Expand All @@ -88,11 +106,32 @@ class GithubAuthVendingSupportSpec extends TestKitSuite with AnyFlatSpecLike wit
}
}

it should "return Github import auth provider when Azure auth is enabled" in {
val serviceRegistryActor = TestProbe()
val testSupport = new TestGithubAuthVendingSupport(serviceRegistryActor.ref)

testSupport.importAuthProvider(azureGithubAuthVendingConfig()) match {
case Valid(providerOpt) =>
providerOpt.isEmpty shouldBe false
providerOpt.get.isInstanceOf[GithubImportAuthProvider] shouldBe true
providerOpt.get.validHosts shouldBe List("github.com", "githubusercontent.com", "raw.githubusercontent.com")
case Invalid(e) => fail(s"Unexpected failure: $e")
}
}

it should "return no import auth provider when Azure auth is disabled" in {
val serviceRegistryActor = TestProbe()
val testSupport = new TestGithubAuthVendingSupport(serviceRegistryActor.ref)

testSupport.importAuthProvider(azureGithubAuthVendingConfig(false)) match {
case Valid(providerOpt) => providerOpt.isEmpty shouldBe true
case Invalid(e) => fail(s"Unexpected failure: $e")
}
}
}

object GithubAuthVendingSupportSpec {
class TestGithubAuthVendingSupport(val serviceRegistryActor: ActorRef, val timeout: Timeout = 10.seconds)
extends GithubAuthVendingSupport {
class TestGithubAuthVendingSupport(val serviceRegistryActor: ActorRef) extends GithubAuthVendingSupport {
implicit override val ec: ExecutionContext = ExecutionContext.global
}

Expand Down

0 comments on commit 79c2bff

Please sign in to comment.