diff --git a/build.gradle b/build.gradle index 0ae62539fc..12ba2d94c2 100644 --- a/build.gradle +++ b/build.gradle @@ -64,7 +64,7 @@ plugins { id 'maven-publish' id 'com.diffplug.spotless' version '6.25.0' id 'checkstyle' - id 'com.netflix.nebula.ospackage' version "11.10.0" + id 'com.netflix.nebula.ospackage' version "11.10.1" id "org.gradle.test-retry" version "1.6.0" id 'eclipse' id "com.github.spotbugs" version "5.2.5" @@ -497,7 +497,7 @@ configurations { force "org.apache.httpcomponents:httpcore:4.4.16" force "com.google.errorprone:error_prone_annotations:2.36.0" force "org.checkerframework:checker-qual:3.48.3" - force "ch.qos.logback:logback-classic:1.5.12" + force "ch.qos.logback:logback-classic:1.5.15" force "commons-io:commons-io:2.18.0" } } @@ -585,7 +585,7 @@ dependencies { implementation 'commons-cli:commons-cli:1.9.0' implementation "org.bouncycastle:bcprov-jdk18on:${versions.bouncycastle}" implementation 'org.ldaptive:ldaptive:1.2.3' - implementation 'com.nimbusds:nimbus-jose-jwt:9.47' + implementation 'com.nimbusds:nimbus-jose-jwt:9.48' implementation 'com.rfksystems:blake2b:2.0.0' implementation 'com.password4j:password4j:1.8.2' @@ -641,7 +641,7 @@ dependencies { implementation "com.nulab-inc:zxcvbn:1.9.0" runtimeOnly 'com.google.guava:failureaccess:1.0.2' - runtimeOnly 'org.apache.commons:commons-text:1.12.0' + runtimeOnly 'org.apache.commons:commons-text:1.13.0' runtimeOnly "org.glassfish.jaxb:jaxb-runtime:${jaxb_version}" runtimeOnly 'com.google.j2objc:j2objc-annotations:2.8' compileOnly 'com.google.code.findbugs:jsr305:3.0.2' @@ -688,8 +688,8 @@ dependencies { testImplementation 'commons-validator:commons-validator:1.9.0' testImplementation 'org.springframework.kafka:spring-kafka-test:3.3.0' testImplementation "org.springframework:spring-beans:${spring_version}" - testImplementation 'org.junit.jupiter:junit-jupiter:5.11.3' - testImplementation 'org.junit.jupiter:junit-jupiter-api:5.11.3' + testImplementation 'org.junit.jupiter:junit-jupiter:5.11.4' + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.11.4' testImplementation('org.awaitility:awaitility:4.2.2') { exclude(group: 'org.hamcrest', module: 'hamcrest') } @@ -749,7 +749,7 @@ dependencies { integrationTestImplementation "org.mockito:mockito-core:5.14.2" //spotless - implementation('com.google.googlejavaformat:google-java-format:1.25.0') { + implementation('com.google.googlejavaformat:google-java-format:1.25.2') { exclude group: 'com.google.guava' } } diff --git a/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java index b58e4afd35..de55334a99 100644 --- a/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java +++ b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java @@ -17,12 +17,11 @@ import java.net.SocketAddress; import java.util.Collection; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Pattern; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.opensearch.security.auth.UserInjector; @@ -57,7 +56,7 @@ public final class SafeSerializationUtils { LdapAttribute.class ); - private static final List> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of( + private static final Set> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableSet.of( InetAddress.class, Number.class, Collection.class, @@ -66,12 +65,23 @@ public final class SafeSerializationUtils { ); private static final Set SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues"); + static final Map, Boolean> safeClassCache = new ConcurrentHashMap<>(); static boolean isSafeClass(Class cls) { - return cls.isArray() - || SAFE_CLASSES.contains(cls) - || SAFE_CLASS_NAMES.contains(cls.getName()) - || SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); + return safeClassCache.computeIfAbsent(cls, SafeSerializationUtils::computeIsSafeClass); + } + + static boolean computeIsSafeClass(Class cls) { + return cls.isArray() || SAFE_CLASSES.contains(cls) || SAFE_CLASS_NAMES.contains(cls.getName()) || isAssignableFromSafeClass(cls); + } + + private static boolean isAssignableFromSafeClass(Class cls) { + for (Class safeClass : SAFE_ASSIGNABLE_FROM_CLASSES) { + if (safeClass.isAssignableFrom(cls)) { + return true; + } + } + return false; } static void prohibitUnsafeClasses(Class clazz) throws IOException { @@ -79,5 +89,4 @@ static void prohibitUnsafeClasses(Class clazz) throws IOException { throw new IOException("Unauthorized serialization attempt " + clazz.getName()); } } - } diff --git a/src/test/java/org/opensearch/security/support/SafeSerializationUtilsTest.java b/src/test/java/org/opensearch/security/support/SafeSerializationUtilsTest.java new file mode 100644 index 0000000000..f69d4e0291 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/SafeSerializationUtilsTest.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.security.support; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.regex.Pattern; + +import org.junit.Test; + +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.User; + +import com.amazon.dlic.auth.ldap.LdapUser; +import org.ldaptive.AbstractLdapBean; +import org.ldaptive.LdapAttribute; +import org.ldaptive.LdapEntry; +import org.ldaptive.SearchEntry; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class SafeSerializationUtilsTest { + + @Test + public void testSafeClasses() { + assertTrue(SafeSerializationUtils.isSafeClass(String.class)); + assertTrue(SafeSerializationUtils.isSafeClass(InetSocketAddress.class)); + assertTrue(SafeSerializationUtils.isSafeClass(Pattern.class)); + assertTrue(SafeSerializationUtils.isSafeClass(User.class)); + assertTrue(SafeSerializationUtils.isSafeClass(UserInjector.InjectedUser.class)); + assertTrue(SafeSerializationUtils.isSafeClass(SourceFieldsContext.class)); + assertTrue(SafeSerializationUtils.isSafeClass(LdapUser.class)); + assertTrue(SafeSerializationUtils.isSafeClass(SearchEntry.class)); + assertTrue(SafeSerializationUtils.isSafeClass(LdapEntry.class)); + assertTrue(SafeSerializationUtils.isSafeClass(AbstractLdapBean.class)); + assertTrue(SafeSerializationUtils.isSafeClass(LdapAttribute.class)); + } + + @Test + public void testSafeAssignableClasses() { + assertTrue(SafeSerializationUtils.isSafeClass(InetAddress.class)); + assertTrue(SafeSerializationUtils.isSafeClass(Integer.class)); + assertTrue(SafeSerializationUtils.isSafeClass(ArrayList.class)); + assertTrue(SafeSerializationUtils.isSafeClass(HashMap.class)); + assertTrue(SafeSerializationUtils.isSafeClass(Enum.class)); + } + + @Test + public void testArraysAreSafe() { + assertTrue(SafeSerializationUtils.isSafeClass(String[].class)); + assertTrue(SafeSerializationUtils.isSafeClass(int[].class)); + assertTrue(SafeSerializationUtils.isSafeClass(Object[].class)); + } + + @Test + public void testUnsafeClasses() { + assertFalse(SafeSerializationUtils.isSafeClass(SafeSerializationUtilsTest.class)); + assertFalse(SafeSerializationUtils.isSafeClass(Runtime.class)); + } + + @Test + public void testProhibitUnsafeClasses() { + try { + SafeSerializationUtils.prohibitUnsafeClasses(String.class); + } catch (IOException e) { + fail("Should not throw exception for safe class"); + } + + try { + SafeSerializationUtils.prohibitUnsafeClasses(SafeSerializationUtilsTest.class); + fail("Should throw exception for unsafe class"); + } catch (IOException e) { + assertEquals("Unauthorized serialization attempt " + SafeSerializationUtilsTest.class.getName(), e.getMessage()); + } + } + + @Test + public void testInheritance() { + class CustomArrayList extends ArrayList {} + assertTrue(SafeSerializationUtils.isSafeClass(CustomArrayList.class)); + + class CustomMap extends HashMap {} + assertTrue(SafeSerializationUtils.isSafeClass(CustomMap.class)); + } + + @Test + public void testCaching() { + // First call should compute the result + boolean result1 = SafeSerializationUtils.isSafeClass(String.class); + assertTrue(result1); + + // Second call should use cached result + boolean result2 = SafeSerializationUtils.isSafeClass(String.class); + assertTrue(result2); + + // Verify that the cache was used (size should be 1) + assertEquals(1, SafeSerializationUtils.safeClassCache.size()); + + // Third call for a different class + boolean result3 = SafeSerializationUtils.isSafeClass(Integer.class); + assertTrue(result3); + // Verify that the cache was updated + assertEquals(2, SafeSerializationUtils.safeClassCache.size()); + } +}