Skip to content

Commit

Permalink
Add tests on TLS support utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
dj-mal committed Nov 7, 2017
1 parent 557545a commit 99f2be0
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.auth0.android.util;

import android.app.Activity;

import com.squareup.okhttp.ConnectionSpec;
import com.squareup.okhttp.OkHttpClient;
import com.squareup.okhttp.TlsVersion;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.robolectric.Robolectric;
import org.robolectric.RobolectricTestRunner;
import org.robolectric.annotation.Config;

import java.util.List;

import javax.net.ssl.SSLSocketFactory;

import static junit.framework.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;

@RunWith(RobolectricTestRunner.class)
@Config(constants = com.auth0.android.auth0.BuildConfig.class, sdk = 21, manifest = Config.NONE)
public class OkHttpTls12CompatTest {

Activity activity;
@Mock OkHttpClient client;

@Before
public void setUp(){
MockitoAnnotations.initMocks(this);
activity = Robolectric.setupActivity(Activity.class);
}

@Test
@Config(sdk=22)
public void shouldNotConfigTlsPostApi21() {
OkHttpTls12Compat.enableSupportOnPreLollipop(client);
verify(client, never()).setSslSocketFactory((SSLSocketFactory) any());
}

@Test
public void shouldConfigTlsOnOrPreApi21() {
OkHttpTls12Compat.enableSupportOnPreLollipop(client);

ArgumentCaptor<SSLSocketFactory> factoryCaptor = ArgumentCaptor.forClass(SSLSocketFactory.class);
verify(client).setSslSocketFactory(factoryCaptor.capture());
assertTrue(factoryCaptor.getValue() instanceof Tls12SocketFactory);

ArgumentCaptor<List> specCaptor = ArgumentCaptor.forClass(List.class);
verify(client).setConnectionSpecs(specCaptor.capture());
boolean hasTls12 = false;
for (Object item : specCaptor.getValue()) {
assertTrue(item instanceof ConnectionSpec);
ConnectionSpec spec = (ConnectionSpec) item;
if (!spec.isTls()) {
continue;
}
List<TlsVersion> versions = spec.tlsVersions();
for (TlsVersion version : versions) {
if ("TLSv1.2".equals(version.javaName())) {
hasTls12 = true;
break;
}
}
}
assertTrue(hasTls12);
}
}
152 changes: 152 additions & 0 deletions auth0/src/test/java/com/auth0/android/util/Tls12SocketFactoryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package com.auth0.android.util;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import java.io.IOException;
import java.net.InetAddress;
import java.net.Socket;
import java.util.Arrays;

import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertTrue;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class Tls12SocketFactoryTest {

private static final String TLS_1_2 = "TLSv1.2";
private static final String MOCK_HOST = "www.example.com";
private static final int MOCK_PORT = 8080;
private static final int MOCK_LOCAL_PORT = 8081;
private static final boolean MOCK_AUTO_CLOSE = true;

@Mock SSLSocket socket;
@Mock SSLSocketFactory delegate;
Tls12SocketFactory factory;

@Rule
public ExpectedException expectedException = ExpectedException.none();

@Before
public void setUp(){
MockitoAnnotations.initMocks(this);
factory = new Tls12SocketFactory(delegate);
}

@Test
public void shouldGetDefaultCipherSuites() {
String[] suites = new String[]{"Test"};
when(delegate.getDefaultCipherSuites()).thenReturn(suites);

String[] result = factory.getDefaultCipherSuites();

verify(delegate).getDefaultCipherSuites();
assertTrue(Arrays.equals(result, suites));
}

@Test
public void shouldGetSupportedCipherSuites() {
String[] suites = new String[]{"Test"};
when(delegate.getSupportedCipherSuites()).thenReturn(suites);

String[] result = factory.getSupportedCipherSuites();

verify(delegate).getSupportedCipherSuites();
assertTrue(Arrays.equals(result, suites));
}

@Test
public void shouldCreateSocket_socket_host_port_autoClose() throws IOException {
when(delegate.createSocket((Socket) anyObject(), anyString(), anyInt(), anyBoolean()))
.thenReturn(socket);

Socket result = factory.createSocket(socket, MOCK_HOST, MOCK_PORT, MOCK_AUTO_CLOSE);

assertEquals(result, socket);
verify(delegate).createSocket(eq(socket), eq(MOCK_HOST), eq(MOCK_PORT), eq(MOCK_AUTO_CLOSE));
verifyPatched(result);
}

@Test
public void shouldCreateSocket_host_port() throws IOException {
when(delegate.createSocket(anyString(), anyInt()))
.thenReturn(socket);

Socket result = factory.createSocket(MOCK_HOST, MOCK_PORT);

assertEquals(result, socket);
verify(delegate).createSocket(eq(MOCK_HOST), eq(MOCK_PORT));
verifyPatched(result);
}

@Test
public void shouldCreateSocket_host_port_localHost_localPort() throws IOException {
InetAddress localHost = mock(InetAddress.class);
when(delegate.createSocket(anyString(), anyInt(), (InetAddress) anyObject(), anyInt()))
.thenReturn(socket);

Socket result = factory.createSocket(MOCK_HOST, MOCK_PORT, localHost, MOCK_LOCAL_PORT);

assertEquals(result, socket);
verify(delegate).createSocket(eq(MOCK_HOST), eq(MOCK_PORT), eq(localHost), eq(MOCK_LOCAL_PORT));
verifyPatched(result);
}

@Test
public void shouldCreateSocket_hostAddress_port() throws IOException {
InetAddress host = mock(InetAddress.class);
when(delegate.createSocket((InetAddress) anyObject(), anyInt()))
.thenReturn(socket);

Socket result = factory.createSocket(host, MOCK_PORT);

assertEquals(result, socket);
verify(delegate).createSocket(eq(host), eq(MOCK_PORT));
verifyPatched(result);
}

@Test
public void shouldCreateSocket_address_port_localAddress_localPort() throws IOException {
InetAddress address = mock(InetAddress.class);
InetAddress localAddress = mock(InetAddress.class);
when(delegate.createSocket((InetAddress) anyObject(), anyInt(), (InetAddress) anyObject(), anyInt()))
.thenReturn(socket);

Socket result = factory.createSocket(address, MOCK_PORT, localAddress, MOCK_LOCAL_PORT);

assertEquals(result, socket);
verify(delegate).createSocket(eq(address), eq(MOCK_PORT), eq(localAddress), eq(MOCK_LOCAL_PORT));
verifyPatched(result);
}


private static void verifyPatched(Socket socket) {
ArgumentCaptor<String[]> captor = ArgumentCaptor.forClass(String[].class);
assertTrue(socket instanceof SSLSocket);
verify((SSLSocket)socket).setEnabledProtocols(captor.capture());
String[] protocols = captor.getValue();
boolean patched = false;
for (String string : protocols) {
if (TLS_1_2.equals(string)) {
patched = true;
break;
}
}
assertTrue(patched);
}
}

0 comments on commit 99f2be0

Please sign in to comment.