Skip to content

Commit

Permalink
Token based authentication integration with core extension (#4011)
Browse files Browse the repository at this point in the history
* tba draft

* - stop authxmanager on pool close
- swith to long dates

* drop use of authxmanager and authenticatedconnection from core

* -update submodule ref
-change exception message

* - remove submodule

- update dependency

* back to current version

* - move autxhmanager creation to user space
- introduce authenticationeventlisteners
- clenaup in connectionpool
- add entraidtestcontext
- add redisintegrationtests
- fix failing tokenbasedauthentication unit&integ tests

* - prevent use of pubsub with TBA+RESP2 combination
- fix flaky test

* - support tba with clusters
- add cluster+tba tests

* - remove onerror from authxmanager
- fix flaky tests

* - fix flaky test

* fix renewalDuringOperationsTest

* -reviews from @sazzad16

* - fix config for managedIdentity
- set audiences with scopes
- managed identity tests

* review from @ggivo
- use getuser instead oid from Token

* handle and propogate from unsuccessful AUTH response

* adding reauth support for both pubsub and shardedpubsub

* fix ping issue with pubsub

* - review from @sazzad16 : make JedisSafeAuthenticator protected
- fix failing unit tests

* update authx version

* - remove workaround for standalone endpoint
  • Loading branch information
atakavci authored Dec 20, 2024
1 parent 90583d0 commit d4a569c
Show file tree
Hide file tree
Showing 21 changed files with 1,902 additions and 86 deletions.
13 changes: 13 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@
<version>2.11.0</version>
</dependency>

<dependency>
<groupId>redis.clients.authentication</groupId>
<artifactId>redis-authx-core</artifactId>
<version>0.1.1-beta1</version>
</dependency>

<!-- Optional dependencies -->

<!-- UNIX socket connection support -->
Expand Down Expand Up @@ -150,6 +156,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>redis.clients.authentication</groupId>
<artifactId>redis-authx-entraid</artifactId>
<version>0.1.1-beta1</version>
<scope>test</scope>
</dependency>

<!-- circuit breaker / failover -->
<dependency>
<groupId>io.github.resilience4j</groupId>
Expand Down
51 changes: 40 additions & 11 deletions src/main/java/redis/clients/jedis/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.concurrent.atomic.AtomicReference;

import redis.clients.jedis.Protocol.Command;
import redis.clients.jedis.Protocol.Keyword;
import redis.clients.jedis.annots.Experimental;
import redis.clients.jedis.args.ClientAttributeOption;
import redis.clients.jedis.args.Rawable;
import redis.clients.jedis.authentication.AuthXManager;
import redis.clients.jedis.commands.ProtocolCommand;
import redis.clients.jedis.exceptions.JedisConnectionException;
import redis.clients.jedis.exceptions.JedisDataException;
Expand All @@ -44,6 +46,8 @@ public class Connection implements Closeable {
private String strVal;
protected String server;
protected String version;
private AtomicReference<RedisCredentials> currentCredentials = new AtomicReference<>(null);
private AuthXManager authXManager;

public Connection() {
this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT);
Expand All @@ -63,6 +67,7 @@ public Connection(final HostAndPort hostAndPort, final JedisClientConfig clientC

public Connection(final JedisSocketFactory socketFactory) {
this.socketFactory = socketFactory;
this.authXManager = null;
}

public Connection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig) {
Expand Down Expand Up @@ -93,8 +98,8 @@ public String toIdentityString() {
SocketAddress remoteAddr = socket.getRemoteSocketAddress();
SocketAddress localAddr = socket.getLocalSocketAddress();
if (remoteAddr != null) {
strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id,
localAddr, (broken ? '!' : '-'), remoteAddr);
strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id, localAddr,
(broken ? '!' : '-'), remoteAddr);
} else if (localAddr != null) {
strVal = String.format("%s{id: 0x%X, L:%s}", className, id, localAddr);
} else {
Expand Down Expand Up @@ -438,8 +443,8 @@ private static boolean validateClientInfo(String info) {
for (int i = 0; i < info.length(); i++) {
char c = info.charAt(i);
if (c < '!' || c > '~') {
throw new JedisValidationException("client info cannot contain spaces, "
+ "newlines or special characters.");
throw new JedisValidationException(
"client info cannot contain spaces, " + "newlines or special characters.");
}
}
return true;
Expand All @@ -451,7 +456,13 @@ protected void initializeFromClientConfig(final JedisClientConfig config) {

protocol = config.getRedisProtocol();

final Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();
Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();

authXManager = config.getAuthXManager();
if (authXManager != null) {
credentialsProvider = authXManager;
}

if (credentialsProvider instanceof RedisCredentialsProvider) {
final RedisCredentialsProvider redisCredentialsProvider = (RedisCredentialsProvider) credentialsProvider;
try {
Expand All @@ -469,7 +480,8 @@ protected void initializeFromClientConfig(final JedisClientConfig config) {

String clientName = config.getClientName();
if (clientName != null && validateClientInfo(clientName)) {
fireAndForgetMsg.add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName));
fireAndForgetMsg
.add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName));
}

ClientSetInfoConfig setInfoConfig = config.getClientSetInfoConfig();
Expand Down Expand Up @@ -525,12 +537,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c
if (protocol != null && credentials != null && credentials.getUser() != null) {
byte[] rawPass = encodeToBytes(credentials.getPassword());
try {
helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(), encode(credentials.getUser()), rawPass);
helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(),
encode(credentials.getUser()), rawPass);
} finally {
Arrays.fill(rawPass, (byte) 0); // clear sensitive data
}
} else {
auth(credentials);
authenticate(credentials);
helloResult = protocol == null ? null : hello(encode(protocol.version()));
}
if (helloResult != null) {
Expand All @@ -542,9 +555,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c
// handled in RedisCredentialsProvider.cleanUp()
}

private void auth(RedisCredentials credentials) {
public void setCredentials(RedisCredentials credentials) {
currentCredentials.set(credentials);
}

private String authenticate(RedisCredentials credentials) {
if (credentials == null || credentials.getPassword() == null) {
return;
return null;
}
byte[] rawPass = encodeToBytes(credentials.getPassword());
try {
Expand All @@ -556,7 +573,11 @@ private void auth(RedisCredentials credentials) {
} finally {
Arrays.fill(rawPass, (byte) 0); // clear sensitive data
}
getStatusCodeReply();
return getStatusCodeReply();
}

public String reAuthenticate() {
return authenticate(currentCredentials.getAndSet(null));
}

protected Map<String, Object> hello(byte[]... args) {
Expand Down Expand Up @@ -585,4 +606,12 @@ public boolean ping() {
}
return true;
}

protected boolean isTokenBasedAuthenticationEnabled() {
return authXManager != null;
}

protected AuthXManager getAuthXManager() {
return authXManager;
}
}
79 changes: 65 additions & 14 deletions src/main/java/redis/clients/jedis/ConnectionFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.function.Supplier;

import redis.clients.jedis.annots.Experimental;
import redis.clients.jedis.authentication.AuthXManager;
import redis.clients.jedis.authentication.JedisAuthenticationException;
import redis.clients.jedis.authentication.AuthXEventListener;
import redis.clients.jedis.csc.Cache;
import redis.clients.jedis.csc.CacheConnection;
import redis.clients.jedis.exceptions.JedisException;
Expand All @@ -20,28 +25,52 @@ public class ConnectionFactory implements PooledObjectFactory<Connection> {

private final JedisSocketFactory jedisSocketFactory;
private final JedisClientConfig clientConfig;
private Cache clientSideCache = null;
private final Cache clientSideCache;
private final Supplier<Connection> objectMaker;

private final AuthXEventListener authXEventListener;

public ConnectionFactory(final HostAndPort hostAndPort) {
this.clientConfig = DefaultJedisClientConfig.builder().build();
this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort);
this(hostAndPort, DefaultJedisClientConfig.builder().build(), null);
}

public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig) {
this.clientConfig = clientConfig;
this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig);
this(hostAndPort, clientConfig, null);
}

@Experimental
public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, Cache csCache) {
this.clientConfig = clientConfig;
this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig);
this.clientSideCache = csCache;
public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig,
Cache csCache) {
this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache);
}

