Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify sns message signature #198

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.sns.AmazonSNS;
import com.amazonaws.services.sns.AmazonSNSClient;
import com.amazonaws.services.sns.message.SnsMessageManager;
import io.awspring.cloud.context.annotation.ConditionalOnMissingAmazonClient;
import io.awspring.cloud.core.config.AmazonWebserviceClientFactoryBean;
import io.awspring.cloud.core.region.RegionProvider;
Expand All @@ -49,6 +50,7 @@
* @author Alain Sahli
* @author Eddú Meléndez
* @author Maciej Walkowiak
* @author Manuel Wessner
*/
@Configuration(proxyBeanMethods = false)
@ConditionalOnClass(AmazonSNS.class)
Expand Down Expand Up @@ -81,6 +83,12 @@ public AmazonWebserviceClientFactoryBean<AmazonSNSClient> amazonSNS(SnsPropertie
return clientFactoryBean;
}

@ConditionalOnMissingAmazonClient(SnsMessageManager.class)
@Bean
public SnsMessageManager snsMessageManager() {
return new SnsMessageManager(this.regionProvider.getRegion().getName());
}

@Configuration(proxyBeanMethods = false)
@ConditionalOnClass(WebMvcConfigurer.class)
static class SnsWebConfiguration {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.sns.message.SnsMessageManager;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSAsync;
import com.amazonaws.services.sqs.AmazonSQSAsyncClient;
Expand Down Expand Up @@ -57,6 +58,7 @@
*
* @author Maciej Walkowiak
* @author Eddú Meléndez
* @author Manuel Wessner
*/
@ConditionalOnClass(SimpleMessageListenerContainer.class)
@ConditionalOnMissingBean(SimpleMessageListenerContainer.class)
Expand Down Expand Up @@ -116,11 +118,14 @@ static class SqsConfiguration {

private final ObjectMapper objectMapper;

private final SnsMessageManager snsMessageManager;

SqsConfiguration(ObjectProvider<SimpleMessageListenerContainerFactory> simpleMessageListenerContainerFactory,
ObjectProvider<QueueMessageHandlerFactory> queueMessageHandlerFactory, BeanFactory beanFactory,
ObjectProvider<ResourceIdResolver> resourceIdResolver,
ObjectProvider<MappingJackson2MessageConverter> mappingJackson2MessageConverter,
ObjectProvider<ObjectMapper> objectMapper, SqsProperties sqsProperties) {
ObjectProvider<ObjectMapper> objectMapper, SqsProperties sqsProperties,
SnsMessageManager snsMessageManager) {
this.simpleMessageListenerContainerFactory = simpleMessageListenerContainerFactory
.getIfAvailable(() -> createSimpleMessageListenerContainerFactory(sqsProperties));
this.queueMessageHandlerFactory = queueMessageHandlerFactory
Expand All @@ -129,6 +134,7 @@ static class SqsConfiguration {
this.resourceIdResolver = resourceIdResolver.getIfAvailable();
this.mappingJackson2MessageConverter = mappingJackson2MessageConverter.getIfAvailable();
this.objectMapper = objectMapper.getIfAvailable();
this.snsMessageManager = snsMessageManager;
}

private static QueueMessageHandlerFactory createQueueMessageHandlerFactory(SqsProperties sqsProperties) {
Expand Down Expand Up @@ -198,6 +204,7 @@ private QueueMessageHandler getMessageHandler(AmazonSQSAsync amazonSqs) {

this.queueMessageHandlerFactory.setBeanFactory(this.beanFactory);
this.queueMessageHandlerFactory.setObjectMapper(this.objectMapper);
this.queueMessageHandlerFactory.setSnsMessageManager(this.snsMessageManager);

return this.queueMessageHandlerFactory.createQueueMessageHandler();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.sns.message.SnsMessageManager;
import com.amazonaws.services.sqs.AmazonSQSAsync;
import com.amazonaws.services.sqs.AmazonSQSAsyncClient;
import com.amazonaws.services.sqs.buffered.AmazonSQSBufferedAsyncClient;
Expand Down Expand Up @@ -448,7 +449,7 @@ static class ConfigurationWithCustomContainerFactory {
static final long BACK_OFF_TIME = 5000;

static {
QueueMessageHandler queueMessageHandler = new QueueMessageHandler();
QueueMessageHandler queueMessageHandler = new QueueMessageHandler(new SnsMessageManager("eu-central-1"));
queueMessageHandler.setApplicationContext(new StaticApplicationContext());
MESSAGE_HANDLER = queueMessageHandler;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Arrays;
import java.util.List;

import com.amazonaws.services.sns.message.SnsMessageManager;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSAsync;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -62,6 +63,8 @@ public class QueueMessageHandlerFactory {

private ObjectMapper objectMapper;

private SnsMessageManager snsMessageManager;

public void setArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
this.argumentResolvers = argumentResolvers;
}
Expand Down Expand Up @@ -147,12 +150,20 @@ public void setObjectMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
}

/**
* Configures an {@link SnsMessageManager} that is used by default.
* @param snsMessageManager - sns message manager, shouldn't be null
*/
public void setSnsMessageManager(SnsMessageManager snsMessageManager) {
this.snsMessageManager = snsMessageManager;
}

public QueueMessageHandler createQueueMessageHandler() {
QueueMessageHandler queueMessageHandler = new QueueMessageHandler(
CollectionUtils.isEmpty(this.messageConverters)
? Arrays.asList(getDefaultMappingJackson2MessageConverter(this.objectMapper))
: this.messageConverters,
this.sqsMessageDeletionPolicy);
this.sqsMessageDeletionPolicy, this.snsMessageManager);

if (!CollectionUtils.isEmpty(this.argumentResolvers)) {
queueMessageHandler.getCustomArgumentResolvers().addAll(this.argumentResolvers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Set;
import java.util.stream.Collectors;

import com.amazonaws.services.sns.message.SnsMessageManager;
import io.awspring.cloud.messaging.listener.annotation.SqsListener;
import io.awspring.cloud.messaging.listener.support.AcknowledgmentHandlerMethodArgumentResolver;
import io.awspring.cloud.messaging.listener.support.VisibilityHandlerMethodArgumentResolver;
Expand Down Expand Up @@ -67,6 +68,7 @@
* @author Maciej Walkowiak
* @author Wojciech Mąka
* @author Matej Nedic
* @author Manuel Wessner
* @since 1.0
*/
public class QueueMessageHandler extends AbstractMethodMessageHandler<QueueMessageHandler.MappingInformation> {
Expand All @@ -77,21 +79,19 @@ public class QueueMessageHandler extends AbstractMethodMessageHandler<QueueMessa

private final SqsMessageDeletionPolicy sqsMessageDeletionPolicy;

private final SnsMessageManager snsMessageManager;

private final List<MessageConverter> messageConverters;

public QueueMessageHandler(List<MessageConverter> messageConverters,
SqsMessageDeletionPolicy sqsMessageDeletionPolicy) {
SqsMessageDeletionPolicy sqsMessageDeletionPolicy, SnsMessageManager snsMessageManager) {
this.messageConverters = messageConverters;
this.sqsMessageDeletionPolicy = sqsMessageDeletionPolicy;
this.snsMessageManager = snsMessageManager;
}

public QueueMessageHandler(List<MessageConverter> messageConverters) {
this(messageConverters, SqsMessageDeletionPolicy.NO_REDRIVE);
}

public QueueMessageHandler() {
this.messageConverters = Collections.emptyList();
this.sqsMessageDeletionPolicy = SqsMessageDeletionPolicy.NO_REDRIVE;
public QueueMessageHandler(SnsMessageManager snsMessageManager) {
this(Collections.emptyList(), SqsMessageDeletionPolicy.NO_REDRIVE, snsMessageManager);
}

private static String[] wrapInStringArray(Object valueToWrap) {
Expand All @@ -105,12 +105,12 @@ protected List<? extends HandlerMethodArgumentResolver> initArgumentResolvers()
resolvers.add(new HeaderMethodArgumentResolver(null, null));
resolvers.add(new SqsHeadersMethodArgumentResolver());

resolvers.add(new NotificationSubjectArgumentResolver());
resolvers.add(new NotificationSubjectArgumentResolver(snsMessageManager));
resolvers.add(new AcknowledgmentHandlerMethodArgumentResolver(ACKNOWLEDGMENT));
resolvers.add(new VisibilityHandlerMethodArgumentResolver(VISIBILITY));

CompositeMessageConverter compositeMessageConverter = createPayloadArgumentCompositeConverter();
resolvers.add(new NotificationMessageArgumentResolver(compositeMessageConverter));
resolvers.add(new NotificationMessageArgumentResolver(compositeMessageConverter, snsMessageManager));
resolvers.add(new MessageMethodArgumentResolver(this.messageConverters.isEmpty() ? new StringMessageConverter()
: new CompositeMessageConverter(this.messageConverters)));
resolvers.add(new SqsMessageMethodArgumentResolver());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.awspring.cloud.messaging.support;

import com.amazonaws.services.sns.message.SnsMessageManager;
import io.awspring.cloud.messaging.config.annotation.NotificationMessage;
import io.awspring.cloud.messaging.support.converter.NotificationRequestConverter;

Expand All @@ -31,8 +32,8 @@ public class NotificationMessageArgumentResolver implements HandlerMethodArgumen

private final MessageConverter converter;

public NotificationMessageArgumentResolver(MessageConverter converter) {
this.converter = new NotificationRequestConverter(converter);
public NotificationMessageArgumentResolver(MessageConverter converter, SnsMessageManager snsMessageManager) {
this.converter = new NotificationRequestConverter(converter, snsMessageManager);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.awspring.cloud.messaging.support;

import com.amazonaws.services.sns.message.SnsMessageManager;
import io.awspring.cloud.messaging.config.annotation.NotificationSubject;
import io.awspring.cloud.messaging.support.converter.NotificationRequestConverter;

Expand All @@ -33,8 +34,8 @@ public class NotificationSubjectArgumentResolver implements HandlerMethodArgumen

private final MessageConverter converter;

public NotificationSubjectArgumentResolver() {
this.converter = new NotificationRequestConverter(new StringMessageConverter());
public NotificationSubjectArgumentResolver(SnsMessageManager snsMessageManager) {
this.converter = new NotificationRequestConverter(new StringMessageConverter(), snsMessageManager);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

package io.awspring.cloud.messaging.support.converter;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.UUID;

import com.amazonaws.services.sns.message.SnsMessageManager;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.awspring.cloud.messaging.core.MessageAttributeDataTypes;
Expand All @@ -38,6 +42,7 @@
/**
* @author Agim Emruli
* @author Alain Sahli
* @author Manuel Wessner
* @since 1.0
*/
public class NotificationRequestConverter implements MessageConverter {
Expand All @@ -46,8 +51,11 @@ public class NotificationRequestConverter implements MessageConverter {

private final MessageConverter payloadConverter;

public NotificationRequestConverter(MessageConverter payloadConverter) {
private final SnsMessageManager snsMessageManager;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause an issue if .aws/config has different region than properties for sns specific.
this.regionProvider = properties.getRegion() == null ? regionProvider.getIfAvailable() : new StaticRegionProvider(properties.getRegion());
Is used while configurating SNSclient, (meaning region in properties takes precedence over config file). Which means if we have in a file eu-central-1 and in properties if we define cloud.aws.sns.region: us-east-1 verification will always fail. (Client will use us-east-1 and SnsMessageManager will use eu-central-1)
My recommendation would be create SnsMessageManager and pass it through constructor (will require refactoring in some other classes as well since it will have to be passed through few constructors).
Or second option pass region and use properties region to configure SnsMessageManager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maciejwalkowiak First of nice find.

I know its been a while, but I dont understand the recommended approach fully 😅. SnsMessageManager has the default constructor (which uses the wrong region) and the one where I pass the region. So I should use that one.

How and where can I inject the correct region? I suppose I shouldnt duplicate the code in the SnsAutoConfiguration, right? Or should I simply inject SNSProperties and use this region (but which region provider do I use if the property is null, the default one)?

Copy link
Member

@MatejNedic MatejNedic Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, will try to explain the approach in more detail.
Personally, I would create SnsMessageManager bean in SnsAutoConfiguration and pass it to NotificationRequestConverter.
For a region use regionProvider located in SnsAutoConfiguration.

@maciejwalkowiak @eddumelendez WDYT? Could you assist further?

Copy link
Contributor Author

@WtfJoke WtfJoke Jan 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your explaination helped, thank you. I pushed the changes.

I always struggle a bit with formatting issues 😅 I've only managed to do this on my linux machine (and only when using the fully qualified plugin name eg /mvnw io.spring.javaformat:spring-javaformat-maven-plugin:0.0.31:apply
Havent yet tested the changes, but I think you still can take a look at it.

Just noticed the test failures, will take a look. Sorry 😅

Copy link
Contributor Author

@WtfJoke WtfJoke Jan 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strangely enough, all tests are green on my machine :S
EDIT: Found the cause, its the default region provided in my .aws/config


public NotificationRequestConverter(MessageConverter payloadConverter, SnsMessageManager snsMessageManager) {
this.payloadConverter = payloadConverter;
this.snsMessageManager = snsMessageManager;
}

private static Map<String, Object> getMessageAttributesAsMessageHeaders(JsonNode message) {
Expand Down Expand Up @@ -84,26 +92,28 @@ public Object fromMessage(Message<?> message, Class<?> targetClass) {
Assert.notNull(message, "message must not be null");
Assert.notNull(targetClass, "target class must not be null");

String payload = message.getPayload().toString();
JsonNode jsonNode;
try {
jsonNode = this.jsonMapper.readTree(message.getPayload().toString());
jsonNode = this.jsonMapper.readTree(payload);
}
catch (Exception e) {
throw new MessageConversionException("Could not read JSON", e);
}
if (!jsonNode.has("Type")) {
throw new MessageConversionException(
"Payload: '" + message.getPayload() + "' does not contain a Type attribute", null);
throw new MessageConversionException("Payload: '" + payload + "' does not contain a Type attribute", null);
}

if (!"Notification".equals(jsonNode.get("Type").asText())) {
throw new MessageConversionException("Payload: '" + message.getPayload() + "' is not a valid notification",
null);
throw new MessageConversionException("Payload: '" + payload + "' is not a valid notification", null);
}

if (!jsonNode.has("Message")) {
throw new MessageConversionException("Payload: '" + message.getPayload() + "' does not contain a message",
null);
throw new MessageConversionException("Payload: '" + payload + "' does not contain a message", null);
}

if (jsonNode.has("SignatureVersion")) {
verifySignature(payload);
}

String messagePayload = jsonNode.get("Message").asText();
Expand All @@ -119,6 +129,16 @@ public Message<?> toMessage(Object payload, MessageHeaders headers) {
"This converter only supports reading a SNS notification and not writing them");
}

private void verifySignature(String payload) {
try (InputStream messageStream = new ByteArrayInputStream(payload.getBytes())) {
// Unmarshalling the message is not needed, but also done here
snsMessageManager.parseMessage(messageStream);
}
catch (IOException e) {
throw new MessageConversionException("Issue while verifying signature of Payload: '" + payload + "'", e);
}
}

/**
* Notification request wrapper.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.sns.message.SnsMessageManager;
import com.amazonaws.services.sqs.AmazonSQSAsync;
import com.amazonaws.services.sqs.AmazonSQSAsyncClient;
import com.amazonaws.services.sqs.buffered.AmazonSQSBufferedAsyncClient;
Expand Down Expand Up @@ -372,7 +373,7 @@ static class ConfigurationWithCustomContainerFactory {
static final long BACK_OFF_TIME = 5000;

static {
QueueMessageHandler queueMessageHandler = new QueueMessageHandler();
QueueMessageHandler queueMessageHandler = new QueueMessageHandler(new SnsMessageManager("eu-central-1"));
queueMessageHandler.setApplicationContext(new StaticApplicationContext());
MESSAGE_HANDLER = queueMessageHandler;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import com.amazonaws.services.sns.message.SnsMessageManager;
import com.amazonaws.services.sqs.AmazonSQSAsync;
import com.amazonaws.services.sqs.buffered.AmazonSQSBufferedAsyncClient;
import com.amazonaws.services.sqs.model.GetQueueAttributesRequest;
Expand Down Expand Up @@ -71,7 +72,7 @@ void testAfterPropertiesSetIsLoggingWarnMessageIfFifoUsedWithAmazonSQSBufferedAs
Logger loggerMock = container.getLogger();
AmazonSQSAsync sqsMock = mock(AmazonSQSBufferedAsyncClient.class, withSettings().stubOnly());

QueueMessageHandler messageHandler = new QueueMessageHandler();
QueueMessageHandler messageHandler = new QueueMessageHandler(new SnsMessageManager("eu-central-1"));
container.setAmazonSqs(sqsMock);
container.setMessageHandler(mock(QueueMessageHandler.class));
container.setMessageHandler(messageHandler);
Expand Down Expand Up @@ -228,7 +229,7 @@ void receiveMessageRequests_withOneElement_created() throws Exception {
AbstractMessageListenerContainer container = new StubAbstractMessageListenerContainer();

AmazonSQSAsync mock = mock(AmazonSQSAsync.class, withSettings().stubOnly());
QueueMessageHandler messageHandler = new QueueMessageHandler();
QueueMessageHandler messageHandler = new QueueMessageHandler(new SnsMessageManager("eu-central-1"));
container.setAmazonSqs(mock);
container.setMessageHandler(mock(QueueMessageHandler.class));
container.setMessageHandler(messageHandler);
Expand Down Expand Up @@ -266,7 +267,7 @@ void receiveMessageRequests_withMultipleElements_created() throws Exception {
AmazonSQSAsync mock = mock(AmazonSQSAsync.class, withSettings().stubOnly());
container.setAmazonSqs(mock);
StaticApplicationContext applicationContext = new StaticApplicationContext();
QueueMessageHandler messageHandler = new QueueMessageHandler();
QueueMessageHandler messageHandler = new QueueMessageHandler(new SnsMessageManager("eu-central-1"));
messageHandler.setApplicationContext(applicationContext);
container.setMessageHandler(messageHandler);
applicationContext.registerSingleton("messageListener", MessageListener.class);
Expand Down Expand Up @@ -436,7 +437,7 @@ void receiveMessageRequests_withDestinationResolverThrowingException_shouldLogWa
AmazonSQSAsync mock = mock(AmazonSQSAsync.class, withSettings().stubOnly());
container.setAmazonSqs(mock);
StaticApplicationContext applicationContext = new StaticApplicationContext();
QueueMessageHandler messageHandler = new QueueMessageHandler();
QueueMessageHandler messageHandler = new QueueMessageHandler(new SnsMessageManager("eu-central-1"));
messageHandler.setApplicationContext(applicationContext);
container.setMessageHandler(messageHandler);
applicationContext.registerSingleton("messageListener", MessageListener.class);
Expand Down
Loading