Skip to content

Commit

Permalink
feat: handle token error responses, extract some common code
Browse files Browse the repository at this point in the history
  • Loading branch information
tronghn committed Nov 25, 2024
1 parent c60e179 commit b714f76
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 111 deletions.
32 changes: 9 additions & 23 deletions wonderwalled-azure/src/main/kotlin/io/nais/Wonderwalled.kt
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
package io.nais

import io.ktor.client.plugins.ClientRequestException
import io.ktor.client.statement.readRawBytes
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.server.cio.CIO
import io.ktor.server.engine.embeddedServer
import io.ktor.server.response.respond
import io.ktor.server.response.respondBytes
import io.ktor.server.routing.get
import io.ktor.server.routing.route
import io.ktor.server.routing.routing
import io.nais.common.AppConfig
import io.nais.common.AuthClient
import io.nais.common.IdentityProvider
import io.nais.common.NaisAuth
import io.nais.common.TokenResponse
import io.nais.common.bearerToken
import io.nais.common.commonSetup
import io.nais.common.requestHeaders
import io.nais.common.server

fun main() {
val config = AppConfig()

embeddedServer(CIO, port = config.port) {
commonSetup()

server { config ->
val azure = AuthClient(config.auth, IdentityProvider.AZURE_AD)

routing {
Expand Down Expand Up @@ -63,11 +53,9 @@ fun main() {
}

val target = audience.toScope()
try {
val exchange = azure.exchange(target, token)
call.respond(exchange)
} catch (e: ClientRequestException) {
call.respondBytes(e.response.readRawBytes(), e.response.contentType(), e.response.status)
when (val response = azure.exchange(target, token)) {
is TokenResponse.Success -> call.respond(response)
is TokenResponse.Error -> call.respond(response.status, response.error)
}
}

Expand All @@ -79,11 +67,9 @@ fun main() {
}

val target = audience.toScope()
try {
val token = azure.token(target)
call.respond(token)
} catch (e: ClientRequestException) {
call.respondBytes(e.response.readRawBytes(), e.response.contentType(), e.response.status)
when (val response = azure.token(target)) {
is TokenResponse.Success -> call.respond(response)
is TokenResponse.Error -> call.respond(response.status, response.error)
}
}
}
Expand Down
68 changes: 34 additions & 34 deletions wonderwalled-common/src/main/kotlin/io/nais/common/Auth.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import com.fasterxml.jackson.annotation.JsonAnySetter
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.annotation.JsonProperty
import com.fasterxml.jackson.annotation.JsonValue
import com.natpryce.konfig.Configuration
import com.natpryce.konfig.Key
import com.natpryce.konfig.stringType
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.ResponseException
import io.ktor.client.request.forms.submitForm
import io.ktor.http.HttpStatusCode
import io.ktor.http.parameters
import io.ktor.server.application.ApplicationCall
import io.ktor.server.application.createRouteScopedPlugin
Expand All @@ -31,23 +30,24 @@ enum class IdentityProvider(@JsonValue val alias: String) {
TOKEN_X("tokenx"),
}

data class AuthClientConfig(
val tokenEndpoint: String,
val tokenExchangeEndpoint: String,
val tokenIntrospectionEndpoint: String,
) {
constructor(config: Configuration) : this(
tokenEndpoint = config[Key("nais.token.endpoint", stringType)],
tokenExchangeEndpoint = config[Key("nais.token.exchange.endpoint", stringType)],
tokenIntrospectionEndpoint = config[Key("nais.token.introspection.endpoint", stringType)],
)
sealed class TokenResponse {
data class Success(
@JsonProperty("access_token")
val accessToken: String,
@JsonProperty("expires_in")
val expiresInSeconds: Int,
) : TokenResponse()

data class Error(
val error: TokenErrorResponse,
val status: HttpStatusCode,
) : TokenResponse()
}

data class TokenResponse(
@JsonProperty("access_token")
val accessToken: String,
@JsonProperty("expires_in")
val expiresInSeconds: Int,
data class TokenErrorResponse(
val error: String,
@JsonProperty("error_description")
val errorDescription: String,
)

data class TokenIntrospectionResponse(
Expand All @@ -59,41 +59,41 @@ data class TokenIntrospectionResponse(
)

class AuthClient(
private val config: AuthClientConfig,
private val config: Config.Auth,
private val provider: IdentityProvider,
private val httpClient: HttpClient = defaultHttpClient(),
) {
private val tracer: Tracer = GlobalOpenTelemetry.get().getTracer("io.nais.common.AuthClient")

suspend fun token(target: String) =
tracer.withSpan("AuthClient/token (${provider.alias})", parameters = {
setAllAttributes(traceAttributes(target))
}) {
suspend fun token(target: String): TokenResponse = try {
tracer.withSpan("AuthClient/token (${provider.alias})", traceAttributes(target)) {
httpClient.submitForm(config.tokenEndpoint, parameters {
set("target", target)
set("identity_provider", provider.alias)
}).body<TokenResponse>()
}).body<TokenResponse.Success>()
}
} catch (e: ResponseException) {
TokenResponse.Error(e.response.body<TokenErrorResponse>(), e.response.status)
}

suspend fun exchange(target: String, userToken: String) =
tracer.withSpan("AuthClient/exchange (${provider.alias})", parameters = {
setAllAttributes(traceAttributes(target))
}) {
suspend fun exchange(target: String, userToken: String): TokenResponse = try {
tracer.withSpan("AuthClient/exchange (${provider.alias})", traceAttributes(target)) {
httpClient.submitForm(config.tokenExchangeEndpoint, parameters {
set("target", target)
set("user_token", userToken)
set("identity_provider", provider.alias)
}).body<TokenResponse>()
}).body<TokenResponse.Success>()
}
} catch (e: ResponseException) {
TokenResponse.Error(e.response.body<TokenErrorResponse>(), e.response.status)
}

suspend fun introspect(accessToken: String) =
tracer.withSpan("AuthClient/introspect (${provider.alias})", parameters = {
setAllAttributes(traceAttributes())
}) {
suspend fun introspect(accessToken: String): TokenIntrospectionResponse =
tracer.withSpan("AuthClient/introspect (${provider.alias})", traceAttributes()) {
httpClient.submitForm(config.tokenIntrospectionEndpoint, parameters {
set("token", accessToken)
set("identity_provider", provider.alias)
}).body<TokenIntrospectionResponse>()
}).body()
}

private fun traceAttributes(target: String? = null) = Attributes.builder().apply {
Expand Down
24 changes: 21 additions & 3 deletions wonderwalled-common/src/main/kotlin/io/nais/common/Config.kt
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
package io.nais.common

import com.natpryce.konfig.Configuration
import com.natpryce.konfig.ConfigurationProperties.Companion.systemProperties
import com.natpryce.konfig.EnvironmentVariables
import com.natpryce.konfig.Key
import com.natpryce.konfig.intType
import com.natpryce.konfig.overriding
import com.natpryce.konfig.stringType
import io.opentelemetry.sdk.autoconfigure.AutoConfiguredOpenTelemetrySdk

private val config =
systemProperties() overriding
EnvironmentVariables()

data class AppConfig(
data class Config(
val port: Int = config.getOrElse(Key("application.port", intType), 8080),
val auth: AuthClientConfig = AuthClientConfig(config),
val auth: Auth = Auth(config),
// optional, generally only needed when running locally
val ingress: String = config.getOrElse(key = Key("login.ingress", stringType), default = ""),
)
) {
init {
AutoConfiguredOpenTelemetrySdk.initialize()
}

data class Auth(
val tokenEndpoint: String,
val tokenExchangeEndpoint: String,
val tokenIntrospectionEndpoint: String,
) {
constructor(config: Configuration) : this(
tokenEndpoint = config[Key("nais.token.endpoint", stringType)],
tokenExchangeEndpoint = config[Key("nais.token.exchange.endpoint", stringType)],
tokenIntrospectionEndpoint = config[Key("nais.token.introspection.endpoint", stringType)],
)
}
}
27 changes: 18 additions & 9 deletions wonderwalled-common/src/main/kotlin/io/nais/common/Http.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ package io.nais.common
import com.fasterxml.jackson.databind.DeserializationFeature
import com.fasterxml.jackson.databind.SerializationFeature
import io.ktor.client.HttpClient
import io.ktor.client.engine.cio.CIO
import io.ktor.http.HttpHeaders
import io.ktor.serialization.jackson.jackson
import io.ktor.server.application.Application
import io.ktor.server.application.ApplicationCall
import io.ktor.server.application.install
import io.ktor.server.engine.embeddedServer
import io.ktor.server.plugins.callid.CallId
import io.ktor.server.plugins.callid.callIdMdc
import io.ktor.server.plugins.calllogging.CallLogging
import io.ktor.server.plugins.contentnegotiation.ContentNegotiation
import io.ktor.server.request.authorization
import io.ktor.server.request.path
import io.ktor.server.response.respond
Expand All @@ -24,13 +23,16 @@ import io.ktor.server.routing.routing
import io.opentelemetry.api.GlobalOpenTelemetry
import io.opentelemetry.instrumentation.ktor.v3_0.client.KtorClientTracing
import io.opentelemetry.instrumentation.ktor.v3_0.server.KtorServerTracing
import io.opentelemetry.sdk.autoconfigure.AutoConfiguredOpenTelemetrySdk
import org.slf4j.event.Level
import java.util.UUID
import io.ktor.client.engine.cio.CIO as ClientCIO
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation as ClientContentNegotiation
import io.ktor.server.cio.CIO as ServerCIO
import io.ktor.server.plugins.contentnegotiation.ContentNegotiation as ServerContentNegotiation

fun defaultHttpClient() = HttpClient(CIO) {
fun defaultHttpClient() = HttpClient(ClientCIO) {
expectSuccess = true
install(io.ktor.client.plugins.contentnegotiation.ContentNegotiation) {
install(ClientContentNegotiation) {
jackson {
deserializationConfig.apply {
configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
Expand All @@ -43,27 +45,32 @@ fun defaultHttpClient() = HttpClient(CIO) {
}
}

fun Application.commonSetup() {
AutoConfiguredOpenTelemetrySdk.initialize()

install(ContentNegotiation) {
fun server(
config: Config = Config(),
module: Application.(Config) -> Unit
) = embeddedServer(ServerCIO, port = config.port) {
install(ServerContentNegotiation) {
jackson {
enable(SerializationFeature.INDENT_OUTPUT)
}
}

install(IgnoreTrailingSlash)

install(CallId) {
header(HttpHeaders.XCorrelationId)
header(HttpHeaders.XRequestId)
generate { UUID.randomUUID().toString() }
verify { callId: String -> callId.isNotEmpty() }
}

install(CallLogging) {
level = Level.INFO
disableDefaultColors()
filter { call -> !call.request.path().startsWith("/internal") }
callIdMdc("call_id")
}

install(KtorServerTracing) {
setOpenTelemetry(GlobalOpenTelemetry.get())
}
Expand All @@ -82,6 +89,8 @@ fun Application.commonSetup() {
}
}
}

module(config)
}

fun ApplicationCall.requestHeaders(): Map<String, String> =
Expand Down
8 changes: 5 additions & 3 deletions wonderwalled-common/src/main/kotlin/io/nais/common/Tracing.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.nais.common

import io.opentelemetry.api.common.Attributes
import io.opentelemetry.api.trace.Span
import io.opentelemetry.api.trace.SpanBuilder
import io.opentelemetry.api.trace.StatusCode
import io.opentelemetry.api.trace.Tracer
import io.opentelemetry.extension.kotlin.asContextElement
Expand All @@ -10,11 +10,13 @@ import kotlin.coroutines.coroutineContext

suspend fun <T> Tracer.withSpan(
spanName: String,
parameters: (SpanBuilder.() -> Unit)? = null,
attributes: Attributes? = null,
block: suspend (span: Span) -> T
): T {
val span: Span = this.spanBuilder(spanName).run {
if (parameters != null) parameters()
if (attributes != null) {
setAllAttributes(attributes)
}
startSpan()
}

Expand Down
24 changes: 6 additions & 18 deletions wonderwalled-idporten/src/main/kotlin/io/nais/Wonderwalled.kt
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
package io.nais

import io.ktor.client.plugins.ClientRequestException
import io.ktor.client.statement.readRawBytes
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.server.cio.CIO
import io.ktor.server.engine.embeddedServer
import io.ktor.server.response.respond
import io.ktor.server.response.respondBytes
import io.ktor.server.routing.get
import io.ktor.server.routing.route
import io.ktor.server.routing.routing
import io.nais.common.AppConfig
import io.nais.common.AuthClient
import io.nais.common.IdentityProvider
import io.nais.common.NaisAuth
import io.nais.common.TokenResponse
import io.nais.common.bearerToken
import io.nais.common.commonSetup
import io.nais.common.requestHeaders
import io.nais.common.server

fun main() {
val config = AppConfig()

embeddedServer(CIO, port = config.port) {
commonSetup()

server { config ->
val tokenx = AuthClient(config.auth, IdentityProvider.TOKEN_X)
val idporten = AuthClient(config.auth, IdentityProvider.IDPORTEN)

Expand Down Expand Up @@ -63,11 +53,9 @@ fun main() {
return@get
}

try {
val exchange = tokenx.exchange(audience, token)
call.respond(exchange)
} catch (e: ClientRequestException) {
call.respondBytes(e.response.readRawBytes(), e.response.contentType(), e.response.status)
when (val response = tokenx.exchange(audience, token)) {
is TokenResponse.Success -> call.respond(response)
is TokenResponse.Error -> call.respond(response.status, response.error)
}
}
}
Expand Down
Loading

0 comments on commit b714f76

Please sign in to comment.