public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final JedisClientConfig clientConfig) {
this.clientConfig = clientConfig;
public ConnectionFactory(final JedisSocketFactory jedisSocketFactory,
final JedisClientConfig clientConfig) {
this(jedisSocketFactory, clientConfig, null);
}

private ConnectionFactory(final JedisSocketFactory jedisSocketFactory,
final JedisClientConfig clientConfig, Cache csCache) {

this.jedisSocketFactory = jedisSocketFactory;
this.clientSideCache = csCache;
this.clientConfig = clientConfig;

AuthXManager authXManager = clientConfig.getAuthXManager();
if (authXManager == null) {
this.objectMaker = connectionSupplier();
this.authXEventListener = AuthXEventListener.NOOP_LISTENER;
} else {
Supplier<Connection> supplier = connectionSupplier();
this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get());
this.authXEventListener = authXManager.getListener();
authXManager.start();
}
}

private Supplier<Connection> connectionSupplier() {
return clientSideCache == null ? () -> new Connection(jedisSocketFactory, clientConfig)
: () -> new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache);
}

@Override
Expand All @@ -64,8 +93,7 @@ public void destroyObject(PooledObject<Connection> pooledConnection) throws Exce
@Override
public PooledObject<Connection> makeObject() throws Exception {
try {
Connection jedis = clientSideCache == null ? new Connection(jedisSocketFactory, clientConfig)
: new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache);
Connection jedis = objectMaker.get();
return new DefaultPooledObject<>(jedis);
} catch (JedisException je) {
logger.debug("Error while makeObject", je);
Expand All @@ -76,17 +104,40 @@ public PooledObject<Connection> makeObject() throws Exception {
@Override
public void passivateObject(PooledObject<Connection> pooledConnection) throws Exception {
// TODO maybe should select db 0? Not sure right now.
Connection jedis = pooledConnection.getObject();
reAuthenticate(jedis);
}

@Override
public boolean validateObject(PooledObject<Connection> pooledConnection) {
final Connection jedis = pooledConnection.getObject();
try {
// check HostAndPort ??
return jedis.isConnected() && jedis.ping();
if (!jedis.isConnected()) {
return false;
}
reAuthenticate(jedis);
return jedis.ping();
} catch (final Exception e) {
logger.warn("Error while validating pooled Connection object.", e);
return false;
}
}

private void reAuthenticate(Connection jedis) throws Exception {
try {
String result = jedis.reAuthenticate();
if (result != null && !result.equals("OK")) {
String msg = "Re-authentication failed with server response: " + result;
Exception failedAuth = new JedisAuthenticationException(msg);
logger.error(failedAuth.getMessage(), failedAuth);
authXEventListener.onConnectionAuthenticationError(failedAuth);
return;
}
} catch (Exception e) {
logger.error("Error while re-authenticating connection", e);
authXEventListener.onConnectionAuthenticationError(e);
throw e;
}
}
}
41 changes: 38 additions & 3 deletions src/main/java/redis/clients/jedis/ConnectionPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@

import org.apache.commons.pool2.PooledObjectFactory;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;

import redis.clients.jedis.annots.Experimental;
import redis.clients.jedis.authentication.AuthXManager;
import redis.clients.jedis.csc.Cache;
import redis.clients.jedis.exceptions.JedisException;
import redis.clients.jedis.util.Pool;

public class ConnectionPool extends Pool<Connection> {

private AuthXManager authXManager;

public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) {
this(new ConnectionFactory(hostAndPort, clientConfig));
attachAuthenticationListener(clientConfig.getAuthXManager());
}

@Experimental
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache) {
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
Cache clientSideCache) {
this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache));
attachAuthenticationListener(clientConfig.getAuthXManager());
}

