diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index fefa4a5e..3040e107 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -100,8 +100,12 @@ import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList; import javax.annotation.concurrent.GuardedBy; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import java.io.File; +import java.io.IOException; import java.io.InterruptedIOException; +import java.io.UncheckedIOException; import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; import java.lang.management.MemoryUsage; @@ -109,6 +113,7 @@ import java.lang.reflect.Method; import java.net.URI; import java.nio.channels.ClosedByInterruptException; +import java.security.NoSuchAlgorithmException; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.*; @@ -431,7 +436,7 @@ public abstract class ModelMesh extends ThriftService private PayloadProcessor initPayloadProcessor() { String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null); logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions); - if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) { + if (payloadProcessorsDefinitions != null && !payloadProcessorsDefinitions.isEmpty()) { List payloadProcessors = new ArrayList<>(); for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) { try { @@ -442,6 +447,14 @@ private PayloadProcessor initPayloadProcessor() { String method = uri.getFragment(); if ("http".equals(processorName)) { processor = new RemotePayloadProcessor(uri); + } else if ("https".equals(processorName)) { + SSLContext sslContext; + try { + sslContext = SSLContext.getDefault(); + } catch (NoSuchAlgorithmException missingAlgorithmException) { + throw new UncheckedIOException(new IOException(missingAlgorithmException)); + } + processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters()); } else if ("logger".equals(processorName)) { processor = new LoggingPayloadProcessor(); } diff --git a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java index 23c2fba1..c8a958a6 100644 --- a/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java +++ b/src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java @@ -31,6 +31,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; + /** * A {@link PayloadProcessor} that sends payloads to a remote service via HTTP POST. */ @@ -45,8 +48,19 @@ public class RemotePayloadProcessor implements PayloadProcessor { private final HttpClient client; public RemotePayloadProcessor(URI uri) { + this(uri, null, null); + } + + public RemotePayloadProcessor(URI uri, SSLContext sslContext, SSLParameters sslParameters) { this.uri = uri; - this.client = HttpClient.newHttpClient(); + if (sslContext != null && sslParameters != null) { + this.client = HttpClient.newBuilder() + .sslContext(sslContext) + .sslParameters(sslParameters) + .build(); + } else { + this.client = HttpClient.newHttpClient(); + } } @Override diff --git a/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java b/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java index ec08ea0a..cfa7e29e 100644 --- a/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java +++ b/src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java @@ -16,6 +16,7 @@ package com.ibm.watson.modelmesh.payload; +import java.io.IOException; import java.net.URI; import io.grpc.Metadata; @@ -29,17 +30,36 @@ class RemotePayloadProcessorTest { @Test - void testDestinationUnreachable() { - RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123")); - String id = "123"; - String modelId = "456"; - String method = "predict"; - Status kind = Status.INVALID_ARGUMENT; - Metadata metadata = new Metadata(); - metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); - metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); - ByteBuf data = Unpooled.buffer(4); - Payload payload = new Payload(id, modelId, method, metadata, data, kind); - assertFalse(remotePayloadProcessor.process(payload)); + void testDestinationUnreachable() throws IOException { + URI uri = URI.create("http://this-does-not-exist:123"); + try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) { + String id = "123"; + String modelId = "456"; + String method = "predict"; + Status kind = Status.INVALID_ARGUMENT; + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); + metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); + ByteBuf data = Unpooled.buffer(4); + Payload payload = new Payload(id, modelId, method, metadata, data, kind); + assertFalse(remotePayloadProcessor.process(payload)); + } + } + + @Test + void testDestinationUnreachableHTTPS() throws IOException { + URI uri = URI.create("https://this-does-not-exist:123"); + try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) { + String id = "123"; + String modelId = "456"; + String method = "predict"; + Status kind = Status.INVALID_ARGUMENT; + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); + metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes()); + ByteBuf data = Unpooled.buffer(4); + Payload payload = new Payload(id, modelId, method, metadata, data, kind); + assertFalse(remotePayloadProcessor.process(payload)); + } } } \ No newline at end of file