From d8c349aa6d42482791110482079368593f6d7a25 Mon Sep 17 00:00:00 2001 From: Amelia Downs Date: Thu, 19 Dec 2024 16:46:13 -0500 Subject: [PATCH] update the tcp_routes table to have host_tls_port default to 0 instead of NULL (#67) Co-authored-by: Geoff Franks --- .../V8_host_tls_port_tcp_default_zero.go | 33 ++ .../V8_host_tls_port_tcp_default_zero_test.go | 317 ++++++++++++ migration/migration.go | 3 + migration/migration_test.go | 2 +- migration/v7/model.go | 9 + migration/v7/models_suite_test.go | 13 + migration/v7/models_test.go | 470 ++++++++++++++++++ migration/v7/route.go | 91 ++++ migration/v7/router_groups.go | 284 +++++++++++ migration/v7/tcp_route.go | 121 +++++ migration/v7/tcp_route_test.go | 110 ++++ 11 files changed, 1452 insertions(+), 1 deletion(-) create mode 100644 migration/V8_host_tls_port_tcp_default_zero.go create mode 100644 migration/V8_host_tls_port_tcp_default_zero_test.go create mode 100644 migration/v7/model.go create mode 100644 migration/v7/models_suite_test.go create mode 100644 migration/v7/models_test.go create mode 100644 migration/v7/route.go create mode 100644 migration/v7/router_groups.go create mode 100644 migration/v7/tcp_route.go create mode 100644 migration/v7/tcp_route_test.go diff --git a/migration/V8_host_tls_port_tcp_default_zero.go b/migration/V8_host_tls_port_tcp_default_zero.go new file mode 100644 index 00000000..f23d7153 --- /dev/null +++ b/migration/V8_host_tls_port_tcp_default_zero.go @@ -0,0 +1,33 @@ +package migration + +import ( + "code.cloudfoundry.org/routing-api/db" + "code.cloudfoundry.org/routing-api/models" +) + +type V8HostTLSPortTCPDefaultZero struct{} + +func NewV8HostTLSPortTCPDefaultZero() *V8HostTLSPortTCPDefaultZero { + return &V8HostTLSPortTCPDefaultZero{} +} + +func (v *V8HostTLSPortTCPDefaultZero) Version() int { + return 8 +} + +func (v *V8HostTLSPortTCPDefaultZero) Run(sqlDB *db.SqlDB) error { + _, err := sqlDB.Client.Model(&models.TcpRouteMapping{}).RemoveIndex("idx_tcp_route") + if err != nil { + return err + } + + if sqlDB.Client.Dialect().GetName() == "postgres" { + sqlDB.Client.Exec("ALTER TABLE tcp_routes ALTER COLUMN host_tls_port SET DEFAULT 0") + } else { + sqlDB.Client.Exec("ALTER TABLE tcp_routes MODIFY COLUMN host_tls_port int DEFAULT 0") + } + + sqlDB.Client.Exec("UPDATE tcp_routes SET host_tls_port = 0 WHERE host_tls_port IS NULL") + + return nil +} diff --git a/migration/V8_host_tls_port_tcp_default_zero_test.go b/migration/V8_host_tls_port_tcp_default_zero_test.go new file mode 100644 index 00000000..9b9ed7a6 --- /dev/null +++ b/migration/V8_host_tls_port_tcp_default_zero_test.go @@ -0,0 +1,317 @@ +package migration_test + +import ( + "time" + + "code.cloudfoundry.org/routing-api/cmd/routing-api/testrunner" + "code.cloudfoundry.org/routing-api/db" + "code.cloudfoundry.org/routing-api/migration" + v7 "code.cloudfoundry.org/routing-api/migration/v7" + "code.cloudfoundry.org/routing-api/models" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("V8HostTLSPortTCPDefaultZero", func() { + var ( + sqlDB *db.SqlDB + dbAllocator testrunner.DbAllocator + ) + + BeforeEach(func() { + dbAllocator = testrunner.NewDbAllocator() + sqlCfg, err := dbAllocator.Create() + Expect(err).NotTo(HaveOccurred()) + + sqlDB, err = db.NewSqlDB(sqlCfg) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + err := dbAllocator.Delete() + Expect(err).ToNot(HaveOccurred()) + }) + + Describe("Version", func() { + It("returns 8 for the version", func() { + v8Migration := migration.NewV8HostTLSPortTCPDefaultZero() + Expect(v8Migration.Version()).To(Equal(8)) + }) + }) + + Describe("Run", func() { + Context("when a db already exists with values and has not been manually updated", func() { + BeforeEach(func() { + err := sqlDB.Client.AutoMigrate(&v7.RouterGroupDB{}, &v7.TcpRouteMapping{}, &v7.Route{}) + Expect(err).ToNot(HaveOccurred()) + + sniHostname1 := "sniHostname1" + tcpRoute1 := v7.TcpRouteMapping{ // This one has no HostTLSPort, before the migration this will default to NULL + Model: v7.Model{Guid: "guid-0"}, + ExpiresAt: time.Now().Add(1 * time.Hour), + TcpMappingEntity: v7.TcpMappingEntity{ + RouterGroupGuid: "test0-preexisting-omitted-host-tls-port", + HostPort: 80, + HostIP: "1.1.1.1", + ExternalPort: 80, + SniHostname: &sniHostname1, + }, + } + + tcpRoute2 := v7.TcpRouteMapping{ // This one has HostTLSPort set explicitly to 8443 + Model: v7.Model{Guid: "guid-2"}, + ExpiresAt: time.Now().Add(1 * time.Hour), + TcpMappingEntity: v7.TcpMappingEntity{ + RouterGroupGuid: "test0-preexisting-host-tls-port-8443", + HostPort: 80, + HostTLSPort: 8443, + HostIP: "2.2.2.2", + ExternalPort: 80, + SniHostname: &sniHostname1, + }, + } + + _, err = sqlDB.Client.Create(&tcpRoute1) + Expect(err).NotTo(HaveOccurred()) + _, err = sqlDB.Client.Create(&tcpRoute2) + Expect(err).NotTo(HaveOccurred()) + + By("validating that there are 2 tcp routes") + tcpRoutes, err := sqlDB.ReadTcpRouteMappings() + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(2)) + + By("validating that 1 has host_tls_port set to NULL") + tcpRoutesWithNULL, err := readFilteredTcpRouteMappingsWhereHostTcpPortIsNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithNULL)).To(Equal(1)) + Expect(tcpRoutesWithNULL[0].HostIP).To(Equal("1.1.1.1")) + + By("validating that 1 has host_tls_port set to a non-NULL value") + tcpRoutesWithoutNULL, err := readFilteredTcpRouteMappingsWhereHostTcpPortIsNotNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithoutNULL)).To(Equal(1)) + Expect(tcpRoutesWithoutNULL[0].HostIP).To(Equal("2.2.2.2")) + }) + + It("updates existing records with a NULL value to have a value of 0", func() { + By("running the migration") + v8Migration := migration.NewV8HostTLSPortTCPDefaultZero() + err := v8Migration.Run(sqlDB) + Expect(err).ToNot(HaveOccurred()) + + By("validating that there are still 2 tcp routes") + tcpRoutes, err := sqlDB.ReadTcpRouteMappings() + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(2)) + + By("validating that there are now zero tcp routes with host_tls_port set to NULL") + tcpRoutesWithNULL, err := readFilteredTcpRouteMappingsWhereHostTcpPortIsNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithNULL)).To(Equal(0)) + + By("validating that there are now two tcp routes with host_tls_port set to a non-NULL value") + tcpRoutesWithoutNULL, err := readFilteredTcpRouteMappingsWhereHostTcpPortIsNotNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithoutNULL)).To(Equal(2)) + + By("validating that the host_tls_port for tcpRoute2 did not change") + tcpRoutes, err = sqlDB.ReadFilteredTcpRouteMappings("host_tls_port", []string{"8443"}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(1)) + Expect(tcpRoutes[0].HostIP).To(Equal("2.2.2.2")) + + By("validating that the host_tls_port for tcpRoute1 is 0 in the db") + tcpRoutes, err = sqlDB.ReadFilteredTcpRouteMappings("host_tls_port", []string{"0"}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(1)) + Expect(tcpRoutes[0].HostIP).To(Equal("1.1.1.1")) + + By("creating a new route post migration without host_tls_port set") + tcpRoute3 := v7.TcpRouteMapping{ // This one has no HostTLSPort, before the migration this will default to NULL + Model: v7.Model{Guid: "guid-meow"}, + ExpiresAt: time.Now().Add(1 * time.Hour), + TcpMappingEntity: v7.TcpMappingEntity{ + RouterGroupGuid: "meow-testing-post-migration-when-there-is-no-host-tls-port", + HostPort: 80, + HostIP: "3.3.3.3", + ExternalPort: 80, + }, + } + _, err = sqlDB.Client.Create(&tcpRoute3) + Expect(err).NotTo(HaveOccurred()) + + By("validating that all new tcproutes will default to 0") + tcpRoutes, err = sqlDB.ReadFilteredTcpRouteMappings("host_tls_port", []string{"0"}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(2)) + Expect([]string{tcpRoutes[0].HostIP, tcpRoutes[1].HostIP}).To(ContainElements("1.1.1.1", "3.3.3.3")) + }) + + Context("when run against a database that was fixed by hand", func() { + It("doesnt fail during the migration", func() { + + By("manually updating the default") + if sqlDB.Client.Dialect().GetName() == "postgres" { + sqlDB.Client.Exec("ALTER TABLE tcp_routes ALTER COLUMN host_tls_port SET DEFAULT 0") + } else { + sqlDB.Client.Exec("ALTER TABLE tcp_routes MODIFY COLUMN host_tls_port int DEFAULT 0") + } + sqlDB.Client.Exec("UPDATE tcp_routes SET host_tls_port = 0 WHERE host_tls_port IS NULL") + + By("validating that there are still 2 tcp routes") + tcpRoutes, err := sqlDB.ReadTcpRouteMappings() + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(2)) + + By("validating that there are now zero tcp routes with host_tls_port set to NULL") + tcpRoutesWithNULL, err := readFilteredTcpRouteMappingsWhereHostTcpPortIsNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithNULL)).To(Equal(0)) + + By("validating that there are now two tcp routes with host_tls_port set to a non-NULL value") + tcpRoutesWithoutNULL, err := readFilteredTcpRouteMappingsWhereHostTcpPortIsNotNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithoutNULL)).To(Equal(2)) + + By("creating a new route post manual fix without host_tls_port set") + tcpRoute3 := v7.TcpRouteMapping{ // This one has no HostTLSPort, before the migration this will default to NULL + Model: v7.Model{Guid: "guid-meow"}, + ExpiresAt: time.Now().Add(1 * time.Hour), + TcpMappingEntity: v7.TcpMappingEntity{ + RouterGroupGuid: "meow-testing-post-migration-when-there-is-no-host-tls-port", + HostPort: 80, + HostIP: "3.3.3.3", + ExternalPort: 80, + }, + } + _, err = sqlDB.Client.Create(&tcpRoute3) + Expect(err).NotTo(HaveOccurred()) + + By("validating that new tcproutes will default to 0") + tcpRoutes, err = sqlDB.ReadFilteredTcpRouteMappings("host_tls_port", []string{"0"}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(2)) + Expect([]string{tcpRoutes[0].HostIP, tcpRoutes[1].HostIP}).To(ContainElements("1.1.1.1", "3.3.3.3")) + + By("running the migration") + v8Migration := migration.NewV8HostTLSPortTCPDefaultZero() + err = v8Migration.Run(sqlDB) + Expect(err).ToNot(HaveOccurred()) + + By("validating that there are still 3 tcp routes") + tcpRoutes, err = sqlDB.ReadTcpRouteMappings() + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(3)) + + By("validating that there are now zero tcp routes with host_tls_port set to NULL") + tcpRoutesWithNULL, err = readFilteredTcpRouteMappingsWhereHostTcpPortIsNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithNULL)).To(Equal(0)) + + By("validating that there are now two tcp routes with host_tls_port set to a non-NULL value") + tcpRoutesWithoutNULL, err = readFilteredTcpRouteMappingsWhereHostTcpPortIsNotNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithoutNULL)).To(Equal(3)) + + By("validating that the host_tls_port for tcpRoute2 did not change") + tcpRoutes, err = sqlDB.ReadFilteredTcpRouteMappings("host_tls_port", []string{"8443"}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(1)) + Expect(tcpRoutes[0].HostIP).To(Equal("2.2.2.2")) + + By("creating a new route post migration without host_tls_port set") + tcpRoute4 := v7.TcpRouteMapping{ // This one has no HostTLSPort, before the migration this will default to NULL + Model: v7.Model{Guid: "guid-meow-4"}, + ExpiresAt: time.Now().Add(1 * time.Hour), + TcpMappingEntity: v7.TcpMappingEntity{ + RouterGroupGuid: "meow-testing-post-migration-when-there-is-no-host-tls-port-4", + HostPort: 44, + HostIP: "4.4.4.4", + ExternalPort: 44, + }, + } + _, err = sqlDB.Client.Create(&tcpRoute4) + Expect(err).NotTo(HaveOccurred()) + + By("validating that all tcproutes will still default to 0") + tcpRoutes, err = sqlDB.ReadFilteredTcpRouteMappings("host_tls_port", []string{"0"}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(3)) + Expect([]string{tcpRoutes[0].HostIP, tcpRoutes[1].HostIP, tcpRoutes[2].HostIP}).To(ContainElements("1.1.1.1", "3.3.3.3", "4.4.4.4")) + // 1.1.1.1 was made before fixing by hand + // 3.3.3.3 was made after the manual fix + // 4.4.4.4 was made after the migration + }) + }) + }) + + Context("when the tables are newly created (by V0 init migration)", func() { + BeforeEach(func() { + v0Migration := migration.NewV0InitMigration() + err := v0Migration.Run(sqlDB) + Expect(err).ToNot(HaveOccurred()) + + By("running the migration") + v8Migration := migration.NewV8HostTLSPortTCPDefaultZero() + err = v8Migration.Run(sqlDB) + Expect(err).ToNot(HaveOccurred()) + }) + + It("always has default 0 for host_tls_port from the beginning", func() { + By("creating a new route post migration without host_tls_port set") + tcpRoute := v7.TcpRouteMapping{ // This one has no HostTLSPort, before the migration this will default to NULL + Model: v7.Model{Guid: "guid-meow"}, + ExpiresAt: time.Now().Add(1 * time.Hour), + TcpMappingEntity: v7.TcpMappingEntity{ + RouterGroupGuid: "meow-testing-post-migration-when-there-is-no-host-tls-port", + HostPort: 80, + HostIP: "1.1.1.1", + ExternalPort: 80, + }, + } + _, err := sqlDB.Client.Create(&tcpRoute) + Expect(err).NotTo(HaveOccurred()) + + By("validating that all new tcproutes will default to 0") + tcpRoutes, err := sqlDB.ReadFilteredTcpRouteMappings("host_tls_port", []string{"0"}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutes)).To(Equal(1)) + Expect(tcpRoutes[0].HostIP).To(Equal("1.1.1.1")) + + By("validating that there are zero tcp routes with host_tls_port set to NULL") + tcpRoutesWithNULL, err := readFilteredTcpRouteMappingsWhereHostTcpPortIsNull(sqlDB) + Expect(err).ToNot(HaveOccurred()) + Expect(len(tcpRoutesWithNULL)).To(Equal(0)) + }) + }) + + Context("when run against a database that was already migrated", func() { + BeforeEach(func() { + err := sqlDB.Client.AutoMigrate(&models.RouterGroupDB{}, &models.TcpRouteMapping{}, &models.Route{}) + Expect(err).ToNot(HaveOccurred()) + }) + }) + }) +}) + +func readFilteredTcpRouteMappingsWhereHostTcpPortIsNull(s *db.SqlDB) ([]models.TcpRouteMapping, error) { + var tcpRoutes []models.TcpRouteMapping + now := time.Now() + err := s.Client.Where("host_tls_port IS NULL").Where("expires_at > ?", now).Find(&tcpRoutes) + if err != nil { + return nil, err + } + return tcpRoutes, nil +} + +func readFilteredTcpRouteMappingsWhereHostTcpPortIsNotNull(s *db.SqlDB) ([]models.TcpRouteMapping, error) { + var tcpRoutes []models.TcpRouteMapping + now := time.Now() + err := s.Client.Where("host_tls_port IS NOT NULL").Where("expires_at > ?", now).Find(&tcpRoutes) + if err != nil { + return nil, err + } + return tcpRoutes, nil +} diff --git a/migration/migration.go b/migration/migration.go index 6ae17f08..9559fbb7 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -88,6 +88,9 @@ func InitializeMigrations() []Migration { migration = NewV7TCPTLSRoutes() migrations = append(migrations, migration) + migration = NewV8HostTLSPortTCPDefaultZero() + migrations = append(migrations, migration) + return migrations } diff --git a/migration/migration_test.go b/migration/migration_test.go index 70d8bd96..b1c0d188 100644 --- a/migration/migration_test.go +++ b/migration/migration_test.go @@ -43,7 +43,7 @@ var _ = Describe("Migration", func() { done := make(chan struct{}) defer close(done) migrations := migration.InitializeMigrations() - Expect(migrations).To(HaveLen(7)) + Expect(migrations).To(HaveLen(8)) Expect(migrations[0]).To(BeAssignableToTypeOf(new(migration.V0InitMigration))) Expect(migrations[1]).To(BeAssignableToTypeOf(new(migration.V2UpdateRgMigration))) diff --git a/migration/v7/model.go b/migration/v7/model.go new file mode 100644 index 00000000..365801c7 --- /dev/null +++ b/migration/v7/model.go @@ -0,0 +1,9 @@ +package models + +import "time" + +type Model struct { + Guid string `gorm:"primary_key" json:"-"` + CreatedAt time.Time `json:"-"` + UpdatedAt time.Time `json:"-"` +} diff --git a/migration/v7/models_suite_test.go b/migration/v7/models_suite_test.go new file mode 100644 index 00000000..71cd775e --- /dev/null +++ b/migration/v7/models_suite_test.go @@ -0,0 +1,13 @@ +package models_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +func TestModels(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Models Suite") +} diff --git a/migration/v7/models_test.go b/migration/v7/models_test.go new file mode 100644 index 00000000..4d930b2d --- /dev/null +++ b/migration/v7/models_test.go @@ -0,0 +1,470 @@ +package models_test + +import ( + "encoding/json" + + . "code.cloudfoundry.org/routing-api/models" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Models", func() { + Describe("ModificationTag", func() { + var tag ModificationTag + + BeforeEach(func() { + tag = ModificationTag{Guid: "guid1", Index: 5} + }) + + Describe("Increment", func() { + BeforeEach(func() { + tag.Increment() + }) + + It("Increments the index", func() { + Expect(tag.Index).To(Equal(uint32(6))) + }) + }) + + Describe("SucceededBy", func() { + var tag2 ModificationTag + + Context("when the guid is the different", func() { + BeforeEach(func() { + tag2 = ModificationTag{Guid: "guid5", Index: 0} + }) + It("new tag should succeed", func() { + Expect(tag.SucceededBy(&tag2)).To(BeTrue()) + }) + }) + + Context("when the guid is the same", func() { + + Context("when the index is the same as the original tag", func() { + BeforeEach(func() { + tag2 = ModificationTag{Guid: "guid1", Index: 5} + }) + + It("new tag should not succeed", func() { + Expect(tag.SucceededBy(&tag2)).To(BeFalse()) + }) + + }) + + Context("when the index is less than original tag Index", func() { + + BeforeEach(func() { + tag2 = ModificationTag{Guid: "guid1", Index: 4} + }) + + It("new tag should not succeed", func() { + Expect(tag.SucceededBy(&tag2)).To(BeFalse()) + }) + }) + + Context("when the index is greater than original tag Index", func() { + BeforeEach(func() { + tag2 = ModificationTag{Guid: "guid1", Index: 6} + }) + + It("new tag should succeed", func() { + Expect(tag.SucceededBy(&tag2)).To(BeTrue()) + }) + + }) + + }) + + }) + }) + + Describe("RouterGroup", func() { + var rg RouterGroup + + Describe("Validate", func() { + It("does not allow ReservablePorts for http type", func() { + rg = RouterGroup{ + Name: "router-group-1", + Type: "http", + ReservablePorts: "1025-2025", + } + err := rg.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("Reservable ports are not supported for router groups of type http")) + By("not having ReservablePorts") + rg = RouterGroup{ + Name: "router-group-1", + Type: "http", + } + err = rg.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("ReservablePorts are optional for non-http, non-tcp type", func() { + rg = RouterGroup{ + Name: "router-group-1", + Type: "foo", + ReservablePorts: "1025-2025", + } + err := rg.Validate() + Expect(err).ToNot(HaveOccurred()) + + rg = RouterGroup{ + Name: "router-group-1", + Type: "foo", + ReservablePorts: "", + } + err = rg.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("succeeds for valid router group", func() { + rg = RouterGroup{ + Name: "router-group-1", + Type: "tcp", + ReservablePorts: "1025-2025", + } + err := rg.Validate() + Expect(err).NotTo(HaveOccurred()) + }) + + It("fails for missing type", func() { + rg = RouterGroup{ + Name: "router-group-1", + ReservablePorts: "10-20", + } + err := rg.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("Missing type in router group")) + }) + It("fails for missing name", func() { + rg = RouterGroup{ + Type: "tcp", + ReservablePorts: "10-20", + } + err := rg.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("Missing name in router group")) + }) + + It("fails for missing ReservablePorts", func() { + rg = RouterGroup{ + Type: "tcp", + Name: "router-group-1", + } + err := rg.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("Missing reservable_ports in router group: router-group-1")) + }) + + Context("when there are reserved system component ports", func() { + BeforeEach(func() { + ReservedSystemComponentPorts = []uint16{5555, 6666, 7777} + }) + + Context("when failOnRouterPortConflicts is true", func() { + BeforeEach(func() { + FailOnRouterPortConflicts = true + }) + + It("succeeds when the ports don't overlap", func() { + rg = RouterGroup{ + Name: "router-group-1", + Type: "tcp", + ReservablePorts: "1025-2025", + } + err := rg.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("fails when the ports overlap", func() { + rg = RouterGroup{ + Name: "router-group-1", + Type: "tcp", + ReservablePorts: "5000-6000", + } + err := rg.Validate() + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("Invalid ports. Reservable ports must not include the following reserved system component ports: [5555 6666 7777].")) + }) + + Context("when failOnRouterPortConflicts is false", func() { + + BeforeEach(func() { + FailOnRouterPortConflicts = false + }) + + It("succeeds when the ports don't overlap", func() { + rg = RouterGroup{ + Name: "router-group-1", + Type: "tcp", + ReservablePorts: "1025-2025", + } + err := rg.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + + It("succeeds when the ports overlap", func() { + rg = RouterGroup{ + Name: "router-group-1", + Type: "tcp", + ReservablePorts: "5000-6000", + } + err := rg.Validate() + Expect(err).ToNot(HaveOccurred()) + }) + }) + }) + }) + }) + }) + + Describe("ReservablePorts", func() { + var ports ReservablePorts + + Describe("Validate", func() { + It("succeeds for valid reservable ports", func() { + ports = "6001,6005,6010-6020,6021-6030" + err := ports.Validate() + Expect(err).NotTo(HaveOccurred()) + }) + + It("fails for overlapping ranges", func() { + ports = "6010-6020,6020-6030" + err := ports.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("Overlapping values: [6010-6020] and [6020-6030]")) + }) + + It("fails for overlapping values", func() { + ports = "6001,6001,6002,6003,6003,6004" + err := ports.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("Overlapping values: 6001 and 6001")) + }) + + It("fails for invalid reservable ports", func() { + ports = "foo!" + err := ports.Validate() + Expect(err).To(HaveOccurred()) + }) + }) + + Describe("Parse", func() { + It("validates a single unsigned integer", func() { + ports = "9999" + r, err := ports.Parse() + Expect(err).NotTo(HaveOccurred()) + + Expect(len(r)).To(Equal(1)) + start, end := r[0].Endpoints() + Expect(start).To(Equal(uint16(9999))) + Expect(end).To(Equal(uint16(9999))) + }) + + It("validates multiple integers", func() { + ports = "9999,1111,2222" + r, err := ports.Parse() + Expect(err).NotTo(HaveOccurred()) + Expect(len(r)).To(Equal(3)) + + expected := []uint16{9999, 1111, 2222} + for i := 0; i < len(r); i++ { + start, end := r[i].Endpoints() + Expect(start).To(Equal(expected[i])) + Expect(end).To(Equal(expected[i])) + } + }) + + It("validates a range", func() { + ports = "10241-10249" + r, err := ports.Parse() + Expect(err).NotTo(HaveOccurred()) + + Expect(len(r)).To(Equal(1)) + start, end := r[0].Endpoints() + Expect(start).To(Equal(uint16(10241))) + Expect(end).To(Equal(uint16(10249))) + }) + + It("validates a list of ranges and integers", func() { + ports = "6001-6010,6020-6022,6045,6050-6060" + r, err := ports.Parse() + Expect(err).NotTo(HaveOccurred()) + + Expect(len(r)).To(Equal(4)) + expected := []uint16{6001, 6010, 6020, 6022, 6045, 6045, 6050, 6060} + for i := 0; i < len(r); i++ { + start, end := r[i].Endpoints() + Expect(start).To(Equal(expected[2*i])) + Expect(end).To(Equal(expected[2*i+1])) + } + }) + + It("errors on range with 3 dashes", func() { + ports = "10-999-1000" + _, err := ports.Parse() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("range (10-999-1000) has too many '-' separators")) + }) + + It("errors on a negative integer", func() { + ports = "-9999" + _, err := ports.Parse() + Expect(err).To(HaveOccurred()) + }) + + It("errors on a incomplete range", func() { + ports = "1030-" + _, err := ports.Parse() + Expect(err).To(HaveOccurred()) + }) + + It("errors on non-numeric input", func() { + ports = "adsfasdf" + _, err := ports.Parse() + Expect(err).To(HaveOccurred()) + }) + + It("errors when range starts with lower number", func() { + ports = "10000-9999" + _, err := ports.Parse() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("range (10000-9999) must be in ascending numeric order")) + }) + }) + }) + + Describe("Range", func() { + Describe("Overlaps", func() { + testRange, _ := NewRange(6010, 6020) + + It("validates non-overlapping ranges", func() { + r, _ := NewRange(6021, 6030) + Expect(testRange.Overlaps(r)).To(BeFalse()) + }) + + It("finds overlapping ranges of single values", func() { + r1, _ := NewRange(6010, 6010) + r2, _ := NewRange(6010, 6010) + Expect(r1.Overlaps(r2)).To(BeTrue()) + }) + + It("finds overlapping ranges of single value and range", func() { + r2, _ := NewRange(6015, 6015) + Expect(testRange.Overlaps(r2)).To(BeTrue()) + }) + + It("finds overlapping ranges of single value upper bound and range", func() { + r2, _ := NewRange(6020, 6020) + Expect(testRange.Overlaps(r2)).To(BeTrue()) + }) + + It("validates single value one above upper bound range", func() { + r2, _ := NewRange(6021, 6021) + Expect(testRange.Overlaps(r2)).To(BeFalse()) + }) + + It("finds overlapping ranges when start overlaps", func() { + r, _ := NewRange(6015, 6030) + Expect(testRange.Overlaps(r)).To(BeTrue()) + }) + + It("finds overlapping ranges when end overlaps", func() { + r, _ := NewRange(6005, 6015) + Expect(testRange.Overlaps(r)).To(BeTrue()) + }) + + It("finds overlapping ranges when the range is a superset", func() { + r, _ := NewRange(6009, 6021) + Expect(testRange.Overlaps(r)).To(BeTrue()) + }) + }) + }) + + Describe("Route", func() { + var ( + route Route + ) + + BeforeEach(func() { + tag, err := NewModificationTag() + Expect(err).ToNot(HaveOccurred()) + route = NewRoute("/foo/bar", 35, "2.2.2.2", "", "banana", 66) + route.ModificationTag = tag + }) + + Describe("SetDefaults", func() { + JustBeforeEach(func() { + route.SetDefaults(120) + }) + + Context("when ttl is nil", func() { + BeforeEach(func() { + route.TTL = nil + }) + + It("sets the default ttl", func() { + Expect(*route.TTL).To(Equal(120)) + }) + }) + + Context("when ttl is not nil", func() { + It("does not change ttl", func() { + Expect(*route.TTL).To(Equal(66)) + }) + }) + }) + }) + + Describe("TcpRouteMapping", func() { + var ( + route TcpRouteMapping + ) + + BeforeEach(func() { + tag, err := NewModificationTag() + Expect(err).ToNot(HaveOccurred()) + route = NewTcpRouteMapping("router-group-1", 60000, "2.2.2.2", 64000, 64001, "instance-id", pointertoString("sni-hostname"), 66, tag) + }) + + Describe("SetDefaults", func() { + JustBeforeEach(func() { + route.SetDefaults(120) + }) + + Context("when ttl is nil", func() { + BeforeEach(func() { + route.TTL = nil + }) + + It("sets default ttl", func() { + Expect(*route.TTL).To(Equal(120)) + }) + }) + + Context("when ttl is not nil", func() { + It("doesn't change ttl", func() { + Expect(*route.TTL).To(Equal(66)) + }) + }) + }) + + Context("multiple annotations", func() { + It("return router group object", func() { + jsonStr := + ` +{ "guid": "some-guid", + "name": "name" +} +` + rg := RouterGroup{} + err := json.Unmarshal([]byte(jsonStr), &rg) + Expect(err).ToNot(HaveOccurred()) + Expect(rg.Guid).To(Equal("some-guid")) + Expect(rg.Name).To(Equal("name")) + }) + }) + }) +}) diff --git a/migration/v7/route.go b/migration/v7/route.go new file mode 100644 index 00000000..c6ff3795 --- /dev/null +++ b/migration/v7/route.go @@ -0,0 +1,91 @@ +package models + +import ( + "time" + + uuid "github.com/nu7hatch/gouuid" +) + +type Route struct { + Model + ExpiresAt time.Time `json:"-"` + RouteEntity +} + +type RouteEntity struct { + Route string `gorm:"not null; unique_index:idx_route" json:"route"` + Port uint16 `gorm:"not null; unique_index:idx_route" json:"port"` + IP string `gorm:"not null; unique_index:idx_route" json:"ip"` + TTL *int `json:"ttl"` + LogGuid string `json:"log_guid"` + RouteServiceUrl string `gorm:"not null; unique_index:idx_route" json:"route_service_url,omitempty"` + ModificationTag `json:"modification_tag"` +} + +func NewRouteWithModel(route Route) (Route, error) { + guid, err := uuid.NewV4() + if err != nil { + return Route{}, err + } + + return Route{ + ExpiresAt: time.Now().Add(time.Duration(*route.TTL) * time.Second), + Model: Model{Guid: guid.String()}, + RouteEntity: route.RouteEntity, + }, nil +} +func NewRoute(url string, port uint16, ip, logGuid, routeServiceUrl string, ttl int) Route { + route := RouteEntity{ + Route: url, + Port: port, + IP: ip, + TTL: &ttl, + LogGuid: logGuid, + RouteServiceUrl: routeServiceUrl, + } + return Route{ + RouteEntity: route, + } +} + +func NewModificationTag() (ModificationTag, error) { + uuid, err := uuid.NewV4() + if err != nil { + return ModificationTag{}, err + } + + return ModificationTag{ + Guid: uuid.String(), + Index: 0, + }, nil +} + +func (t *ModificationTag) Increment() { + t.Index++ +} + +func (m *ModificationTag) SucceededBy(other *ModificationTag) bool { + if m == nil || m.Guid == "" || other.Guid == "" { + return true + } + + return m.Guid != other.Guid || m.Index < other.Index +} + +func (r Route) GetTTL() int { + if r.TTL == nil { + return 0 + } + return *r.TTL +} + +func (r *Route) SetDefaults(defaultTTL int) { + if r.TTL == nil { + r.TTL = &defaultTTL + } +} + +type ModificationTag struct { + Guid string `gorm:"column:modification_guid" json:"guid"` + Index uint32 `gorm:"column:modification_index" json:"index"` +} diff --git a/migration/v7/router_groups.go b/migration/v7/router_groups.go new file mode 100644 index 00000000..4bd9b037 --- /dev/null +++ b/migration/v7/router_groups.go @@ -0,0 +1,284 @@ +package models + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +var InvalidPortError = errors.New("Port must be between 1024 and 65535") + +type RouterGroupType string + +var ReservedSystemComponentPorts = []uint16{} +var FailOnRouterPortConflicts = false + +const ( + RouterGroup_TCP RouterGroupType = "tcp" + RouterGroup_HTTP RouterGroupType = "http" +) + +type RouterGroupsDB []RouterGroupDB + +type RouterGroupDB struct { + Model + Name string + Type string + ReservablePorts string +} + +type RouterGroup struct { + Model + Guid string `json:"guid"` + Name string `json:"name"` + Type RouterGroupType `json:"type"` + ReservablePorts ReservablePorts `json:"reservable_ports" yaml:"reservable_ports"` +} + +func NewRouterGroupDB(routerGroup RouterGroup) RouterGroupDB { + if routerGroup.Model.Guid == "" { + routerGroup.Model = Model{ + Guid: routerGroup.Guid, + } + } + return RouterGroupDB{ + Model: routerGroup.Model, + Name: routerGroup.Name, + Type: string(routerGroup.Type), + ReservablePorts: string(routerGroup.ReservablePorts), + } +} + +func (RouterGroupDB) TableName() string { + return "router_groups" +} + +func (rg *RouterGroupDB) ToRouterGroup() RouterGroup { + return RouterGroup{ + Model: rg.Model, + Guid: rg.Guid, + Name: rg.Name, + Type: RouterGroupType(rg.Type), + ReservablePorts: ReservablePorts(rg.ReservablePorts), + } +} + +func (rgs RouterGroupsDB) ToRouterGroups() RouterGroups { + routerGroups := RouterGroups{} + for _, routerGroupDB := range rgs { + routerGroups = append(routerGroups, routerGroupDB.ToRouterGroup()) + } + return routerGroups +} + +type RouterGroups []RouterGroup + +func (g RouterGroups) Validate() error { + for _, r := range g { + if err := r.Validate(); err != nil { + return err + } + } + return nil +} + +func (g RouterGroup) Validate() error { + if g.Name == "" { + return errors.New("Missing name in router group") + } + + if g.Type == "" { + return errors.New("Missing type in router group") + } + + if g.ReservablePorts == "" { + if g.Type == RouterGroup_TCP { + return fmt.Errorf("Missing reservable_ports in router group: %s", g.Name) + } + + return nil + } + + if g.Type == RouterGroup_HTTP { + return errors.New("Reservable ports are not supported for router groups of type http") + } + + return g.ReservablePorts.Validate() + +} + +type ReservablePorts string + +func (p *ReservablePorts) UnmarshalYAML(unmarshal func(interface{}) error) error { + var input interface{} + + err := unmarshal(&input) + if err != nil { + return err // untested + } + + switch t := input.(type) { + case int: + *p = ReservablePorts(strconv.Itoa(t)) + case string: + *p = ReservablePorts(input.(string)) + case []interface{}: + var s []string + + for _, v := range t { + val, ok := v.(int) + if !ok { + return errors.New("invalid type for reservable port") + } + + s = append(s, strconv.Itoa(val)) + } + + *p = ReservablePorts(strings.Join(s, ",")) + default: + return errors.New("reservable port unmarshal failed") // untested + } + + return nil +} + +func (p ReservablePorts) Validate() error { + portRanges, err := p.Parse() + if err != nil { + return err + } + + // check for overlapping ranges + for i, r1 := range portRanges { + for j, r2 := range portRanges { + if i == j { + continue + } + if r1.Overlaps(r2) { + errMsg := fmt.Sprintf("Overlapping values: %s and %s", r1.String(), r2.String()) + return errors.New(errMsg) + } + } + } + // check if ports overlap with reservedSystemComponentPorts + if FailOnRouterPortConflicts { + for _, r1 := range portRanges { + for _, reservedPort := range ReservedSystemComponentPorts { + + if reservedPort >= r1.start && reservedPort <= r1.end { + errMsg := fmt.Sprintf("Invalid ports. Reservable ports must not include the following reserved system component ports: %v.", ReservedSystemComponentPorts) + return errors.New(errMsg) + } + } + } + } + return nil +} + +func (p ReservablePorts) Parse() (Ranges, error) { + rangesArray := strings.Split(string(p), ",") + var ranges Ranges + + for _, p := range rangesArray { + r, err := parseRange(p) + if err != nil { + return Ranges{}, err + } else { + ranges = append(ranges, r) + } + } + + return ranges, nil +} + +type Range struct { + start uint16 // inclusive + end uint16 // inclusive +} +type Ranges []Range + +func portIsInRange(port uint16) bool { + return port >= 1024 +} + +func NewRange(start, end uint16) (Range, error) { + if portIsInRange(start) && portIsInRange(end) { + return Range{ + start: start, + end: end, + }, nil + } + return Range{}, InvalidPortError +} + +func (r Range) Overlaps(other Range) bool { + maxUpper := r.max(other) + minLower := r.min(other) + // check bounds for both, then see if size of both fit + // For example: 10-20 and 15-30 + // |----10-20----| + // |-------15-30------| + // |==========================| + // minLower: 10 maxUpper: 30 + // (30 - 10) <= (20 - 10) + (30 - 15) + // 20 <= 25? + return uint64(maxUpper-minLower) <= uint64(r.end-r.start)+uint64(other.end-other.start) +} + +func (r Range) String() string { + if r.start == r.end { + return fmt.Sprintf("%d", r.start) + } + return fmt.Sprintf("[%d-%d]", r.start, r.end) +} + +func (r Range) max(other Range) uint16 { + if r.end > other.end { + return r.end + } + return other.end +} + +func (r Range) min(other Range) uint16 { + if r.start < other.start { + return r.start + } + return other.start +} + +func (r Range) Endpoints() (uint16, uint16) { + return r.start, r.end +} + +func parseRange(r string) (Range, error) { + endpoints := strings.Split(r, "-") + + len := len(endpoints) + switch len { + case 1: + n, err := strconv.ParseUint(endpoints[0], 10, 16) + if err != nil { + return Range{}, InvalidPortError + } + return NewRange(uint16(n), uint16(n)) + case 2: + start, err := strconv.ParseUint(endpoints[0], 10, 16) + if err != nil { + return Range{}, fmt.Errorf("range (%s) requires a starting port", r) + } + + end, err := strconv.ParseUint(endpoints[1], 10, 16) + if err != nil { + return Range{}, fmt.Errorf("range (%s) requires an ending port", r) + } + + if start > end { + return Range{}, fmt.Errorf("range (%s) must be in ascending numeric order", r) + } + + return NewRange(uint16(start), uint16(end)) + default: + return Range{}, fmt.Errorf("range (%s) has too many '-' separators", r) + } +} diff --git a/migration/v7/tcp_route.go b/migration/v7/tcp_route.go new file mode 100644 index 00000000..fece5161 --- /dev/null +++ b/migration/v7/tcp_route.go @@ -0,0 +1,121 @@ +package models + +import ( + "fmt" + "time" + + uuid "github.com/nu7hatch/gouuid" +) + +type TcpRouteMapping struct { + Model + ExpiresAt time.Time `json:"-"` + TcpMappingEntity +} + +// IMPORTANT!! when adding a new field here that is part of the unique index for +// +// a tcp route, make sure to update not only the logic for Matches(), +// but also the SqlDb.FindExistingTcpRouteMapping() function's custom +// WHERE filter to include the new field +type TcpMappingEntity struct { + RouterGroupGuid string `gorm:"not null; unique_index:idx_tcp_route" json:"router_group_guid"` + HostPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int" json:"backend_port"` + HostTLSPort int `gorm:"default:null; unique_index:idx_tcp_route; type:int" json:"backend_tls_port"` + HostIP string `gorm:"not null; unique_index:idx_tcp_route" json:"backend_ip"` + SniHostname *string `gorm:"default:null; unique_index:idx_tcp_route" json:"backend_sni_hostname,omitempty"` + // We don't add uniqueness on InstanceId so that if a route is attempted to be created with the same detals but + // different InstanceId, we fail uniqueness and prevent stale/duplicate routes. If this fails a route, the + // TTL on the old record should expire + allow the new route to be created eventually. + InstanceId string `gorm:"null; default:null;" json:"instance_id"` + ExternalPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type: int" json:"port"` + ModificationTag `json:"modification_tag"` + TTL *int `json:"ttl,omitempty"` + IsolationSegment string `json:"isolation_segment"` +} + +func (TcpRouteMapping) TableName() string { + return "tcp_routes" +} + +func NewTcpRouteMappingWithModel(tcpMapping TcpRouteMapping) (TcpRouteMapping, error) { + guid, err := uuid.NewV4() + if err != nil { + return TcpRouteMapping{}, err + } + + m := Model{Guid: guid.String()} + return TcpRouteMapping{ + ExpiresAt: time.Now().Add(time.Duration(*tcpMapping.TTL) * time.Second), + Model: m, + TcpMappingEntity: tcpMapping.TcpMappingEntity, + }, nil +} + +func NewTcpRouteMapping( + routerGroupGuid string, + externalPort uint16, + hostIP string, + hostPort uint16, + hostTlsPort int, + instanceId string, + sniHostname *string, + ttl int, + modTag ModificationTag, +) TcpRouteMapping { + mapping := TcpRouteMapping{ + TcpMappingEntity: TcpMappingEntity{ + RouterGroupGuid: routerGroupGuid, + ExternalPort: externalPort, + SniHostname: sniHostname, + InstanceId: instanceId, + HostPort: hostPort, + HostTLSPort: hostTlsPort, + HostIP: hostIP, + TTL: &ttl, + ModificationTag: modTag, + }, + } + return mapping +} + +func (m TcpRouteMapping) String() string { + return fmt.Sprintf("%s:%d<->%s:%d", m.RouterGroupGuid, m.ExternalPort, m.HostIP, m.HostPort) +} + +func (m TcpRouteMapping) Matches(other TcpRouteMapping) bool { + sameRouterGroupGuid := m.RouterGroupGuid == other.RouterGroupGuid + sameExternalPort := m.ExternalPort == other.ExternalPort + sameHostIP := m.HostIP == other.HostIP + sameHostPort := m.HostPort == other.HostPort + sameInstanceId := m.InstanceId == other.InstanceId + sameHostTLSPort := m.HostTLSPort == other.HostTLSPort + + nilTTL := m.TTL == nil && other.TTL == nil + sameTTLPointer := m.TTL == other.TTL + sameTTLValue := m.TTL != nil && other.TTL != nil && *m.TTL == *other.TTL + sameTTL := nilTTL || sameTTLPointer || sameTTLValue + + nilSniHostname := m.SniHostname == nil && other.SniHostname == nil + sameSniHostnamePointer := m.SniHostname == other.SniHostname + sameSniHostnameValue := m.SniHostname != nil && other.SniHostname != nil && *m.SniHostname == *other.SniHostname + sameSniHostname := nilSniHostname || sameSniHostnamePointer || sameSniHostnameValue + + return sameRouterGroupGuid && + sameExternalPort && + sameHostIP && + sameHostPort && + sameInstanceId && + sameTTL && + sameHostTLSPort && + sameSniHostname +} + +func (t *TcpRouteMapping) SetDefaults(maxTTL int) { + // default ttl if not present + // TTL is a pointer to a uint16 so that we can + // detect if it's present or not (i.e. nil or 0) + if t.TTL == nil { + t.TTL = &maxTTL + } +} diff --git a/migration/v7/tcp_route_test.go b/migration/v7/tcp_route_test.go new file mode 100644 index 00000000..d6d03e61 --- /dev/null +++ b/migration/v7/tcp_route_test.go @@ -0,0 +1,110 @@ +package models_test + +import ( + "encoding/json" + + "code.cloudfoundry.org/routing-api/models" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func pointertoString(s string) *string { return &s } + +var _ = Describe("TCP Route", func() { + Describe("TcpMappingEntity", func() { + var tcpRouteMapping models.TcpRouteMapping + var sniHostNamePtr *string + + JustBeforeEach(func() { + tcpRouteMapping = models.NewTcpRouteMapping("a-guid", 1234, "hostIp", 5678, 8765, "", sniHostNamePtr, 5, models.ModificationTag{}) + }) + Describe("SNI Hostname", func() { + Context("when the SNI hostname is nil", func() { + BeforeEach(func() { + sniHostNamePtr = nil + }) + It("comes through as nil", func() { + Expect(tcpRouteMapping.SniHostname).To(BeNil()) + }) + It("is omitted from JSON marshaling", func() { + j, err := json.Marshal(tcpRouteMapping) + Expect(err).NotTo(HaveOccurred()) + Expect(string(j)).NotTo(ContainSubstring("backend_sni_hostname")) + }) + }) + + Context("when a valid SNI hostname is provided", func() { + BeforeEach(func() { + sniHostNamePtr = pointertoString("sniHostname") + }) + + It("Accepts the value", func() { + Expect(*tcpRouteMapping.SniHostname).To(Equal("sniHostname")) + }) + It("is provided in the marshaled JSON", func() { + j, err := json.Marshal(tcpRouteMapping) + Expect(err).NotTo(HaveOccurred()) + Expect(string(j)).To(ContainSubstring("backend_sni_hostname")) + }) + }) + Context("when the SNI hostname is empty", func() { + BeforeEach(func() { + sniHostNamePtr = pointertoString("") + }) + It("is provided in the marshaled JSON", func() { + j, err := json.Marshal(tcpRouteMapping) + Expect(err).NotTo(HaveOccurred()) + Expect(string(j)).To(ContainSubstring("backend_sni_hostname")) + }) + }) + }) + Describe("Matches()", func() { + var tcpRouteMapping2 models.TcpRouteMapping + var sniHostNamePtr2 *string + + BeforeEach(func() { + sniHostNamePtr = pointertoString("sniHostName") + }) + + JustBeforeEach(func() { + tcpRouteMapping2 = models.NewTcpRouteMapping("a-guid", 1234, "hostIp", 5678, 8765, "", sniHostNamePtr2, 5, models.ModificationTag{}) + }) + + Context("when two routes have the same SNIHostName value", func() { + BeforeEach(func() { + sniHostNamePtr2 = sniHostNamePtr + }) + It("matches", func() { + Expect(tcpRouteMapping.Matches(tcpRouteMapping2)).To(BeTrue()) + }) + }) + Context("when two routes have equal values", func() { + BeforeEach(func() { + sniHostNamePtr2 = pointertoString("sniHostName") + }) + It("matches", func() { + Expect(tcpRouteMapping.Matches(tcpRouteMapping2)).To(BeTrue()) + }) + }) + + Context("when two routes have values that are not equal", func() { + BeforeEach(func() { + sniHostNamePtr2 = pointertoString("sniHostName2") + }) + It("doesn't match", func() { + Expect(tcpRouteMapping.Matches(tcpRouteMapping2)).To(BeFalse()) + }) + }) + Context("when one of the routes has a nil SNIHostName", func() { + BeforeEach(func() { + sniHostNamePtr2 = nil + }) + It("doesn't match", func() { + Expect(tcpRouteMapping.Matches(tcpRouteMapping2)).To(BeFalse()) + Expect(tcpRouteMapping2.Matches(tcpRouteMapping)).To(BeFalse()) + }) + }) + }) + }) +})