diff --git a/app/src/main/java/com/duckduckgo/app/browser/BrowserWebViewClient.kt b/app/src/main/java/com/duckduckgo/app/browser/BrowserWebViewClient.kt index 1397de8adb38..5bc6cbae99cf 100644 --- a/app/src/main/java/com/duckduckgo/app/browser/BrowserWebViewClient.kt +++ b/app/src/main/java/com/duckduckgo/app/browser/BrowserWebViewClient.kt @@ -129,10 +129,6 @@ class BrowserWebViewClient @Inject constructor( private var shouldOpenDuckPlayerInNewTab: Boolean = true - private val confirmationCallback: (isMalicious: Boolean) -> Unit = { - // TODO (cbarreiro): Handle site blocked asynchronously - } - init { appCoroutineScope.launch { duckPlayer.observeShouldOpenInNewTab().collect { @@ -165,7 +161,7 @@ class BrowserWebViewClient @Inject constructor( try { Timber.v("shouldOverride webViewUrl: ${webView.url} URL: $url") webViewClientListener?.onShouldOverride() - if (requestInterceptor.shouldOverrideUrlLoading(url, isForMainFrame)) { + if (requestInterceptor.shouldOverrideUrlLoading(webView, url, isForMainFrame)) { return true } diff --git a/app/src/main/java/com/duckduckgo/app/browser/WebViewRequestInterceptor.kt b/app/src/main/java/com/duckduckgo/app/browser/WebViewRequestInterceptor.kt index 5c3ec224118e..8c70907848d4 100644 --- a/app/src/main/java/com/duckduckgo/app/browser/WebViewRequestInterceptor.kt +++ b/app/src/main/java/com/duckduckgo/app/browser/WebViewRequestInterceptor.kt @@ -39,6 +39,7 @@ import com.duckduckgo.httpsupgrade.api.HttpsUpgrader import com.duckduckgo.privacy.config.api.Gpc import com.duckduckgo.request.filterer.api.RequestFilterer import com.duckduckgo.user.agent.api.UserAgentProvider +import com.google.android.material.snackbar.Snackbar import kotlinx.coroutines.withContext import timber.log.Timber @@ -62,6 +63,7 @@ interface RequestInterceptor { @WorkerThread fun shouldOverrideUrlLoading( + webView: WebView, url: Uri, isForMainFrame: Boolean, ): Boolean @@ -105,10 +107,12 @@ class WebViewRequestInterceptor( ): WebResourceResponse? { val url: Uri? = request.url - maliciousSiteBlockerWebViewIntegration.shouldIntercept(request, documentUri) { - handleSiteBlocked() + maliciousSiteBlockerWebViewIntegration.shouldIntercept(request, documentUri) { isMalicious -> + if (isMalicious) { + handleSiteBlocked(webView) + } }?.let { - handleSiteBlocked() + handleSiteBlocked(webView) return it } @@ -177,22 +181,24 @@ class WebViewRequestInterceptor( return getWebResourceResponse(request, documentUrl, null) } - override fun shouldOverrideUrlLoading(url: Uri, isForMainFrame: Boolean): Boolean { + override fun shouldOverrideUrlLoading(webView: WebView, url: Uri, isForMainFrame: Boolean): Boolean { if (maliciousSiteBlockerWebViewIntegration.shouldOverrideUrlLoading( url, isForMainFrame, - ) { - handleSiteBlocked() + ) { isMalicious -> + if (isMalicious) { + handleSiteBlocked(webView) + } } ) { - handleSiteBlocked() + handleSiteBlocked(webView) return true } return false } - private fun handleSiteBlocked() { - // TODO (cbarreiro): Handle site blocked + private fun handleSiteBlocked(webView: WebView) { + Snackbar.make(webView, "Site blocked", Snackbar.LENGTH_SHORT).show() } private fun getWebResourceResponse( diff --git a/app/src/main/java/com/duckduckgo/app/browser/webview/MaliciousSiteBlockerWebViewIntegration.kt b/app/src/main/java/com/duckduckgo/app/browser/webview/MaliciousSiteBlockerWebViewIntegration.kt index ab00b5d78f2a..0eb3a918df5e 100644 --- a/app/src/main/java/com/duckduckgo/app/browser/webview/MaliciousSiteBlockerWebViewIntegration.kt +++ b/app/src/main/java/com/duckduckgo/app/browser/webview/MaliciousSiteBlockerWebViewIntegration.kt @@ -32,6 +32,7 @@ import com.duckduckgo.privacy.config.api.PrivacyConfigCallbackPlugin import com.squareup.anvil.annotations.ContributesBinding import com.squareup.anvil.annotations.ContributesMultibinding import java.net.URLDecoder +import java.util.concurrent.atomic.AtomicInteger import javax.inject.Inject import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.launch @@ -68,6 +69,7 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor( @VisibleForTesting(otherwise = VisibleForTesting.PRIVATE) val processedUrls = mutableListOf() private var isFeatureEnabled = false + private var currentCheckId = AtomicInteger(0) init { if (isMainProcess) { @@ -109,16 +111,13 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor( return null } - if (request.isForMainFrame) { - if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), confirmationCallback) == MALICIOUS) { - return WebResourceResponse(null, null, null) - } - processedUrls.add(decodedUrl) - } else if (isForIframe(request) && documentUri?.host == request.requestHeaders["Referer"]?.toUri()?.host) { - if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), confirmationCallback) == MALICIOUS) { + val belongsToCurrentPage = documentUri?.host == request.requestHeaders["Referer"]?.toUri()?.host + if (request.isForMainFrame || (isForIframe(request) && belongsToCurrentPage)) { + if (checkMaliciousUrl(decodedUrl, confirmationCallback)) { return WebResourceResponse(null, null, null) + } else { + processedUrls.add(decodedUrl) } - processedUrls.add(decodedUrl) } return null } @@ -142,15 +141,33 @@ class RealMaliciousSiteBlockerWebViewIntegration @Inject constructor( // iframes always go through the shouldIntercept method, so we only need to check the main frame here if (isForMainFrame) { - if (maliciousSiteProtection.isMalicious(decodedUrl.toUri(), confirmationCallback) == MALICIOUS) { + if (checkMaliciousUrl(decodedUrl, confirmationCallback)) { return@runBlocking true + } else { + processedUrls.add(decodedUrl) } - processedUrls.add(decodedUrl) } false } } + private suspend fun checkMaliciousUrl( + url: String, + confirmationCallback: (isMalicious: Boolean) -> Unit, + ): Boolean { + val checkId = currentCheckId.incrementAndGet() + return maliciousSiteProtection.isMalicious(url.toUri()) { + // if another load has started, we should ignore the result + val isMalicious = if (checkId == currentCheckId.get()) { + it + } else { + false + } + processedUrls.clear() + confirmationCallback(isMalicious) + } == MALICIOUS + } + private fun isForIframe(request: WebResourceRequest) = request.requestHeaders["Sec-Fetch-Dest"] == "iframe" || request.url.path?.contains("/embed/") == true || request.url.path?.contains("/iframe/") == true || diff --git a/app/src/test/java/com/duckduckgo/app/browser/webview/RealMaliciousSiteBlockerWebViewIntegrationTest.kt b/app/src/test/java/com/duckduckgo/app/browser/webview/RealMaliciousSiteBlockerWebViewIntegrationTest.kt index b4141551c217..8341067ed880 100644 --- a/app/src/test/java/com/duckduckgo/app/browser/webview/RealMaliciousSiteBlockerWebViewIntegrationTest.kt +++ b/app/src/test/java/com/duckduckgo/app/browser/webview/RealMaliciousSiteBlockerWebViewIntegrationTest.kt @@ -2,6 +2,7 @@ package com.duckduckgo.app.browser.webview import android.webkit.WebResourceRequest import androidx.core.net.toUri +import androidx.test.core.app.ActivityScenario.launch import androidx.test.ext.junit.runners.AndroidJUnit4 import com.duckduckgo.app.pixels.remoteconfig.AndroidBrowserConfigFeature import com.duckduckgo.common.test.CoroutineTestRule @@ -9,7 +10,12 @@ import com.duckduckgo.feature.toggles.api.FakeFeatureToggleFactory import com.duckduckgo.feature.toggles.api.Toggle.State import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection.IsMaliciousResult.MALICIOUS +import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection.IsMaliciousResult.WAIT_FOR_CONFIRMATION +import junit.framework.TestCase.assertEquals import junit.framework.TestCase.assertTrue +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest import org.junit.Assert.assertFalse import org.junit.Assert.assertNotNull @@ -129,6 +135,17 @@ class RealMaliciousSiteBlockerWebViewIntegrationTest { assertFalse(result) } + @Test + fun `shouldIntercept returns null when feature is enabled, is malicious, and is mainframe but webView has different host`() = runTest { + whenever(maliciousSiteProtection.isMalicious(any(), any())).thenReturn(MALICIOUS) + val request = mock(WebResourceRequest::class.java) + whenever(request.url).thenReturn(maliciousUri) + whenever(request.isForMainFrame).thenReturn(false) + + val result = testee.shouldIntercept(request, exampleUri) {} + assertNull(result) + } + @Test fun `onPageLoadStarted clears processedUrls`() = runTest { testee.processedUrls.add(exampleUri.toString()) @@ -136,6 +153,83 @@ class RealMaliciousSiteBlockerWebViewIntegrationTest { assertTrue(testee.processedUrls.isEmpty()) } + @Test + fun `if a new page load triggering is malicious is started, isMalicious callback result should be ignored for the first page`() = runTest { + val request = mock(WebResourceRequest::class.java) + whenever(request.url).thenReturn(maliciousUri) + whenever(request.isForMainFrame).thenReturn(true) + + val callbackChannel = Channel() + val firstCallbackDeferred = CompletableDeferred() + val secondCallbackDeferred = CompletableDeferred() + + whenever(maliciousSiteProtection.isMalicious(any(), any())).thenAnswer { invocation -> + val callback = invocation.getArgument<(Boolean) -> Unit>(1) + + launch { + callbackChannel.receive() + callback(true) + } + WAIT_FOR_CONFIRMATION + } + + testee.shouldOverrideUrlLoading(maliciousUri, true) { isMalicious -> + firstCallbackDeferred.complete(isMalicious) + } + + testee.shouldOverrideUrlLoading(exampleUri, true) { isMalicious -> + secondCallbackDeferred.complete(isMalicious) + } + + callbackChannel.send(Unit) + callbackChannel.send(Unit) + + val firstCallbackResult = firstCallbackDeferred.await() + val secondCallbackResult = secondCallbackDeferred.await() + + assertEquals(false, firstCallbackResult) + assertEquals(true, secondCallbackResult) + } + + @Test + fun `isMalicious callback result should be processed if no new page loads triggering isMalicious have started`() = runTest { + val request = mock(WebResourceRequest::class.java) + whenever(request.url).thenReturn(maliciousUri) + whenever(request.isForMainFrame).thenReturn(true) + + val callbackChannel = Channel() + val firstCallbackDeferred = CompletableDeferred() + val secondCallbackDeferred = CompletableDeferred() + + whenever(maliciousSiteProtection.isMalicious(any(), any())).thenAnswer { invocation -> + val callback = invocation.getArgument<(Boolean) -> Unit>(1) + + launch { + callbackChannel.receive() + callback(true) + } + WAIT_FOR_CONFIRMATION + } + + testee.shouldOverrideUrlLoading(maliciousUri, true) { isMalicious -> + firstCallbackDeferred.complete(isMalicious) + } + + callbackChannel.send(Unit) + + testee.shouldOverrideUrlLoading(exampleUri, true) { isMalicious -> + secondCallbackDeferred.complete(isMalicious) + } + + callbackChannel.send(Unit) + + val firstCallbackResult = firstCallbackDeferred.await() + val secondCallbackResult = secondCallbackDeferred.await() + + assertEquals(true, firstCallbackResult) + assertEquals(true, secondCallbackResult) + } + private fun updateFeatureEnabled(enabled: Boolean) { fakeAndroidBrowserConfigFeature.enableMaliciousSiteProtection().setRawStoredState(State(enabled)) testee.onPrivacyConfigDownloaded() diff --git a/malicious-site-protection/malicious-site-protection-impl/build.gradle b/malicious-site-protection/malicious-site-protection-impl/build.gradle index c9454a21f7aa..19aa8196ef5a 100644 --- a/malicious-site-protection/malicious-site-protection-impl/build.gradle +++ b/malicious-site-protection/malicious-site-protection-impl/build.gradle @@ -30,9 +30,12 @@ dependencies { implementation project(path: ':anvil-annotations') implementation project(path: ':di') ksp AndroidX.room.compiler + implementation AndroidX.room.runtime + implementation AndroidX.room.ktx implementation KotlinX.coroutines.android implementation AndroidX.core.ktx + implementation AndroidX.work.runtimeKtx implementation Google.dagger implementation project(path: ':common-utils') @@ -43,6 +46,7 @@ dependencies { implementation Google.android.material testImplementation AndroidX.test.ext.junit + testImplementation 'org.jetbrains.kotlinx:kotlinx-coroutines-test:1.5.2' testImplementation Testing.junit4 testImplementation "org.mockito.kotlin:mockito-kotlin:_" testImplementation project(path: ':common-test') @@ -54,6 +58,7 @@ dependencies { // conflicts with mockito due to direct inclusion of byte buddy exclude group: "org.jetbrains.kotlinx", module: "kotlinx-coroutines-debug" } + testImplementation AndroidX.work.testing coreLibraryDesugaring Android.tools.desugarJdkLibs } diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteModule.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteModule.kt new file mode 100644 index 000000000000..8682a39687c4 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteModule.kt @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl + +import android.content.Context +import androidx.room.Room +import com.duckduckgo.di.scopes.AppScope +import com.duckduckgo.malicioussiteprotection.impl.data.db.MaliciousSiteDao +import com.duckduckgo.malicioussiteprotection.impl.data.db.MaliciousSitesDatabase +import com.duckduckgo.malicioussiteprotection.impl.data.db.MaliciousSitesDatabase.Companion.ALL_MIGRATIONS +import com.squareup.anvil.annotations.ContributesTo +import dagger.Module +import dagger.Provides +import dagger.SingleInstanceIn +import java.security.MessageDigest + +@Module +@ContributesTo(AppScope::class) +class MaliciousSiteModule { + + @Provides + @SingleInstanceIn(AppScope::class) + fun provideMaliciousSiteProtectionDatabase(context: Context): MaliciousSitesDatabase { + return Room.databaseBuilder(context, MaliciousSitesDatabase::class.java, "malicious_sites.db") + .addMigrations(*ALL_MIGRATIONS) + .fallbackToDestructiveMigration() + .build() + } + + @Provides + @SingleInstanceIn(AppScope::class) + fun provideMaliciousSiteDao(database: MaliciousSitesDatabase): MaliciousSiteDao { + return database.maliciousSiteDao() + } + + @Provides + @SingleInstanceIn(AppScope::class) + fun provideMessageDigest(): MessageDigest { + return MessageDigest.getInstance("SHA-256") + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionFiltersUpdateWorker.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionFiltersUpdateWorker.kt new file mode 100644 index 000000000000..0971c00c7827 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionFiltersUpdateWorker.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2025 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl + +import android.content.Context +import androidx.lifecycle.LifecycleOwner +import androidx.work.BackoffPolicy +import androidx.work.CoroutineWorker +import androidx.work.ExistingPeriodicWorkPolicy +import androidx.work.PeriodicWorkRequestBuilder +import androidx.work.WorkManager +import androidx.work.WorkerParameters +import com.duckduckgo.anvil.annotations.ContributesWorker +import com.duckduckgo.app.lifecycle.MainProcessLifecycleObserver +import com.duckduckgo.common.utils.DispatcherProvider +import com.duckduckgo.di.scopes.AppScope +import com.duckduckgo.malicioussiteprotection.impl.data.MaliciousSiteRepository +import com.squareup.anvil.annotations.ContributesMultibinding +import dagger.SingleInstanceIn +import java.util.concurrent.TimeUnit +import javax.inject.Inject +import kotlinx.coroutines.withContext + +@ContributesWorker(AppScope::class) +class MaliciousSiteProtectionFiltersUpdateWorker( + context: Context, + workerParameters: WorkerParameters, +) : CoroutineWorker(context, workerParameters) { + @Inject + lateinit var maliciousSiteRepository: MaliciousSiteRepository + + @Inject + lateinit var dispatcherProvider: DispatcherProvider + + @Inject + lateinit var maliciousSiteProtectionFeature: MaliciousSiteProtectionRCFeature + + override suspend fun doWork(): Result { + return withContext(dispatcherProvider.io()) { + if (maliciousSiteProtectionFeature.isFeatureEnabled().not()) { + return@withContext Result.success() + } + return@withContext if (maliciousSiteRepository.loadFilters().isSuccess) { + Result.success() + } else { + Result.retry() + } + } + } +} + +@ContributesMultibinding( + scope = AppScope::class, + boundType = MainProcessLifecycleObserver::class, +) +@SingleInstanceIn(AppScope::class) +class MaliciousSiteProtectionFiltersUpdateWorkerScheduler @Inject constructor( + private val workManager: WorkManager, + private val maliciousSiteProtectionFeature: MaliciousSiteProtectionRCFeature, + +) : MainProcessLifecycleObserver { + + override fun onCreate(owner: LifecycleOwner) { + val workerRequest = PeriodicWorkRequestBuilder( + maliciousSiteProtectionFeature.getFilterSetUpdateFrequency(), + TimeUnit.MINUTES, + ) + .addTag(MALICIOUS_SITE_PROTECTION_FILTERS_UPDATE_WORKER_TAG) + .setBackoffCriteria(BackoffPolicy.EXPONENTIAL, 1, TimeUnit.MINUTES) + .build() + workManager.enqueueUniquePeriodicWork(MALICIOUS_SITE_PROTECTION_FILTERS_UPDATE_WORKER_TAG, ExistingPeriodicWorkPolicy.UPDATE, workerRequest) + } + + companion object { + private const val MALICIOUS_SITE_PROTECTION_FILTERS_UPDATE_WORKER_TAG = "MALICIOUS_SITE_PROTECTION_FILTERS_UPDATE_WORKER_TAG" + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionHashPrefixesUpdateWorker.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionHashPrefixesUpdateWorker.kt new file mode 100644 index 000000000000..fe111bcb225a --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionHashPrefixesUpdateWorker.kt @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2025 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl + +import android.content.Context +import androidx.lifecycle.LifecycleOwner +import androidx.work.BackoffPolicy +import androidx.work.CoroutineWorker +import androidx.work.ExistingPeriodicWorkPolicy +import androidx.work.PeriodicWorkRequestBuilder +import androidx.work.WorkManager +import androidx.work.WorkerParameters +import com.duckduckgo.anvil.annotations.ContributesWorker +import com.duckduckgo.app.lifecycle.MainProcessLifecycleObserver +import com.duckduckgo.common.utils.DispatcherProvider +import com.duckduckgo.di.scopes.AppScope +import com.duckduckgo.malicioussiteprotection.impl.data.MaliciousSiteRepository +import com.squareup.anvil.annotations.ContributesMultibinding +import dagger.SingleInstanceIn +import java.util.concurrent.TimeUnit +import javax.inject.Inject +import kotlinx.coroutines.withContext + +@ContributesWorker(AppScope::class) +class MaliciousSiteProtectionHashPrefixesUpdateWorker( + context: Context, + workerParameters: WorkerParameters, +) : CoroutineWorker(context, workerParameters) { + @Inject + lateinit var maliciousSiteRepository: MaliciousSiteRepository + + @Inject + lateinit var dispatcherProvider: DispatcherProvider + + @Inject + lateinit var maliciousSiteProtectionFeature: MaliciousSiteProtectionRCFeature + + override suspend fun doWork(): Result { + return withContext(dispatcherProvider.io()) { + if (maliciousSiteProtectionFeature.isFeatureEnabled().not()) { + return@withContext Result.success() + } + return@withContext if (maliciousSiteRepository.loadHashPrefixes().isSuccess) { + Result.success() + } else { + Result.retry() + } + } + } +} + +@ContributesMultibinding( + scope = AppScope::class, + boundType = MainProcessLifecycleObserver::class, +) +@SingleInstanceIn(AppScope::class) +class MaliciousSiteProtectionHashPrefixesUpdateWorkerScheduler @Inject constructor( + private val workManager: WorkManager, + private val maliciousSiteProtectionFeature: MaliciousSiteProtectionRCFeature, + +) : MainProcessLifecycleObserver { + + override fun onCreate(owner: LifecycleOwner) { + val workerRequest = PeriodicWorkRequestBuilder( + maliciousSiteProtectionFeature.getHashPrefixUpdateFrequency(), + TimeUnit.MINUTES, + ) + .addTag(MALICIOUS_SITE_PROTECTION_HASH_PREFIXES_UPDATE_WORKER_TAG) + .setBackoffCriteria(BackoffPolicy.EXPONENTIAL, 1, TimeUnit.MINUTES) + .build() + workManager.enqueueUniquePeriodicWork( + MALICIOUS_SITE_PROTECTION_HASH_PREFIXES_UPDATE_WORKER_TAG, + ExistingPeriodicWorkPolicy.UPDATE, + workerRequest, + ) + } + + companion object { + private const val MALICIOUS_SITE_PROTECTION_HASH_PREFIXES_UPDATE_WORKER_TAG = "MALICIOUS_SITE_PROTECTION_HASH_PREFIXES_UPDATE_WORKER_TAG" + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/RealMaliciousSiteProtection.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/RealMaliciousSiteProtection.kt deleted file mode 100644 index 7ff0dd28d5d5..000000000000 --- a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/RealMaliciousSiteProtection.kt +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2024 DuckDuckGo - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.duckduckgo.malicioussiteprotection.impl - -import android.net.Uri -import com.duckduckgo.di.scopes.AppScope -import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection -import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection.IsMaliciousResult -import com.squareup.anvil.annotations.ContributesBinding -import javax.inject.Inject -import timber.log.Timber - -@ContributesBinding(AppScope::class, MaliciousSiteProtection::class) -class RealMaliciousSiteProtection @Inject constructor( - maliciousSiteProtectionRCFeature: MaliciousSiteProtectionRCFeature, -) : MaliciousSiteProtection { - - override suspend fun isMalicious(url: Uri, confirmationCallback: (isMalicious: Boolean) -> Unit): IsMaliciousResult { - Timber.tag("MaliciousSiteProtection").d("isMalicious $url") - // TODO (cbarreiro): Implement the logic to check if the URL is malicious - return IsMaliciousResult.SAFE - } -} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/MaliciousSiteRepository.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/MaliciousSiteRepository.kt new file mode 100644 index 000000000000..d6682d3ba872 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/MaliciousSiteRepository.kt @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2024 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.data + +import com.duckduckgo.common.utils.DispatcherProvider +import com.duckduckgo.di.scopes.AppScope +import com.duckduckgo.malicioussiteprotection.impl.data.db.MaliciousSiteDao +import com.duckduckgo.malicioussiteprotection.impl.data.db.RevisionEntity +import com.duckduckgo.malicioussiteprotection.impl.data.network.FilterResponse +import com.duckduckgo.malicioussiteprotection.impl.data.network.FilterSetResponse +import com.duckduckgo.malicioussiteprotection.impl.data.network.HashPrefixResponse +import com.duckduckgo.malicioussiteprotection.impl.data.network.MaliciousSiteService +import com.duckduckgo.malicioussiteprotection.impl.models.Feed +import com.duckduckgo.malicioussiteprotection.impl.models.Feed.MALWARE +import com.duckduckgo.malicioussiteprotection.impl.models.Feed.PHISHING +import com.duckduckgo.malicioussiteprotection.impl.models.Filter +import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision.MalwareFilterSetWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision.PhishingFilterSetWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision.MalwareHashPrefixesWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision.PhishingHashPrefixesWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.Match +import com.duckduckgo.malicioussiteprotection.impl.models.Type +import com.duckduckgo.malicioussiteprotection.impl.models.Type.FILTER_SET +import com.duckduckgo.malicioussiteprotection.impl.models.Type.HASH_PREFIXES +import com.squareup.anvil.annotations.ContributesBinding +import dagger.SingleInstanceIn +import javax.inject.Inject +import kotlinx.coroutines.withContext + +interface MaliciousSiteRepository { + suspend fun containsHashPrefix(hashPrefix: String): Boolean + suspend fun getFilters(hash: String): List? + suspend fun matches(hashPrefix: String): List + suspend fun loadFilters(): Result + suspend fun loadHashPrefixes(): Result +} + +@ContributesBinding(AppScope::class) +@SingleInstanceIn(AppScope::class) +class RealMaliciousSiteRepository @Inject constructor( + private val maliciousSiteDao: MaliciousSiteDao, + private val maliciousSiteService: MaliciousSiteService, + private val dispatcherProvider: DispatcherProvider, +) : MaliciousSiteRepository { + + override suspend fun containsHashPrefix(hashPrefix: String): Boolean { + return maliciousSiteDao.getHashPrefix(hashPrefix) != null + } + + override suspend fun getFilters(hash: String): List? { + return maliciousSiteDao.getFilter(hash)?.let { + it.map { + Filter(it.hash, it.regex) + } + } + } + + override suspend fun matches(hashPrefix: String): List { + return try { + maliciousSiteService.getMatches(hashPrefix).matches.map { + Match(it.hostname, it.url, it.regex, it.hash) + } + } catch (e: Exception) { + listOf() + } + } + + override suspend fun loadFilters(): Result { + return loadDataOfType(FILTER_SET) { latestRevision, networkRevision, feed -> loadFilters(latestRevision, networkRevision, feed) } + } + + override suspend fun loadHashPrefixes(): Result { + return loadDataOfType(HASH_PREFIXES) { latestRevision, networkRevision, feed -> loadHashPrefixes(latestRevision, networkRevision, feed) } + } + + private suspend fun loadDataOfType( + type: Type, + loadData: suspend (revisions: List, networkRevision: Int, feed: Feed) -> Unit, + ): Result { + return withContext(dispatcherProvider.io()) { + val networkRevision = maliciousSiteService.getRevision().revision + + val localRevisions = getLocalRevisions(type) + + val result = Feed.entries.fold(Result.success(Unit)) { acc, feed -> + try { + loadData(localRevisions, networkRevision, feed) + acc + } catch (e: Exception) { + Result.failure(e) + } + } + result + } + } + + private suspend fun loadAndUpdateData( + latestRevision: List, + networkRevision: Int, + feed: Feed, + getFunction: suspend (Int) -> T?, + updateFunction: suspend (T?) -> Unit, + ) { + val revision = latestRevision.getRevisionForFeed(feed) + val data: T? = if (networkRevision > revision) { + getFunction(revision) + } else { + null + } + + updateFunction(data) + } + + private suspend fun loadFilters( + latestRevision: List, + networkRevision: Int, + feed: Feed, + ) { + loadAndUpdateData( + latestRevision, + networkRevision, + feed, + when (feed) { + PHISHING -> maliciousSiteService::getPhishingFilterSet + MALWARE -> maliciousSiteService::getMalwareFilterSet + }, + ) { maliciousSiteDao.updateFilters(it?.toFilterSetWithRevision(feed)) } + } + + private suspend fun loadHashPrefixes( + latestRevision: List, + networkRevision: Int, + feed: Feed, + ) { + loadAndUpdateData( + latestRevision, + networkRevision, + feed, + when (feed) { + PHISHING -> maliciousSiteService::getPhishingHashPrefixes + MALWARE -> maliciousSiteService::getMalwareHashPrefixes + }, + ) { maliciousSiteDao.updateHashPrefixes(it?.toHashPrefixesWithRevision(feed)) } + } + + private fun FilterSetResponse.toFilterSetWithRevision(feed: Feed): FilterSetWithRevision { + val insert = insert.toFilterSet() + val delete = delete.toFilterSet() + return when (feed) { + PHISHING -> PhishingFilterSetWithRevision(insert, delete, revision, replace) + MALWARE -> MalwareFilterSetWithRevision(insert, delete, revision, replace) + } + } + + private fun HashPrefixResponse.toHashPrefixesWithRevision(feed: Feed): HashPrefixesWithRevision { + return when (feed) { + PHISHING -> PhishingHashPrefixesWithRevision(insert, delete, revision, replace) + MALWARE -> MalwareHashPrefixesWithRevision(insert, delete, revision, replace) + } + } + + private suspend fun getLocalRevisions(type: Type) = (maliciousSiteDao.getLatestRevision()?.filter { it.type == type.name } ?: listOf()) + + private fun Set.toFilterSet(): Set { + return map { Filter(it.hash, it.regex) }.toSet() + } + + private fun List.getRevisionForFeed(feed: Feed): Int { + return firstOrNull { it.feed == feed.name }?.revision ?: 0 + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDao.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDao.kt new file mode 100644 index 000000000000..7c0dfa2bf7de --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDao.kt @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2024 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.data.db + +import androidx.room.Dao +import androidx.room.Insert +import androidx.room.OnConflictStrategy +import androidx.room.Query +import androidx.room.Transaction +import com.duckduckgo.malicioussiteprotection.impl.models.Feed +import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.Type + +@Dao +interface MaliciousSiteDao { + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertRevision(revision: RevisionEntity) + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertHashPrefixes(items: List) + + @Query("DELETE FROM hash_prefixes") + suspend fun deleteHashPrefixes() + + @Query("DELETE FROM filters") + suspend fun deleteFilters() + + @Query("DELETE FROM hash_prefixes WHERE type = :type") + suspend fun deleteHashPrefixes(type: String) + + @Query("DELETE FROM hash_prefixes WHERE hashPrefix = :hashPrefix AND type = :type") + suspend fun deleteHashPrefix( + hashPrefix: String, + type: String, + ) + + @Query("DELETE FROM filters WHERE hash = :hash AND type = :type") + suspend fun deleteFilter( + hash: String, + type: String, + ) + + @Query("DELETE FROM filters WHERE type = :type") + suspend fun deleteFilters(type: String) + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertFilters(items: Set) + + @Query("SELECT * FROM revisions") + suspend fun getLatestRevision(): List? + + @Query("SELECT * FROM revisions WHERE feed = :feed AND type = :type") + suspend fun getLatestRevision(feed: String, type: String): RevisionEntity? + + @Query("DELETE FROM revisions") + suspend fun deleteRevisions() + + @Query("SELECT * FROM hash_prefixes WHERE hashPrefix = :hashPrefix") + suspend fun getHashPrefix(hashPrefix: String): HashPrefixEntity? + + @Query("SELECT * FROM filters WHERE hash = :hash") + suspend fun getFilter(hash: String): List? + + @Transaction + suspend fun updateHashPrefixes( + hashPrefixes: HashPrefixesWithRevision?, + ) { + hashPrefixes ?: return + + updateData( + type = hashPrefixes.type, + feed = hashPrefixes.feed, + updateWithRevision = { localRevision: Int -> updateHashPrefixes(hashPrefixes, localRevision) }, + ) + } + + @Transaction + suspend fun updateFilters( + filterSet: FilterSetWithRevision?, + ) { + filterSet ?: return + + updateData( + type = filterSet.type, + feed = filterSet.feed, + updateWithRevision = { localRevision: Int -> updateFilters(filterSet, localRevision) }, + ) + } + + private suspend fun updateFilters(filterSet: FilterSetWithRevision, currentLocalRevision: Int): Int { + return filterSet.takeIf { isNewRevisionNewerThanLocal(it.revision, currentLocalRevision) }?.apply { + updateData( + replace = filterSet.replace, + deleteAll = { deleteFilters(filterSet.feed.name) }, + deleteItem = { deleteFilter(it.hash, filterSet.feed.name) }, + insertItems = { insertFilters(filterSet.insert.map { FilterEntity(it.hash, it.regex, type = filterSet.feed.name) }.toSet()) }, + itemsToDelete = filterSet.delete, + ) + }?.revision ?: currentLocalRevision + } + + private suspend fun updateHashPrefixes(hashPrefixes: HashPrefixesWithRevision, currentLocalRevision: Int): Int { + return hashPrefixes.takeIf { isNewRevisionNewerThanLocal(it.revision, currentLocalRevision) }?.apply { + updateData( + replace = hashPrefixes.replace, + deleteAll = { deleteHashPrefixes(hashPrefixes.feed.name) }, + deleteItem = { deleteHashPrefix(it, hashPrefixes.feed.name) }, + insertItems = { insertHashPrefixes(hashPrefixes.insert.map { HashPrefixEntity(hashPrefix = it, type = hashPrefixes.feed.name) }) }, + itemsToDelete = hashPrefixes.delete, + ) + }?.revision ?: currentLocalRevision + } + + private suspend fun updateData( + type: Type, + feed: Feed, + updateWithRevision: suspend (Int) -> Int, + ) { + val currentLocalRevision = getLatestRevision(feed = feed, type = type) + val newRevision = updateWithRevision(currentLocalRevision) + if (currentLocalRevision != newRevision) { + insertRevision(RevisionEntity(feed = feed.name, type = type.name, revision = newRevision)) + } + } + + fun isNewRevisionNewerThanLocal( + newRevision: Int, + currentLocalRevision: Int, + ) = newRevision > currentLocalRevision + + private suspend fun updateData( + replace: Boolean, + deleteAll: suspend () -> Unit, + deleteItem: suspend (T) -> Unit, + insertItems: suspend () -> Unit, + itemsToDelete: Set, + ) { + if (replace) { + deleteAll() + } else { + itemsToDelete.forEach { deleteItem(it) } + } + insertItems() + } + + private suspend fun getLatestRevision(feed: Feed, type: Type): Int { + return getLatestRevision(feed = feed.name, type = type.name)?.revision ?: 0 + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSitesDataEntities.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSitesDataEntities.kt new file mode 100644 index 000000000000..5cb1d1fb79c6 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSitesDataEntities.kt @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2024 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.data.db + +import androidx.room.Entity +import androidx.room.PrimaryKey + +@Entity(tableName = "revisions", primaryKeys = ["feed", "type"]) +data class RevisionEntity( + val feed: String, + val type: String, + val revision: Int, +) + +@Entity( + tableName = "hash_prefixes", +) +data class HashPrefixEntity( + @PrimaryKey + val hashPrefix: String, + val type: String, +) + +@Entity( + tableName = "filters", +) +data class FilterEntity( + @PrimaryKey + val hash: String, + val regex: String, + val type: String, +) diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSitesDatabase.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSitesDatabase.kt new file mode 100644 index 000000000000..0c6ee83cd9aa --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSitesDatabase.kt @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.data.db + +import androidx.room.Database +import androidx.room.RoomDatabase +import androidx.room.migration.Migration + +@Database( + exportSchema = true, + entities = [RevisionEntity::class, HashPrefixEntity::class, FilterEntity::class], + version = 1, +) +abstract class MaliciousSitesDatabase : RoomDatabase() { + abstract fun maliciousSiteDao(): MaliciousSiteDao + + companion object { + val ALL_MIGRATIONS = arrayOf() + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/network/MaliciousSiteService.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/network/MaliciousSiteService.kt new file mode 100644 index 000000000000..ba44b64ac2c2 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/network/MaliciousSiteService.kt @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2024 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.data.network + +import com.duckduckgo.anvil.annotations.ContributesServiceApi +import com.duckduckgo.common.utils.AppUrl.Url.API +import com.duckduckgo.di.scopes.AppScope +import retrofit2.http.GET +import retrofit2.http.Query + +private const val BASE_URL = "$API/api/protection/v1/android" +private const val HASH_PREFIX_PATH = "/hashPrefix" +private const val FILTER_SET_PATH = "/filterSet" +private const val CATEGORY = "category" +private const val PHISHING = "phishing" +private const val MALWARE = "malware" + +@ContributesServiceApi(AppScope::class) +interface MaliciousSiteService { + @GET("$BASE_URL$HASH_PREFIX_PATH?$CATEGORY=$PHISHING") + suspend fun getPhishingHashPrefixes(@Query("revision") revision: Int): HashPrefixResponse + + @GET("$BASE_URL$HASH_PREFIX_PATH?$CATEGORY=$MALWARE") + suspend fun getMalwareHashPrefixes(@Query("revision") revision: Int): HashPrefixResponse + + @GET("$BASE_URL$FILTER_SET_PATH?$CATEGORY=$PHISHING") + suspend fun getPhishingFilterSet(@Query("revision") revision: Int): FilterSetResponse + + @GET("$BASE_URL$FILTER_SET_PATH?$CATEGORY=$MALWARE") + suspend fun getMalwareFilterSet(@Query("revision") revision: Int): FilterSetResponse + + @GET("$BASE_URL/matches") + suspend fun getMatches(@Query("hashPrefix") hashPrefix: String): MatchesResponse + + @GET("$BASE_URL/revision") + suspend fun getRevision(): RevisionResponse +} + +data class MatchesResponse( + val matches: List, +) + +data class HashPrefixResponse( + val insert: Set, + val delete: Set, + val revision: Int, + val replace: Boolean, +) + +data class FilterSetResponse( + val insert: Set, + val delete: Set, + val revision: Int, + val replace: Boolean, +) + +data class FilterResponse( + val hash: String, + val regex: String, +) + +data class MatchResponse( + val hostname: String, + val url: String, + val regex: String, + val hash: String, +) + +data class RevisionResponse( + val revision: Int, +) diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtection.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtection.kt new file mode 100644 index 000000000000..f73d7215aaf1 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtection.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2024 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.domain + +import android.net.Uri +import com.duckduckgo.app.di.AppCoroutineScope +import com.duckduckgo.common.utils.DispatcherProvider +import com.duckduckgo.di.scopes.AppScope +import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection +import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection.IsMaliciousResult +import com.duckduckgo.malicioussiteprotection.impl.MaliciousSiteProtectionRCFeature +import com.duckduckgo.malicioussiteprotection.impl.data.MaliciousSiteRepository +import com.squareup.anvil.annotations.ContributesBinding +import java.security.MessageDigest +import java.util.regex.Pattern +import javax.inject.Inject +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.launch +import timber.log.Timber + +@ContributesBinding(AppScope::class, MaliciousSiteProtection::class) +class RealMaliciousSiteProtection @Inject constructor( + private val dispatchers: DispatcherProvider, + @AppCoroutineScope private val appCoroutineScope: CoroutineScope, + private val maliciousSiteRepository: MaliciousSiteRepository, + private val messageDigest: MessageDigest, + private val maliciousSiteProtectionRCFeature: MaliciousSiteProtectionRCFeature, +) : MaliciousSiteProtection { + + override suspend fun isMalicious(url: Uri, confirmationCallback: (isMalicious: Boolean) -> Unit): IsMaliciousResult { + Timber.tag("MaliciousSiteProtection").d("isMalicious $url") + + if (!maliciousSiteProtectionRCFeature.isFeatureEnabled()) { + Timber.d("\uD83D\uDFE2 Cris: should not block (feature disabled) $url") + return IsMaliciousResult.SAFE + } + + val hostname = url.host ?: return IsMaliciousResult.SAFE + val hash = messageDigest + .digest(hostname.toByteArray(Charsets.UTF_8)) + .joinToString("") { "%02x".format(it) } + val hashPrefix = hash.substring(0, 8) + + if (!maliciousSiteRepository.containsHashPrefix(hashPrefix)) { + Timber.d("\uD83D\uDFE2 Cris: should not block (no hash) $hashPrefix, $url") + return IsMaliciousResult.SAFE + } + maliciousSiteRepository.getFilters(hash)?.let { + for (filter in it) { + if (Pattern.compile(filter.regex).matcher(url.toString()).find()) { + Timber.d("\uD83D\uDFE2 Cris: shouldBlock $url") + return IsMaliciousResult.MALICIOUS + } + } + } + appCoroutineScope.launch(dispatchers.io()) { + confirmationCallback(matches(hashPrefix, url, hostname, hash)) + } + return IsMaliciousResult.WAIT_FOR_CONFIRMATION + } + + private suspend fun matches( + hashPrefix: String, + url: Uri, + hostname: String, + hash: String, + ): Boolean { + val matches = maliciousSiteRepository.matches(hashPrefix.substring(0, 4)) + return matches.any { match -> + Pattern.compile(match.regex).matcher(url.toString()).find() && + (hostname == match.hostname) && + (hash == match.hash) + }.also { matched -> + Timber.d("\uD83D\uDFE2 Cris: should block $matched") + } + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/models/MaliciousSiteProtectionDataModels.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/models/MaliciousSiteProtectionDataModels.kt new file mode 100644 index 000000000000..55543ef51687 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/models/MaliciousSiteProtectionDataModels.kt @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2025 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.models + +import com.duckduckgo.malicioussiteprotection.impl.models.Feed.MALWARE +import com.duckduckgo.malicioussiteprotection.impl.models.Feed.PHISHING + +data class Match( + val hostname: String, + val url: String, + val regex: String, + val hash: String, +) + +data class Filter( + val hash: String, + val regex: String, +) + +sealed class FilterSetWithRevision( + open val insert: Set, + open val delete: Set, + open val revision: Int, + open val replace: Boolean, + val feed: Feed, + val type: Type = Type.FILTER_SET, +) { + data class PhishingFilterSetWithRevision( + override val insert: Set, + override val delete: Set, + override val revision: Int, + override val replace: Boolean, + ) : FilterSetWithRevision(insert, delete, revision, replace, PHISHING) + + data class MalwareFilterSetWithRevision( + override val insert: Set, + override val delete: Set, + override val revision: Int, + override val replace: Boolean, + ) : FilterSetWithRevision(insert, delete, revision, replace, MALWARE) +} + +sealed class HashPrefixesWithRevision( + open val insert: Set, + open val delete: Set, + open val revision: Int, + open val replace: Boolean, + val feed: Feed, + val type: Type = Type.HASH_PREFIXES, +) { + data class PhishingHashPrefixesWithRevision( + override val insert: Set, + override val delete: Set, + override val revision: Int, + override val replace: Boolean, + ) : HashPrefixesWithRevision(insert, delete, revision, replace, PHISHING) + + data class MalwareHashPrefixesWithRevision( + override val insert: Set, + override val delete: Set, + override val revision: Int, + override val replace: Boolean, + ) : HashPrefixesWithRevision(insert, delete, revision, replace, MALWARE) +} + +enum class Feed { + PHISHING, + MALWARE, +} + +enum class Type { + HASH_PREFIXES, + FILTER_SET, +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionFiltersUpdateWorkerTest.kt b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionFiltersUpdateWorkerTest.kt new file mode 100644 index 000000000000..1e2c670c4484 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionFiltersUpdateWorkerTest.kt @@ -0,0 +1,103 @@ +package com.duckduckgo.malicioussiteprotection.impl + +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import androidx.work.ExistingPeriodicWorkPolicy +import androidx.work.ListenableWorker.Result.retry +import androidx.work.ListenableWorker.Result.success +import androidx.work.PeriodicWorkRequest +import androidx.work.WorkManager +import androidx.work.testing.TestListenableWorkerBuilder +import com.duckduckgo.common.test.CoroutineTestRule +import com.duckduckgo.malicioussiteprotection.impl.data.MaliciousSiteRepository +import java.util.concurrent.TimeUnit +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.mock +import org.mockito.Mockito.verify +import org.mockito.kotlin.capture +import org.mockito.kotlin.eq +import org.mockito.kotlin.whenever + +@RunWith(AndroidJUnit4::class) +class MaliciousSiteProtectionFiltersUpdateWorkerTest { + + @get:Rule + var coroutineRule = CoroutineTestRule() + + private val context = InstrumentationRegistry.getInstrumentation().targetContext + private val maliciousSiteRepository: MaliciousSiteRepository = mock() + private val dispatcherProvider = coroutineRule.testDispatcherProvider + private val maliciousSiteProtectionFeature: MaliciousSiteProtectionRCFeature = mock() + private val worker = TestListenableWorkerBuilder(context = context).build() + + @Before + fun setup() { + worker.maliciousSiteRepository = maliciousSiteRepository + worker.dispatcherProvider = dispatcherProvider + worker.maliciousSiteProtectionFeature = maliciousSiteProtectionFeature + } + + @Test + fun doWork_returnsSuccessWhenFeatureIsDisabled() = runTest { + whenever(maliciousSiteProtectionFeature.isFeatureEnabled()).thenReturn(false) + + val result = worker.doWork() + + assertEquals(success(), result) + } + + @Test + fun doWork_returnsSuccessWhenLoadFiltersSucceeds() = runTest { + whenever(maliciousSiteProtectionFeature.isFeatureEnabled()).thenReturn(true) + + val result = worker.doWork() + + assertEquals(success(), result) + verify(maliciousSiteRepository).loadFilters() + } + + @Test + fun doWork_returnsRetryWhenLoadFiltersFails() = runTest { + whenever(maliciousSiteProtectionFeature.isFeatureEnabled()).thenReturn(true) + whenever(maliciousSiteRepository.loadFilters()).thenReturn(Result.failure(Exception())) + + val result = worker.doWork() + + assertEquals(retry(), result) + } +} + +class MaliciousSiteProtectionFiltersUpdateWorkerSchedulerTest { + + private val workManager: WorkManager = mock() + private val maliciousSiteProtectionFeature: MaliciousSiteProtectionRCFeature = mock() + private val scheduler = MaliciousSiteProtectionFiltersUpdateWorkerScheduler(workManager, maliciousSiteProtectionFeature) + + @Test + fun onCreate_schedulesWorkerWithUpdateFrequencyFromRCFlag() { + val updateFrequencyMinutes = 15L + + whenever(maliciousSiteProtectionFeature.getFilterSetUpdateFrequency()).thenReturn(updateFrequencyMinutes) + + scheduler.onCreate(mock()) + + val workRequestCaptor = ArgumentCaptor.forClass(PeriodicWorkRequest::class.java) + verify(workManager).enqueueUniquePeriodicWork( + eq("MALICIOUS_SITE_PROTECTION_FILTERS_UPDATE_WORKER_TAG"), + eq(ExistingPeriodicWorkPolicy.UPDATE), + capture(workRequestCaptor), + ) + + val capturedWorkRequest = workRequestCaptor.value + val repeatInterval = capturedWorkRequest.workSpec.intervalDuration + val expectedInterval = TimeUnit.MINUTES.toMillis(updateFrequencyMinutes) + + assertEquals(expectedInterval, repeatInterval) + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionHashPrefixesUpdateWorkerTest.kt b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionHashPrefixesUpdateWorkerTest.kt new file mode 100644 index 000000000000..9046320a79e9 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/MaliciousSiteProtectionHashPrefixesUpdateWorkerTest.kt @@ -0,0 +1,103 @@ +package com.duckduckgo.malicioussiteprotection.impl + +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import androidx.work.ExistingPeriodicWorkPolicy +import androidx.work.ListenableWorker.Result.retry +import androidx.work.ListenableWorker.Result.success +import androidx.work.PeriodicWorkRequest +import androidx.work.WorkManager +import androidx.work.testing.TestListenableWorkerBuilder +import com.duckduckgo.common.test.CoroutineTestRule +import com.duckduckgo.malicioussiteprotection.impl.data.MaliciousSiteRepository +import java.util.concurrent.TimeUnit +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.ArgumentCaptor +import org.mockito.Mockito.mock +import org.mockito.Mockito.verify +import org.mockito.kotlin.capture +import org.mockito.kotlin.eq +import org.mockito.kotlin.whenever + +@RunWith(AndroidJUnit4::class) +class MaliciousSiteProtectionHashPrefixesUpdateWorkerTest { + + @get:Rule + var coroutineRule = CoroutineTestRule() + + private val context = InstrumentationRegistry.getInstrumentation().targetContext + private val maliciousSiteRepository: MaliciousSiteRepository = mock() + private val dispatcherProvider = coroutineRule.testDispatcherProvider + private val maliciousSiteProtectionFeature: MaliciousSiteProtectionRCFeature = mock() + private val worker = TestListenableWorkerBuilder(context = context).build() + + @Before + fun setup() { + worker.maliciousSiteRepository = maliciousSiteRepository + worker.dispatcherProvider = dispatcherProvider + worker.maliciousSiteProtectionFeature = maliciousSiteProtectionFeature + } + + @Test + fun doWork_returnsSuccessWhenFeatureIsDisabled() = runTest { + whenever(maliciousSiteProtectionFeature.isFeatureEnabled()).thenReturn(false) + + val result = worker.doWork() + + assertEquals(success(), result) + } + + @Test + fun doWork_returnsSuccessWhenLoadHashPrefixesSucceeds() = runTest { + whenever(maliciousSiteProtectionFeature.isFeatureEnabled()).thenReturn(true) + + val result = worker.doWork() + + assertEquals(success(), result) + verify(maliciousSiteRepository).loadHashPrefixes() + } + + @Test + fun doWork_returnsRetryWhenLoadHashPrefixesFails() = runTest { + whenever(maliciousSiteProtectionFeature.isFeatureEnabled()).thenReturn(true) + whenever(maliciousSiteRepository.loadHashPrefixes()).thenReturn(Result.failure(Exception())) + + val result = worker.doWork() + + assertEquals(retry(), result) + } +} + +class MaliciousSiteProtectionHashPrefixesUpdateWorkerSchedulerTest { + + private val workManager: WorkManager = mock() + private val maliciousSiteProtectionFeature: MaliciousSiteProtectionRCFeature = mock() + private val scheduler = MaliciousSiteProtectionHashPrefixesUpdateWorkerScheduler(workManager, maliciousSiteProtectionFeature) + + @Test + fun onCreate_schedulesWorkerWithUpdateFrequencyFromRCFlag() { + val updateFrequencyMinutes = 15L + + whenever(maliciousSiteProtectionFeature.getHashPrefixUpdateFrequency()).thenReturn(updateFrequencyMinutes) + + scheduler.onCreate(mock()) + + val workRequestCaptor = ArgumentCaptor.forClass(PeriodicWorkRequest::class.java) + verify(workManager).enqueueUniquePeriodicWork( + eq("MALICIOUS_SITE_PROTECTION_HASH_PREFIXES_UPDATE_WORKER_TAG"), + eq(ExistingPeriodicWorkPolicy.UPDATE), + capture(workRequestCaptor), + ) + + val capturedWorkRequest = workRequestCaptor.value + val repeatInterval = capturedWorkRequest.workSpec.intervalDuration + val expectedInterval = TimeUnit.MINUTES.toMillis(updateFrequencyMinutes) + + assertEquals(expectedInterval, repeatInterval) + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/RealMaliciousSiteRepositoryTest.kt b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/RealMaliciousSiteRepositoryTest.kt new file mode 100644 index 000000000000..68c12669e3d8 --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/RealMaliciousSiteRepositoryTest.kt @@ -0,0 +1,132 @@ +package com.duckduckgo.malicioussiteprotection.impl.data + +import com.duckduckgo.malicioussiteprotection.impl.data.db.FilterEntity +import com.duckduckgo.malicioussiteprotection.impl.data.db.HashPrefixEntity +import com.duckduckgo.malicioussiteprotection.impl.data.db.MaliciousSiteDao +import com.duckduckgo.malicioussiteprotection.impl.data.db.RevisionEntity +import com.duckduckgo.malicioussiteprotection.impl.data.network.FilterSetResponse +import com.duckduckgo.malicioussiteprotection.impl.data.network.HashPrefixResponse +import com.duckduckgo.malicioussiteprotection.impl.data.network.MaliciousSiteService +import com.duckduckgo.malicioussiteprotection.impl.data.network.MatchResponse +import com.duckduckgo.malicioussiteprotection.impl.data.network.MatchesResponse +import com.duckduckgo.malicioussiteprotection.impl.data.network.RevisionResponse +import com.duckduckgo.malicioussiteprotection.impl.models.Feed.PHISHING +import com.duckduckgo.malicioussiteprotection.impl.models.Filter +import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision.PhishingFilterSetWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision.PhishingHashPrefixesWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.Match +import com.duckduckgo.malicioussiteprotection.impl.models.Type +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test +import org.mockito.Mockito.mock +import org.mockito.Mockito.verify +import org.mockito.kotlin.any +import org.mockito.kotlin.never +import org.mockito.kotlin.whenever + +class RealMaliciousSiteRepositoryTest { + + @get:org.junit.Rule + var coroutineRule = com.duckduckgo.common.test.CoroutineTestRule() + + private val maliciousSiteDao: MaliciousSiteDao = mock() + private val maliciousSiteService: MaliciousSiteService = mock() + private val repository = RealMaliciousSiteRepository(maliciousSiteDao, maliciousSiteService, coroutineRule.testDispatcherProvider) + + @Test + fun loadFilters_updatesFiltersWhenNetworkRevisionIsHigher() = runTest { + val networkRevision = 2 + val latestRevision = listOf(RevisionEntity(PHISHING.name, Type.FILTER_SET.name, 1)) + val phishingFilterSetResponse = FilterSetResponse(setOf(), setOf(), networkRevision, false) + + whenever(maliciousSiteService.getRevision()).thenReturn(RevisionResponse(networkRevision)) + whenever(maliciousSiteDao.getLatestRevision()).thenReturn(latestRevision) + whenever(maliciousSiteService.getPhishingFilterSet(any())).thenReturn(phishingFilterSetResponse) + + repository.loadFilters() + + verify(maliciousSiteService).getPhishingFilterSet(latestRevision.first().revision) + verify(maliciousSiteDao).updateFilters(any()) + } + + @Test + fun loadFilters_doesNotUpdateFiltersWhenNetworkRevisionIsNotHigher() = runTest { + val networkRevision = 1 + val latestRevision = listOf(RevisionEntity(PHISHING.name, Type.FILTER_SET.name, 1)) + + whenever(maliciousSiteService.getRevision()).thenReturn(RevisionResponse(networkRevision)) + whenever(maliciousSiteDao.getLatestRevision()).thenReturn(latestRevision) + + repository.loadFilters() + + verify(maliciousSiteService, never()).getPhishingFilterSet(any()) + verify(maliciousSiteDao, never()).updateFilters(any()) + } + + @Test + fun loadHashPrefixes_updatesHashPrefixesWhenNetworkRevisionIsHigher() = runTest { + val networkRevision = 2 + val latestRevision = listOf(RevisionEntity(PHISHING.name, Type.HASH_PREFIXES.name, 1)) + val phishingHashPrefixResponse = HashPrefixResponse(setOf(), setOf(), networkRevision, false) + + whenever(maliciousSiteService.getRevision()).thenReturn(RevisionResponse(networkRevision)) + whenever(maliciousSiteDao.getLatestRevision()).thenReturn(latestRevision) + whenever(maliciousSiteService.getPhishingHashPrefixes(any())).thenReturn(phishingHashPrefixResponse) + + repository.loadHashPrefixes() + + verify(maliciousSiteService).getPhishingHashPrefixes(latestRevision.first().revision) + verify(maliciousSiteDao).updateHashPrefixes(any()) + } + + @Test + fun loadHashPrefixes_doesNotUpdateHashPrefixesWhenNetworkRevisionIsNotHigher() = runTest { + val networkRevision = 1 + val latestRevision = listOf(RevisionEntity(PHISHING.name, Type.HASH_PREFIXES.name, 1)) + + whenever(maliciousSiteService.getRevision()).thenReturn(RevisionResponse(networkRevision)) + whenever(maliciousSiteDao.getLatestRevision()).thenReturn(latestRevision) + + repository.loadHashPrefixes() + + verify(maliciousSiteService, never()).getPhishingHashPrefixes(any()) + verify(maliciousSiteDao, never()).updateHashPrefixes(any()) + } + + @Test + fun containsHashPrefix_returnsTrueWhenHashPrefixExists() = runTest { + val hashPrefix = "testPrefix" + + whenever(maliciousSiteDao.getHashPrefix(hashPrefix)).thenReturn(HashPrefixEntity(hashPrefix, PHISHING.name)) + + val result = repository.containsHashPrefix(hashPrefix) + + assertTrue(result) + } + + @Test + fun getFilters_returnsFiltersWhenHashExists() = runTest { + val hash = "testHash" + val filters = listOf(FilterEntity(hash, "regex", Type.FILTER_SET.name)) + + whenever(maliciousSiteDao.getFilter(hash)).thenReturn(filters) + + val result = repository.getFilters(hash) + + assertEquals(filters.map { Filter(it.hash, it.regex) }, result) + } + + @Test + fun matches_returnsMatchesWhenHashPrefixExists() = runTest { + val hashPrefix = "testPrefix" + val matchesResponse = MatchesResponse(listOf(MatchResponse("hostname", "url", "regex", "hash"))) + + whenever(maliciousSiteService.getMatches(hashPrefix)).thenReturn(matchesResponse) + + val result = repository.matches(hashPrefix) + + assertEquals(matchesResponse.matches.map { Match(it.hostname, it.url, it.regex, it.hash) }, result) + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDaoTest.kt b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDaoTest.kt new file mode 100644 index 000000000000..cb16995890ca --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDaoTest.kt @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2025 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.data.db + +import androidx.room.Room +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import com.duckduckgo.malicioussiteprotection.impl.models.Feed +import com.duckduckgo.malicioussiteprotection.impl.models.Filter +import com.duckduckgo.malicioussiteprotection.impl.models.FilterSetWithRevision.PhishingFilterSetWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.HashPrefixesWithRevision.PhishingHashPrefixesWithRevision +import com.duckduckgo.malicioussiteprotection.impl.models.Type +import kotlinx.coroutines.test.runTest +import org.junit.After +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertNull +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith + +@RunWith(AndroidJUnit4::class) +class MaliciousSiteDaoTest { + + private lateinit var database: MaliciousSitesDatabase + private lateinit var dao: MaliciousSiteDao + + @Before + fun setup() { + val context = InstrumentationRegistry.getInstrumentation().targetContext + database = Room.inMemoryDatabaseBuilder(context, MaliciousSitesDatabase::class.java).allowMainThreadQueries().build() + dao = database.maliciousSiteDao() + } + + @After + fun teardown() { + database.close() + } + + @Test + fun testUpdateFiltersWithReplaceDeletesOldEntries() = runTest { + dao.insertFilters(setOf(FilterEntity(hash = "hash1", regex = "regex1", type = Feed.PHISHING.name))) + val filterSet = PhishingFilterSetWithRevision( + revision = 1, + replace = true, + delete = setOf(), + insert = setOf(Filter(hash = "hash2", regex = "regex2")), + ) + + dao.updateFilters(filterSet) + + assertTrue(dao.getFilter("hash2")?.isEmpty() == false) + assertTrue(dao.getFilter("hash1")?.isEmpty() == true) + } + + @Test + fun testUpdateFiltersWithoutReplaceDoesNotDeleteOldEntries() = runTest { + dao.insertFilters(setOf(FilterEntity(hash = "hash1", regex = "regex1", type = Feed.PHISHING.name))) + val filterSet = PhishingFilterSetWithRevision( + revision = 1, + replace = false, + delete = setOf(), + insert = setOf(Filter(hash = "hash2", regex = "regex2")), + ) + + dao.updateFilters(filterSet) + + assertTrue(dao.getFilter("hash2")?.isEmpty() == false) + assertTrue(dao.getFilter("hash1")?.isEmpty() == false) + } + + @Test + fun testUpdateHashPrefixesWithReplaceDeletesOldEntries() = runTest { + dao.insertHashPrefixes(listOf(HashPrefixEntity(hashPrefix = "prefix1", type = Feed.PHISHING.name))) + val hashPrefixes = PhishingHashPrefixesWithRevision( + revision = 1, + replace = true, + delete = setOf(), + insert = setOf("prefix2"), + ) + + dao.updateHashPrefixes(hashPrefixes) + + assertNotNull(dao.getHashPrefix("prefix2")) + assertNull(dao.getHashPrefix("prefix1")) + } + + @Test + fun testUpdateHashPrefixesWithoutReplaceDoesNotDeleteOldEntries() = runTest { + dao.insertHashPrefixes(listOf(HashPrefixEntity(hashPrefix = "prefix1", type = Feed.PHISHING.name))) + val hashPrefixes = PhishingHashPrefixesWithRevision( + revision = 1, + replace = false, + delete = setOf(), + insert = setOf("prefix2"), + ) + + dao.updateHashPrefixes(hashPrefixes) + + assertNotNull(dao.getHashPrefix("prefix2")) + assertNotNull(dao.getHashPrefix("prefix1")) + } + + @Test + fun testFiltersNotUpdatedIfRevisionIsLowerThanCurrent() = runTest { + dao.insertFilters(setOf(FilterEntity(hash = "hash1", regex = "regex1", type = Feed.PHISHING.name))) + dao.insertRevision(RevisionEntity(feed = Feed.PHISHING.name, type = Type.FILTER_SET.name, revision = 2)) + val filterSet = PhishingFilterSetWithRevision( + revision = 1, + replace = true, + delete = setOf(), + insert = setOf(Filter(hash = "hash2", regex = "regex2")), + ) + + dao.updateFilters(filterSet) + + assertTrue(dao.getFilter("hash1")?.isEmpty() == false) + assertTrue(dao.getFilter("hash2")?.isEmpty() == true) + } + + @Test + fun testFiltersUpdatedIfRevisionIsHigherThanCurrent() = runTest { + dao.insertFilters(setOf(FilterEntity(hash = "hash1", regex = "regex1", type = Feed.PHISHING.name))) + dao.insertRevision(RevisionEntity(feed = Feed.PHISHING.name, type = Type.FILTER_SET.name, revision = 1)) + val filterSet = PhishingFilterSetWithRevision( + revision = 2, + replace = true, + delete = setOf(), + insert = setOf(Filter(hash = "hash2", regex = "regex2")), + ) + + dao.updateFilters(filterSet) + + assertTrue(dao.getFilter("hash1")?.isEmpty() == true) + assertTrue(dao.getFilter("hash2")?.isEmpty() == false) + } + + @Test + fun testHashPrefixesNotUpdatedIfRevisionIsLowerThanCurrent() = runTest { + dao.insertHashPrefixes(listOf(HashPrefixEntity(hashPrefix = "prefix1", type = Feed.PHISHING.name))) + dao.insertRevision(RevisionEntity(feed = Feed.PHISHING.name, type = Type.HASH_PREFIXES.name, revision = 2)) + val hashPrefixes = PhishingHashPrefixesWithRevision( + revision = 1, + replace = true, + delete = setOf(), + insert = setOf("prefix2"), + ) + + dao.updateHashPrefixes(hashPrefixes) + + assertNotNull(dao.getHashPrefix("prefix1")) + assertNull(dao.getHashPrefix("prefix2")) + } + + @Test + fun testHashPrefixesUpdatedIfRevisionIsHigherThanCurrent() = runTest { + dao.insertHashPrefixes(listOf(HashPrefixEntity(hashPrefix = "prefix1", type = Feed.PHISHING.name))) + dao.insertRevision(RevisionEntity(feed = Feed.PHISHING.name, type = Type.HASH_PREFIXES.name, revision = 1)) + val hashPrefixes = PhishingHashPrefixesWithRevision( + revision = 2, + replace = true, + delete = setOf(), + insert = setOf("prefix2"), + ) + + dao.updateHashPrefixes(hashPrefixes) + + assertNull(dao.getHashPrefix("prefix1")) + assertNotNull(dao.getHashPrefix("prefix2")) + } +} diff --git a/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtectionTest.kt b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtectionTest.kt new file mode 100644 index 000000000000..bfb205ce03ba --- /dev/null +++ b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtectionTest.kt @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 DuckDuckGo + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.duckduckgo.malicioussiteprotection.impl.domain + +import android.net.Uri +import androidx.test.ext.junit.runners.AndroidJUnit4 +import com.duckduckgo.common.test.CoroutineTestRule +import com.duckduckgo.malicioussiteprotection.api.MaliciousSiteProtection +import com.duckduckgo.malicioussiteprotection.impl.MaliciousSiteProtectionRCFeature +import com.duckduckgo.malicioussiteprotection.impl.data.MaliciousSiteRepository +import com.duckduckgo.malicioussiteprotection.impl.models.Filter +import com.duckduckgo.malicioussiteprotection.impl.models.Match +import java.security.MessageDigest +import kotlinx.coroutines.test.runTest +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Mockito.mock +import org.mockito.kotlin.whenever + +@RunWith(AndroidJUnit4::class) +class RealMaliciousSiteProtectionTest { + + @get:Rule + var coroutinesTestRule = CoroutineTestRule() + + private lateinit var realMaliciousSiteProtection: RealMaliciousSiteProtection + private val maliciousSiteRepository: MaliciousSiteRepository = mock() + private val messageDigest: MessageDigest = MessageDigest.getInstance("SHA-256") + private val mockMaliciousSiteProtectionRCFeature: MaliciousSiteProtectionRCFeature = mock() + + @Before + fun setup() { + realMaliciousSiteProtection = RealMaliciousSiteProtection( + coroutinesTestRule.testDispatcherProvider, + coroutinesTestRule.testScope, + maliciousSiteRepository, + messageDigest, + mockMaliciousSiteProtectionRCFeature, + ) + whenever(mockMaliciousSiteProtectionRCFeature.isFeatureEnabled()).thenReturn(true) + } + + @Test + fun isMalicious_returnsSafe_whenUrlIsNotMalicious() = runTest { + val url = Uri.parse("https://example.com") + val hostname = url.host!! + val hash = messageDigest.digest(hostname.toByteArray()).joinToString("") { "%02x".format(it) } + val hashPrefix = hash.substring(0, 8) + + whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(false) + + val result = realMaliciousSiteProtection.isMalicious(url) {} + + assertEquals(MaliciousSiteProtection.IsMaliciousResult.SAFE, result) + } + + @Test + fun isMalicious_returnsMalicious_whenUrlIsMalicious() = runTest { + val url = Uri.parse("https://malicious.com") + val hostname = url.host!! + val hash = messageDigest.digest(hostname.toByteArray()).joinToString("") { "%02x".format(it) } + val hashPrefix = hash.substring(0, 8) + val filter = Filter(hash, ".*malicious.*") + + whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true) + whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter)) + + val result = realMaliciousSiteProtection.isMalicious(url) {} + + assertEquals(MaliciousSiteProtection.IsMaliciousResult.MALICIOUS, result) + } + + @Test + fun isMalicious_returnsSafe_whenUrlIsMaliciousButRCFeatureDisabled() = runTest { + val url = Uri.parse("https://malicious.com") + val hostname = url.host!! + val hash = messageDigest.digest(hostname.toByteArray()).joinToString("") { "%02x".format(it) } + val hashPrefix = hash.substring(0, 8) + val filter = Filter(hash, ".*malicious.*") + + whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true) + whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter)) + whenever(mockMaliciousSiteProtectionRCFeature.isFeatureEnabled()).thenReturn(false) + + val result = realMaliciousSiteProtection.isMalicious(url) {} + + assertEquals(MaliciousSiteProtection.IsMaliciousResult.SAFE, result) + } + + @Test + fun isMalicious_returnsWaitForConfirmation_whenUrlDoesNotMatchFilter() = runTest { + val url = Uri.parse("https://safe.com") + val hostname = url.host!! + val hash = messageDigest.digest(hostname.toByteArray()).joinToString("") { "%02x".format(it) } + val hashPrefix = hash.substring(0, 8) + val filter = Filter(hash, ".*unsafe.*") + + whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true) + whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter)) + + val result = realMaliciousSiteProtection.isMalicious(url) {} + + assertEquals(MaliciousSiteProtection.IsMaliciousResult.WAIT_FOR_CONFIRMATION, result) + } + + @Test + fun isMalicious_invokesOnSiteBlockedAsync_whenUrlIsMaliciousAndNeedsToGoToNetwork() = runTest { + val url = Uri.parse("https://malicious.com") + val hostname = url.host!! + val hash = messageDigest.digest(hostname.toByteArray()).joinToString("") { "%02x".format(it) } + val hashPrefix = hash.substring(0, 8) + val filter = Filter(hash, ".*whatever.*") + var onSiteBlockedAsyncCalled = false + + whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true) + whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter)) + whenever(maliciousSiteRepository.matches(hashPrefix.substring(0, 4))) + .thenReturn(listOf(Match(hostname, url.toString(), ".*malicious.*", hash))) + + realMaliciousSiteProtection.isMalicious(url) { + onSiteBlockedAsyncCalled = true + } + + assertTrue(onSiteBlockedAsyncCalled) + } +}