diff --git a/src/main/java/com/ibm/as400/access/AS400JDBCDriver.java b/src/main/java/com/ibm/as400/access/AS400JDBCDriver.java index 454673cda..f76f4c652 100644 --- a/src/main/java/com/ibm/as400/access/AS400JDBCDriver.java +++ b/src/main/java/com/ibm/as400/access/AS400JDBCDriver.java @@ -1191,9 +1191,9 @@ else if (clearPassword == null) as400 = AS400.newInstance(secure, serverName, userName); else as400 = AS400.newInstance(secure, serverName, userName, clearPassword, additionalAuthenticationFactor); - Object sslSocketFactoryObject = jdProperties.getOriginalInfo().get(PROPERTY_SSL_SOCKET_FACTORY); - if ((sslSocketFactoryObject != null) && (sslSocketFactoryObject instanceof SSLSocketFactory)) { - as400.setSSLSocketFactory((SSLSocketFactory) sslSocketFactoryObject); + SSLSocketFactory sslSocketFactoryObject = jdProperties.getCustomSSLSocketFactory(); + if (null != sslSocketFactoryObject) { + as400.setSSLSocketFactory(sslSocketFactoryObject); } } catch (AS400SecurityException e) diff --git a/src/main/java/com/ibm/as400/access/JDProperties.java b/src/main/java/com/ibm/as400/access/JDProperties.java index 15cd719bf..6c8f711fd 100644 --- a/src/main/java/com/ibm/as400/access/JDProperties.java +++ b/src/main/java/com/ibm/as400/access/JDProperties.java @@ -14,12 +14,25 @@ package com.ibm.as400.access; import java.io.Serializable; +import java.net.InetAddress; +import java.net.Socket; +import java.net.UnknownHostException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.io.FileInputStream; import java.io.IOException; // @W2a import java.sql.DriverPropertyInfo; import java.util.Enumeration; import java.util.Properties; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; + /** @@ -181,9 +194,12 @@ public class JDProperties implements Serializable, Cloneable //@PDC 550 static final int ADDITIONAL_AUTHENTICATION_FACTOR=101; static final int STAY_ALIVE = 102; + static final int TLS_TRUSTSTORE_FILE = 103; + static final int TLS_TRUSTSTORE_FILE_PASS = 104; + // @W2 always add to the end of the array! - private static final int NUMBER_OF_ATTRIBUTES_ = 103; + private static final int NUMBER_OF_ATTRIBUTES_ = 105; // Property names. @@ -254,6 +270,8 @@ public class JDProperties implements Serializable, Cloneable //@PDC 550 private static final String TIME_FORMAT_ = "time format"; private static final String TIMESTAMP_FORMAT_ = "timestamp format"; private static final String TIME_SEPARATOR_ = "time separator"; + private static final String TLS_TRUSTSTORE_FILE_ = "tls truststore"; + private static final String TLS_TRUSTSTORE_FILE_PASS_ = "tls truststore password"; private static final String TRACE_ = "trace"; private static final String TRACE_SERVER_ = "server trace"; // @j1a private static final String TRACE_TOOLBOX_ = "toolbox trace"; // @K1A @@ -1652,8 +1670,22 @@ public class JDProperties implements Serializable, Cloneable //@PDC 550 dpi_[i].required = false; dpi_[i].choices = new String[0]; defaults_[i] = "0"; + + i = TLS_TRUSTSTORE_FILE; + dpi_[i] = new DriverPropertyInfo (TLS_TRUSTSTORE_FILE_, ""); + dpi_[i].description = "TLS_TRUSTSTORE_FILE"; + dpi_[i].required = false; + dpi_[i].choices = new String[0]; + defaults_[i] = EMPTY_; + i = TLS_TRUSTSTORE_FILE_PASS; + dpi_[i] = new DriverPropertyInfo (TLS_TRUSTSTORE_FILE_PASS_, ""); + dpi_[i].description = "TLS_TRUSTSTORE_FILE_PASS"; + dpi_[i].required = false; + dpi_[i].choices = new String[0]; + defaults_[i] = EMPTY_; + } @@ -2019,6 +2051,112 @@ String getString (int index) return value.trim(); } + /** + * Gets a custom SSL Socket Factory, or returns null if no custom SSL Socket factory is specified + * (in which case, the system default factory will be used). + * + * The custom SSL Socket Factory is determined as follows: + * + */ + SSLSocketFactory getCustomSSLSocketFactory() { + Properties originalProps = this.getOriginalInfo(); + Object sslSocketFactoryObject = null == originalProps ? null: originalProps.get(AS400JDBCDriver.PROPERTY_SSL_SOCKET_FACTORY); + if ((sslSocketFactoryObject != null) && (sslSocketFactoryObject instanceof SSLSocketFactory)) { + return (SSLSocketFactory) sslSocketFactoryObject; + } + final String truststoreFile = getString(TLS_TRUSTSTORE_FILE); + final String truststorePass = getString(TLS_TRUSTSTORE_FILE_PASS); + if (null != truststoreFile && null != truststorePass && !truststoreFile.isEmpty() + && !truststorePass.isEmpty()) { + return new SSLSocketFactory() { + private SSLSocketFactory sslSocketFactory_ = null; + + private synchronized SSLSocketFactory getSSLSocketFactory() throws IOException { + if (null != sslSocketFactory_) { + return sslSocketFactory_; + } + if ("*ANY".equalsIgnoreCase(truststoreFile) && "*ANY".equalsIgnoreCase(truststorePass)) { + try { + SSLContext ctx = SSLContext.getInstance("TLS"); + //@formatter:off + ctx.init(null, new TrustManager[] { + new X509TrustManager() { + @Override public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { } + @Override public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { } + @Override public X509Certificate[] getAcceptedIssuers() { return new X509Certificate[0]; } + } + }, null); + //@formatter:on + return sslSocketFactory_ = ctx.getSocketFactory(); + } catch (Exception e) { + throw e instanceof IOException ? (IOException) e : new IOException(e); + } + } + try (FileInputStream trustFile = new FileInputStream(truststoreFile)) { + KeyStore myTrustStore = KeyStore.getInstance("JKS"); + myTrustStore.load(trustFile, truststorePass.toCharArray()); + TrustManagerFactory trustManagerFactory = TrustManagerFactory + .getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(myTrustStore); + SSLContext ctx = SSLContext.getInstance("TLS"); + ctx.init(null, trustManagerFactory.getTrustManagers(), null); + return sslSocketFactory_ = ctx.getSocketFactory(); + } catch (Exception e) { + throw e instanceof IOException ? (IOException) e : new IOException(e); + } + } + + //@formatter:off + @Override + public String[] getDefaultCipherSuites() { + try { return getSSLSocketFactory().getDefaultCipherSuites();} catch (Exception e) { } + return ((SSLSocketFactory) SSLSocketFactory.getDefault()).getDefaultCipherSuites(); + } + + @Override + public String[] getSupportedCipherSuites() { + try { return getSSLSocketFactory().getSupportedCipherSuites(); } catch (Exception e) { } + return ((SSLSocketFactory) SSLSocketFactory.getDefault()).getSupportedCipherSuites(); + } + + @Override + public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException { + return getSSLSocketFactory().createSocket(s, host, port, autoClose); + } + + @Override + public Socket createSocket(String host, int port) throws IOException, UnknownHostException { + return getSSLSocketFactory().createSocket(host, port); + } + + @Override + public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException, UnknownHostException { + return getSSLSocketFactory().createSocket(host, port, localHost, localPort); + } + + @Override + public Socket createSocket(InetAddress host, int port) throws IOException { + return getSSLSocketFactory().createSocket(host, port); + } + + @Override + public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort)throws IOException { + return getSSLSocketFactory().createSocket(address, port, localAddress, localPort); + } + //@formatter:on + + }; + } + return null; + } + /** * Get the clear password. The caller is responsible for clearing the array * after it is done with the password