Skip to content

Commit

Permalink
fixed subsribe logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Dai.Otsuka authored and Admiral-Piett committed Sep 20, 2024
1 parent 4d51d27 commit 5555f5e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 7 deletions.
11 changes: 5 additions & 6 deletions app/gosns/subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,19 @@ func SubscribeV1(req *http.Request) (int, interfaces.AbstractResponseBody) {

subscription := &app.Subscription{EndPoint: requestBody.Endpoint, Protocol: requestBody.Protocol, TopicArn: requestBody.TopicArn, Raw: requestBody.Attributes.RawMessageDelivery, FilterPolicy: &requestBody.Attributes.FilterPolicy}

subArn := uuid.NewString()
subscription.SubscriptionArn = fmt.Sprintf("%s:%s", requestBody.TopicArn, uuid.NewString())

//Create the response
requestId := uuid.NewString()
respStruct := models.SubscribeResponse{Xmlns: models.BASE_XMLNS, Result: models.SubscribeResult{SubscriptionArn: subArn}, Metadata: app.ResponseMetadata{RequestId: requestId}}
respStruct := models.SubscribeResponse{Xmlns: models.BASE_XMLNS, Result: models.SubscribeResult{SubscriptionArn: subscription.SubscriptionArn}, Metadata: app.ResponseMetadata{RequestId: requestId}}
if app.SyncTopics.Topics[topicName] != nil {
app.SyncTopics.Lock()
isDuplicate := false
// Duplicate check
for _, subscription := range app.SyncTopics.Topics[topicName].Subscriptions {
if subscription.EndPoint == requestBody.Endpoint && subscription.TopicArn == requestBody.TopicArn {
for _, sub := range app.SyncTopics.Topics[topicName].Subscriptions {
if sub.EndPoint == requestBody.Endpoint && sub.TopicArn == requestBody.TopicArn {
isDuplicate = true
subArn = subscription.SubscriptionArn
sub.SubscriptionArn = subscription.SubscriptionArn
}
}
if !isDuplicate {
Expand All @@ -66,7 +65,7 @@ func SubscribeV1(req *http.Request) (int, interfaces.AbstractResponseBody) {
token := uuid.NewString()

TOPIC_DATA[requestBody.TopicArn] = &pendingConfirm{
subArn: subArn,
subArn: subscription.SubscriptionArn,
token: token,
}

Expand Down
57 changes: 56 additions & 1 deletion smoke_tests/sns_subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package smoke_tests

import (
"context"
"encoding/xml"
"fmt"
"net/http"
"testing"

"github.com/aws/aws-sdk-go-v2/config"

"github.com/Admiral-Piett/goaws/app/conf"
"github.com/Admiral-Piett/goaws/app/models"
"github.com/Admiral-Piett/goaws/app/test"

"github.com/gavv/httpexpect/v2"
Expand Down Expand Up @@ -58,6 +60,55 @@ func Test_Subscribe_json(t *testing.T) {
assert.Equal(t, "sqs", subscriptions[0].Protocol)
assert.False(t, subscriptions[0].Raw)
assert.Contains(t, subscriptions[0].SubscriptionArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2"))
assert.Equal(t, response.SubscriptionArn, &subscriptions[0].SubscriptionArn)
}

func Test_Subscribe_json_with_duplicate_subscription(t *testing.T) {
server := generateServer()
defaultEnv := app.CurrentEnvironment
conf.LoadYamlConfig("../app/conf/mock-data/mock-config.yaml", "BaseUnitTests")
defer func() {
server.Close()
test.ResetResources()
app.CurrentEnvironment = defaultEnv
}()

sdkConfig, _ := config.LoadDefaultConfig(context.TODO())
sdkConfig.BaseEndpoint = aws.String(server.URL)
snsClient := sns.NewFromConfig(sdkConfig)

snsClient.Subscribe(context.TODO(), &sns.SubscribeInput{
Protocol: aws.String("sqs"),
TopicArn: aws.String(fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")),
Attributes: map[string]string{},
Endpoint: aws.String(fmt.Sprintf("%s:%s", af.BASE_SQS_ARN, "unit-queue2")),
ReturnSubscriptionArn: true,
})

response, err := snsClient.Subscribe(context.TODO(), &sns.SubscribeInput{
Protocol: aws.String("sqs"),
TopicArn: aws.String(fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2")),
Attributes: map[string]string{},
Endpoint: aws.String(fmt.Sprintf("%s:%s", af.BASE_SQS_ARN, "unit-queue2")),
ReturnSubscriptionArn: true,
})

assert.Nil(t, err)
assert.NotNil(t, response)

app.SyncTopics.Lock()
defer app.SyncTopics.Unlock()

subscriptions := app.SyncTopics.Topics["unit-topic2"].Subscriptions
assert.Len(t, subscriptions, 1)

expectedFilterPolicy := app.FilterPolicy(nil)
assert.Equal(t, fmt.Sprintf("%s:%s", af.BASE_SQS_ARN, "unit-queue2"), subscriptions[0].EndPoint)
assert.Equal(t, &expectedFilterPolicy, subscriptions[0].FilterPolicy)
assert.Equal(t, "sqs", subscriptions[0].Protocol)
assert.False(t, subscriptions[0].Raw)
assert.Contains(t, subscriptions[0].SubscriptionArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2"))
assert.Equal(t, response.SubscriptionArn, &subscriptions[0].SubscriptionArn)
assert.Equal(t, subscriptions[0].TopicArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2"))
}

Expand Down Expand Up @@ -101,6 +152,7 @@ func Test_Subscribe_json_with_additional_fields(t *testing.T) {
assert.Equal(t, "sqs", subscriptions[0].Protocol)
assert.True(t, subscriptions[0].Raw)
assert.Contains(t, subscriptions[0].SubscriptionArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2"))
assert.Equal(t, response.SubscriptionArn, &subscriptions[0].SubscriptionArn)
assert.Equal(t, subscriptions[0].TopicArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2"))
}

Expand Down Expand Up @@ -128,7 +180,7 @@ func Test_Subscribe_xml(t *testing.T) {
Protocol: "sqs",
}

e.POST("/").
r := e.POST("/").
WithForm(requestBody).
WithFormField("Attributes.entry.1.key", "RawMessageDelivery").
WithFormField("Attributes.entry.1.value", "true").
Expand All @@ -138,6 +190,8 @@ func Test_Subscribe_xml(t *testing.T) {
Status(http.StatusOK).
Body().Raw()

response := models.SubscribeResponse{}
xml.Unmarshal([]byte(r), &response)
subscriptions := app.SyncTopics.Topics["unit-topic2"].Subscriptions
assert.Len(t, subscriptions, 1)

Expand All @@ -147,5 +201,6 @@ func Test_Subscribe_xml(t *testing.T) {
assert.Equal(t, "sqs", subscriptions[0].Protocol)
assert.True(t, subscriptions[0].Raw)
assert.Contains(t, subscriptions[0].SubscriptionArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2"))
assert.Equal(t, response.Result.SubscriptionArn, subscriptions[0].SubscriptionArn)
assert.Equal(t, subscriptions[0].TopicArn, fmt.Sprintf("%s:%s", af.BASE_SNS_ARN, "unit-topic2"))
}

0 comments on commit 5555f5e

Please sign in to comment.