diff --git a/app/gosns/subscribe.go b/app/gosns/subscribe.go index ad8cea10..e35ab5ba 100644 --- a/app/gosns/subscribe.go +++ b/app/gosns/subscribe.go @@ -39,20 +39,19 @@ func SubscribeV1(req *http.Request) (int, interfaces.AbstractResponseBody) { subscription := &app.Subscription{EndPoint: requestBody.Endpoint, Protocol: requestBody.Protocol, TopicArn: requestBody.TopicArn, Raw: requestBody.Attributes.RawMessageDelivery, FilterPolicy: &requestBody.Attributes.FilterPolicy} - subArn := uuid.NewString() subscription.SubscriptionArn = fmt.Sprintf("%s:%s", requestBody.TopicArn, uuid.NewString()) //Create the response requestId := uuid.NewString() - respStruct := models.SubscribeResponse{Xmlns: models.BASE_XMLNS, Result: models.SubscribeResult{SubscriptionArn: subArn}, Metadata: app.ResponseMetadata{RequestId: requestId}} + respStruct := models.SubscribeResponse{Xmlns: models.BASE_XMLNS, Result: models.SubscribeResult{SubscriptionArn: subscription.SubscriptionArn}, Metadata: app.ResponseMetadata{RequestId: requestId}} if app.SyncTopics.Topics[topicName] != nil { app.SyncTopics.Lock() isDuplicate := false // Duplicate check - for _, subscription := range app.SyncTopics.Topics[topicName].Subscriptions { - if subscription.EndPoint == requestBody.Endpoint && subscription.TopicArn == requestBody.TopicArn { + for _, sub := range app.SyncTopics.Topics[topicName].Subscriptions { + if sub.EndPoint == requestBody.Endpoint && sub.TopicArn == requestBody.TopicArn { isDuplicate = true - subArn = subscription.SubscriptionArn + sub.SubscriptionArn = subscription.SubscriptionArn } } if !isDuplicate { @@ -66,7 +65,7 @@ func SubscribeV1(req *http.Request) (int, interfaces.AbstractResponseBody) { token := uuid.NewString() TOPIC_DATA[requestBody.TopicArn] = &pendingConfirm{ - subArn: subArn, + subArn: subscription.SubscriptionArn, token: token, } diff --git a/smoke_tests/sns_subscribe_test.go b/smoke_tests/sns_subscribe_test.go index 16fd8dec..9a857d6e 100644 --- a/smoke_tests/sns_subscribe_test.go +++ b/smoke_tests/sns_subscribe_test.go @@ -2,6 +2,7 @@ package smoke_tests import ( "context" + "encoding/xml" "fmt" "net/http" "testing" @@ -9,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/Admiral-Piett/goaws/app/conf" + "github.com/Admiral-Piett/goaws/app/models" "github.com/Admiral-Piett/goaws/app/test" "github.com/gavv/httpexpect/v2" @@ -58,6 +60,55 @@ func Test_Subscribe_json(t *testing.T) { assert.Equal(t, "sqs", subscriptions[0].Protocol) assert.False(t, subscriptions[0].Raw) assert.Contains(t, subscriptions[0].SubscriptionArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")) + assert.Equal(t, response.SubscriptionArn, &subscriptions[0].SubscriptionArn) +} + +func Test_Subscribe_json_with_duplicate_subscription(t *testing.T) { + server := generateServer() + defaultEnv := app.CurrentEnvironment + conf.LoadYamlConfig("../app/conf/mock-data/mock-config.yaml", "BaseUnitTests") + defer func() { + server.Close() + test.ResetResources() + app.CurrentEnvironment = defaultEnv + }() + + sdkConfig, _ := config.LoadDefaultConfig(context.TODO()) + sdkConfig.BaseEndpoint = aws.String(server.URL) + snsClient := sns.NewFromConfig(sdkConfig) + + snsClient.Subscribe(context.TODO(), &sns.SubscribeInput{ + Protocol: aws.String("sqs"), + TopicArn: aws.String(fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")), + Attributes: map[string]string{}, + Endpoint: aws.String(fmt.Sprintf("%s:%s", af.BASE_SQS_ARN, "unit-queue2")), + ReturnSubscriptionArn: true, + }) + + response, err := snsClient.Subscribe(context.TODO(), &sns.SubscribeInput{ + Protocol: aws.String("sqs"), + TopicArn: aws.String(fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")), + Attributes: map[string]string{}, + Endpoint: aws.String(fmt.Sprintf("%s:%s", af.BASE_SQS_ARN, "unit-queue2")), + ReturnSubscriptionArn: true, + }) + + assert.Nil(t, err) + assert.NotNil(t, response) + + app.SyncTopics.Lock() + defer app.SyncTopics.Unlock() + + subscriptions := app.SyncTopics.Topics["unit-topic2"].Subscriptions + assert.Len(t, subscriptions, 1) + + expectedFilterPolicy := app.FilterPolicy(nil) + assert.Equal(t, fmt.Sprintf("%s:%s", af.BASE_SQS_ARN, "unit-queue2"), subscriptions[0].EndPoint) + assert.Equal(t, &expectedFilterPolicy, subscriptions[0].FilterPolicy) + assert.Equal(t, "sqs", subscriptions[0].Protocol) + assert.False(t, subscriptions[0].Raw) + assert.Contains(t, subscriptions[0].SubscriptionArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")) + assert.Equal(t, response.SubscriptionArn, &subscriptions[0].SubscriptionArn) assert.Equal(t, subscriptions[0].TopicArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")) } @@ -101,6 +152,7 @@ func Test_Subscribe_json_with_additional_fields(t *testing.T) { assert.Equal(t, "sqs", subscriptions[0].Protocol) assert.True(t, subscriptions[0].Raw) assert.Contains(t, subscriptions[0].SubscriptionArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")) + assert.Equal(t, response.SubscriptionArn, &subscriptions[0].SubscriptionArn) assert.Equal(t, subscriptions[0].TopicArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")) } @@ -128,7 +180,7 @@ func Test_Subscribe_xml(t *testing.T) { Protocol: "sqs", } - e.POST("/"). + r := e.POST("/"). WithForm(requestBody). WithFormField("Attributes.entry.1.key", "RawMessageDelivery"). WithFormField("Attributes.entry.1.value", "true"). @@ -138,6 +190,8 @@ func Test_Subscribe_xml(t *testing.T) { Status(http.StatusOK). Body().Raw() + response := models.SubscribeResponse{} + xml.Unmarshal([]byte(r), &response) subscriptions := app.SyncTopics.Topics["unit-topic2"].Subscriptions assert.Len(t, subscriptions, 1) @@ -147,5 +201,6 @@ func Test_Subscribe_xml(t *testing.T) { assert.Equal(t, "sqs", subscriptions[0].Protocol) assert.True(t, subscriptions[0].Raw) assert.Contains(t, subscriptions[0].SubscriptionArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")) + assert.Equal(t, response.Result.SubscriptionArn, subscriptions[0].SubscriptionArn) assert.Equal(t, subscriptions[0].TopicArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")) }