diff --git a/web/src/main/java/org/apache/shiro/web/filter/authz/SslFilter.java b/web/src/main/java/org/apache/shiro/web/filter/authz/SslFilter.java index a5e9dde3a8..73048b0b43 100644 --- a/web/src/main/java/org/apache/shiro/web/filter/authz/SslFilter.java +++ b/web/src/main/java/org/apache/shiro/web/filter/authz/SslFilter.java @@ -112,11 +112,11 @@ protected boolean isAccessAllowed(ServletRequest request, ServletResponse respon */ @Override protected void postHandle(ServletRequest request, ServletResponse response) { - if (hsts.enabled) { + if (hsts.isEnabled()) { StringBuilder directives = new StringBuilder(64) .append("max-age=").append(hsts.getMaxAge()); - if (hsts.includeSubDomains) { + if (hsts.isIncludeSubDomains()) { directives.append("; includeSubDomains"); } @@ -130,17 +130,18 @@ protected void postHandle(ServletRequest request, ServletResponse response) { */ public class HSTS { + public static final String HTTP_HEADER = "Strict-Transport-Security"; + public static final boolean DEFAULT_ENABLED = false; public static final int DEFAULT_MAX_AGE = 31536000; // approx. one year in seconds public static final boolean DEFAULT_INCLUDE_SUB_DOMAINS = false; - public static final String HTTP_HEADER = "Strict-Transport-Security"; - private boolean enabled; private int maxAge; private boolean includeSubDomains; public HSTS() { + this.enabled = DEFAULT_ENABLED; this.maxAge = DEFAULT_MAX_AGE; this.includeSubDomains = DEFAULT_INCLUDE_SUB_DOMAINS; } diff --git a/web/src/test/java/org/apache/shiro/web/filter/authz/SslFilterTest.java b/web/src/test/java/org/apache/shiro/web/filter/authz/SslFilterTest.java index 413632937f..2e1fe2f574 100644 --- a/web/src/test/java/org/apache/shiro/web/filter/authz/SslFilterTest.java +++ b/web/src/test/java/org/apache/shiro/web/filter/authz/SslFilterTest.java @@ -18,49 +18,85 @@ */ package org.apache.shiro.web.filter.authz; +import java.util.HashMap; +import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.junit.Test; import static org.apache.shiro.web.filter.authz.SslFilter.HSTS.*; +import org.easymock.Capture; +import org.easymock.CaptureType; import static org.easymock.EasyMock.*; +import org.easymock.IAnswer; import static org.junit.Assert.*; +import org.junit.Before; public class SslFilterTest { + + private HttpServletRequest request; + private HttpServletResponse response; + private SslFilter sslFilter; + + @Before + public void before() { + request = createNiceMock(HttpServletRequest.class); + response = createNiceMock(HttpServletResponse.class); + sslFilter = new SslFilter(); + + final Map headers = new HashMap(); + + final Capture capturedName = newCapture(); + final Capture capturedValue = newCapture(); + + // mock HttpServletResponse.getHeader + expect(response.getHeader(capture(capturedName))).andAnswer(new IAnswer() { + @Override + public String answer() throws Throwable { + String name = capturedName.getValue(); + return headers.get(name); + } + + }); + + // mock HttpServletResponse.addHeader + response.addHeader(capture(capturedName), capture(capturedValue)); + expectLastCall().andAnswer(new IAnswer() { + @Override + public Void answer() throws Throwable { + String name = capturedName.getValue(); + String value = capturedValue.getValue(); + headers.put(name, value); + return (null); + } + }); + + replay(response); + } @Test public void testDisabledByDefault() { - HttpServletRequest request = createNiceMock(HttpServletRequest.class); - HttpServletResponse response = createNiceMock(HttpServletResponse.class); - - SslFilter sslFilter = new SslFilter(); - sslFilter.postHandle(request, response); assertNull(response.getHeader(HTTP_HEADER)); } @Test public void testDefaultValues() { - HttpServletRequest request = createNiceMock(HttpServletRequest.class); - HttpServletResponse response = createNiceMock(HttpServletResponse.class); - -// String expected = new StringBuilder() -// .append(HTTP_HEADER) -// .append(": ") -// .append("max-age=") -// .append(DEFAULT_MAX_AGE) -// .toString(); -// expect(response.addHeader(expected, expected)) -// .andReturn(expected) -// .anyTimes(); - replay(response); -// - SslFilter sslFilter = new SslFilter(); sslFilter.getHsts().setEnabled(true); - sslFilter.postHandle(request, response); - - //assertEquals(expected, response.getHeader(HTTP_HEADER)); + assertEquals("max-age=" + DEFAULT_MAX_AGE, response.getHeader(HTTP_HEADER)); } + + @Test + public void testSetProperties() { + sslFilter.getHsts().setEnabled(true); + sslFilter.getHsts().setMaxAge(7776000); + sslFilter.getHsts().setIncludeSubDomains(true); + sslFilter.postHandle(request, response); + + String expected = "max-age=" + 7776000 + "; includeSubDomains"; + assertEquals(expected, response.getHeader(HTTP_HEADER)); + } + }