Skip to content

Commit

Permalink
Support specifying a client certificate for mTLS auth (#940)
Browse files Browse the repository at this point in the history
  • Loading branch information
tinsukE authored Jan 24, 2025
1 parent ec05bdd commit 39dc3bc
Show file tree
Hide file tree
Showing 17 changed files with 242 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ class FeverSecurityKey private constructor() : SecurityKey() {
var serverUrl: String? = null
var username: String? = null
var password: String? = null
var clientCertificateAlias: String? = null

constructor(serverUrl: String?, username: String?, password: String?) : this() {
constructor(serverUrl: String?, username: String?, password: String?, clientCertificateAlias: String?) : this() {
this.serverUrl = serverUrl
this.username = username
this.password = password
this.clientCertificateAlias = clientCertificateAlias
}

constructor(value: String? = DESUtils.empty) : this() {
decode(value, FeverSecurityKey::class.java).let {
serverUrl = it.serverUrl
username = it.username
password = it.password
clientCertificateAlias = it.clientCertificateAlias
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ class FreshRSSSecurityKey private constructor() : SecurityKey() {
var serverUrl: String? = null
var username: String? = null
var password: String? = null
var clientCertificateAlias: String? = null

constructor(serverUrl: String?, username: String?, password: String?) : this() {
constructor(serverUrl: String?, username: String?, password: String?, clientCertificateAlias: String?) : this() {
this.serverUrl = serverUrl
this.username = username
this.password = password
this.clientCertificateAlias = clientCertificateAlias
}

constructor(value: String? = DESUtils.empty) : this() {
decode(value, FreshRSSSecurityKey::class.java).let {
serverUrl = it.serverUrl
username = it.username
password = it.password
clientCertificateAlias = it.clientCertificateAlias
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ class GoogleReaderSecurityKey private constructor() : SecurityKey() {
var serverUrl: String? = null
var username: String? = null
var password: String? = null
var clientCertificateAlias: String? = null

constructor(serverUrl: String?, username: String?, password: String?) : this() {
constructor(serverUrl: String?, username: String?, password: String?, clientCertificateAlias: String?) : this() {
this.serverUrl = serverUrl
this.username = username
this.password = password
this.clientCertificateAlias = clientCertificateAlias
}

constructor(value: String? = DESUtils.empty) : this() {
decode(value, GoogleReaderSecurityKey::class.java).let {
serverUrl = it.serverUrl
username = it.username
password = it.password
clientCertificateAlias = it.clientCertificateAlias
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ class FeverRssService @Inject constructor(
private suspend fun getFeverAPI() =
FeverSecurityKey(accountDao.queryById(context.currentAccountId)!!.securityKey).run {
FeverAPI.getInstance(
context = context,
serverUrl = serverUrl!!,
username = username!!,
password = password!!,
httpUsername = null,
httpPassword = null,
clientCertificateAlias = clientCertificateAlias,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ class GoogleReaderRssService @Inject constructor(
private suspend fun getGoogleReaderAPI() =
GoogleReaderSecurityKey(accountDao.queryById(context.currentAccountId)!!.securityKey).run {
GoogleReaderAPI.getInstance(
context = context,
serverUrl = serverUrl!!,
username = username!!,
password = password!!,
httpUsername = null,
httpPassword = null,
clientCertificateAlias = clientCertificateAlias,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

package me.ash.reader.infrastructure.di

import android.annotation.SuppressLint
import android.content.Context
import android.security.KeyChain
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
Expand All @@ -31,15 +33,18 @@ import okhttp3.Cache
import okhttp3.Interceptor
import okhttp3.OkHttpClient
import okhttp3.Response
import okhttp3.internal.platform.Platform
import java.io.File
import java.net.Socket
import java.security.KeyManagementException
import java.security.NoSuchAlgorithmException
import java.security.Principal
import java.security.PrivateKey
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import javax.inject.Singleton
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManager
import javax.net.ssl.X509KeyManager
import javax.net.ssl.X509TrustManager

/**
Expand All @@ -54,18 +59,21 @@ object OkHttpClientModule {
fun provideOkHttpClient(
@ApplicationContext context: Context,
): OkHttpClient = cachingHttpClient(
context = context,
cacheDirectory = context.cacheDir.resolve("http")
).newBuilder()
.addNetworkInterceptor(UserAgentInterceptor)
.build()
}

fun cachingHttpClient(
context: Context,
cacheDirectory: File? = null,
cacheSize: Long = 10L * 1024L * 1024L,
trustAllCerts: Boolean = true,
connectTimeoutSecs: Long = 30L,
readTimeoutSecs: Long = 30L,
clientCertificateAlias: String? = null,
): OkHttpClient {
val builder: OkHttpClient.Builder = OkHttpClient.Builder()

Expand All @@ -78,31 +86,75 @@ fun cachingHttpClient(
.readTimeout(readTimeoutSecs, TimeUnit.SECONDS)
.followRedirects(true)

if (trustAllCerts) {
builder.trustAllCerts()
if (!clientCertificateAlias.isNullOrBlank() || trustAllCerts) {
builder.setupSsl(context, clientCertificateAlias, trustAllCerts)
}

return builder.build()
}

fun OkHttpClient.Builder.trustAllCerts() {
fun OkHttpClient.Builder.setupSsl(
context: Context,
clientCertificateAlias: String?,
trustAllCerts: Boolean
) {
try {
val trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {
val clientKeyManager = clientCertificateAlias?.let { clientAlias ->
object : X509KeyManager {
override fun getClientAliases(keyType: String?, issuers: Array<Principal>?) =
throw UnsupportedOperationException("getClientAliases")

override fun chooseClientAlias(
keyType: Array<String>?,
issuers: Array<Principal>?,
socket: Socket?
) = clientCertificateAlias

override fun getServerAliases(keyType: String?, issuers: Array<Principal>?) =
throw UnsupportedOperationException("getServerAliases")

override fun chooseServerAlias(
keyType: String?,
issuers: Array<Principal>?,
socket: Socket?
) = throw UnsupportedOperationException("chooseServerAlias")

override fun getCertificateChain(alias: String?): Array<X509Certificate>? {
return if (alias == clientAlias) KeyChain.getCertificateChain(context, clientAlias) else null
}

override fun getPrivateKey(alias: String?): PrivateKey? {
return if (alias == clientAlias) KeyChain.getPrivateKey(context, clientAlias) else null
}
}
}

override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
}
val trustManager = if (trustAllCerts) {
hostnameVerifier { _, _ -> true }

@SuppressLint("CustomX509TrustManager")
object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<out X509Certificate>?,
authType: String?
) = Unit

override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
override fun checkServerTrusted(
chain: Array<out X509Certificate>?,
authType: String?
) = Unit

override fun getAcceptedIssuers(): Array<X509Certificate> = emptyArray()
}
} else {
Platform.get().platformTrustManager()
}

val sslContext = SSLContext.getInstance("TLS")
sslContext.init(null, arrayOf<TrustManager>(trustManager), null)
sslContext.init(arrayOf(clientKeyManager), arrayOf(trustManager), null)
val sslSocketFactory = sslContext.socketFactory

sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(HostnameVerifier { _, _ -> true })
} catch (e: NoSuchAlgorithmException) {
// ignore
} catch (e: KeyManagementException) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package me.ash.reader.infrastructure.rss.provider

import android.content.Context
import com.google.gson.Gson
import com.google.gson.GsonBuilder
import me.ash.reader.infrastructure.di.UserAgentInterceptor
import me.ash.reader.infrastructure.di.cachingHttpClient
import okhttp3.OkHttpClient

abstract class ProviderAPI {
abstract class ProviderAPI(context: Context, clientCertificateAlias: String?) {

protected val client: OkHttpClient = cachingHttpClient()
protected val client: OkHttpClient = cachingHttpClient(
context = context,
clientCertificateAlias = clientCertificateAlias,
)
.newBuilder()
.addNetworkInterceptor(UserAgentInterceptor)
.build()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package me.ash.reader.infrastructure.rss.provider.fever

import android.content.Context
import me.ash.reader.infrastructure.exception.FeverAPIException
import me.ash.reader.infrastructure.rss.provider.ProviderAPI
import me.ash.reader.ui.ext.encodeBase64
Expand All @@ -10,11 +11,13 @@ import okhttp3.executeAsync
import java.util.concurrent.ConcurrentHashMap

class FeverAPI private constructor(
context: Context,
private val serverUrl: String,
private val apiKey: String,
private val httpUsername: String? = null,
private val httpPassword: String? = null,
) : ProviderAPI() {
clientCertificateAlias: String? = null,
) : ProviderAPI(context, clientCertificateAlias) {

private suspend inline fun <reified T> postRequest(query: String?): T {
val response = client.newCall(
Expand Down Expand Up @@ -104,14 +107,16 @@ class FeverAPI private constructor(
private val instances: ConcurrentHashMap<String, FeverAPI> = ConcurrentHashMap()

fun getInstance(
context: Context,
serverUrl: String,
username: String,
password: String,
httpUsername: String? = null,
httpPassword: String? = null,
clientCertificateAlias: String? = null,
): FeverAPI = "$username:$password".md5().run {
instances.getOrPut("$serverUrl$this$httpUsername$httpPassword") {
FeverAPI(serverUrl, this, httpUsername, httpPassword)
instances.getOrPut("$serverUrl$this$httpUsername$httpPassword$clientCertificateAlias") {
FeverAPI(context, serverUrl, this, httpUsername, httpPassword, clientCertificateAlias)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package me.ash.reader.infrastructure.rss.provider.greader

import android.content.Context
import me.ash.reader.infrastructure.di.USER_AGENT_STRING
import me.ash.reader.infrastructure.exception.GoogleReaderAPIException
import me.ash.reader.infrastructure.exception.RetryException
Expand All @@ -10,12 +11,14 @@ import okhttp3.executeAsync
import java.util.concurrent.ConcurrentHashMap

class GoogleReaderAPI private constructor(
context: Context,
private val serverUrl: String,
private val username: String,
private val password: String,
private val httpUsername: String? = null,
private val httpPassword: String? = null,
) : ProviderAPI() {
clientCertificateAlias: String? = null,
) : ProviderAPI(context, clientCertificateAlias) {

enum class Stream(val tag: String) {
ALL_ITEMS("user/-/state/com.google/reading-list"),
Expand Down Expand Up @@ -350,13 +353,15 @@ class GoogleReaderAPI private constructor(
private val instances: ConcurrentHashMap<String, GoogleReaderAPI> = ConcurrentHashMap()

fun getInstance(
context: Context,
serverUrl: String,
username: String,
password: String,
httpUsername: String? = null,
httpPassword: String? = null,
): GoogleReaderAPI = instances.getOrPut("$serverUrl$username$password$httpUsername$httpPassword") {
GoogleReaderAPI(serverUrl, username, password, httpUsername, httpPassword)
clientCertificateAlias: String? = null
): GoogleReaderAPI = instances.getOrPut("$serverUrl$username$password$httpUsername$httpPassword$clientCertificateAlias") {
GoogleReaderAPI(context, serverUrl, username, password, httpUsername, httpPassword, clientCertificateAlias)
}

fun clearInstance() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package me.ash.reader.ui.component.base

import androidx.compose.foundation.interaction.MutableInteractionSource
import androidx.compose.foundation.interaction.PressInteraction
import androidx.compose.foundation.text.KeyboardActions
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material.icons.Icons
Expand All @@ -22,6 +24,7 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusProperties
import androidx.compose.ui.focus.focusRequester
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
Expand All @@ -46,6 +49,7 @@ fun RYOutlineTextField(
errorMessage: String = "",
keyboardOptions: KeyboardOptions = KeyboardOptions.Default,
keyboardActions: KeyboardActions = KeyboardActions(),
onClick: (() -> Unit)? = null,
) {
val clipboardManager = LocalClipboardManager.current
val focusRequester = remember { FocusRequester() }
Expand All @@ -59,7 +63,11 @@ fun RYOutlineTextField(
}

OutlinedTextField(
modifier = Modifier.focusRequester(focusRequester),
modifier = if (onClick != null) {
Modifier.focusProperties { canFocus = false }
} else {
Modifier.focusRequester(focusRequester)
},
colors = TextFieldDefaults.colors(
unfocusedContainerColor = Color.Transparent,
focusedContainerColor = Color.Transparent
Expand Down Expand Up @@ -115,5 +123,18 @@ fun RYOutlineTextField(
},
keyboardOptions = keyboardOptions,
keyboardActions = keyboardActions,
readOnly = onClick != null,
interactionSource = onClick?.let {
remember { MutableInteractionSource() }
.also { interactionSource ->
LaunchedEffect(interactionSource) {
interactionSource.interactions.collect {
if (it is PressInteraction.Release) {
onClick.invoke()
}
}
}
}
}
)
}
Loading

0 comments on commit 39dc3bc

Please sign in to comment.