public ConnectionPool(PooledObjectFactory<Connection> factory) {
Expand All @@ -24,12 +32,14 @@ public ConnectionPool(PooledObjectFactory<Connection> factory) {
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
GenericObjectPoolConfig<Connection> poolConfig) {
this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig);
attachAuthenticationListener(clientConfig.getAuthXManager());
}

@Experimental
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache,
GenericObjectPoolConfig<Connection> poolConfig) {
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
Cache clientSideCache, GenericObjectPoolConfig<Connection> poolConfig) {
this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig);
attachAuthenticationListener(clientConfig.getAuthXManager());
}

public ConnectionPool(PooledObjectFactory<Connection> factory,
Expand All @@ -43,4 +53,29 @@ public Connection getResource() {
conn.setHandlingPool(this);
return conn;
}

@Override
public void close() {
try {
if (authXManager != null) {
authXManager.stop();
}
} finally {
super.close();
}
}

private void attachAuthenticationListener(AuthXManager authXManager) {
this.authXManager = authXManager;
if (authXManager != null) {
authXManager.addPostAuthenticationHook(token -> {
try {
// this is to trigger validations on each connection via ConnectionFactory
evict();
} catch (Exception e) {
throw new JedisException("Failed to evict connections from pool", e);
}
});
}
}
}
Loading

0 comments on commit d4a569c

Please sign in to comment.