diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ac4049f56..e052f082fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the access logs in Traffic Ops to now show the route ID with every API endpoint call. The Route ID is appended to the end of the access log line. - With the addition of multiple server interfaces, interface data is constructed from IP Address/Gateway/Netmask (and their IPv6 counterparts) and Interface Name and Interface MTU fields on services. These **MUST** have proper, valid data before attempting to upgrade or the upgrade **WILL** fail. In particular IP fields need to be valid IP addresses/netmasks, and MTU must only be positive integers of at least 1280. - The `/servers` and `/servers/{{ID}}}` API endpoints have been updated to use and reflect multi-interface servers. +- CDN Snapshots now use a server's "service addresses" to provide its IP addresses ### Deprecated - Deprecated the non-nullable `DeliveryService` Go struct and other structs that use it. `DeliveryServiceNullable` structs should be used instead. diff --git a/traffic_ops/traffic_ops_golang/ats/db.go b/traffic_ops/traffic_ops_golang/ats/db.go index 95062ef632..6e22b9d9b9 100644 --- a/traffic_ops/traffic_ops_golang/ats/db.go +++ b/traffic_ops/traffic_ops_golang/ats/db.go @@ -612,12 +612,21 @@ func GetServerInfo(tx *sql.Tx, qry string, qryParams []interface{}) (*atscfg.Ser return nil, false, errors.New("querying server info: " + err.Error()) } - infs, err := dbhelpers.GetServerInterfaces(s.ID, tx) + infs, err := dbhelpers.GetServersInterfaces([]int{s.ID}, tx) if err != nil { return nil, false, fmt.Errorf("querying server info interfaces: %v", err) } - legacyInfo, err := tc.InterfaceInfoToLegacyInterfaces(infs) + interfaces, ok := infs[s.ID] + if !ok || len(interfaces) < 1 { + return nil, false, fmt.Errorf("server #%d has no interfaces", s.ID) + } + + ifaces := make([]tc.ServerInterfaceInfo, 0, len(interfaces)) + for _, inf := range interfaces { + ifaces = append(ifaces, inf) + } + legacyInfo, err := tc.InterfaceInfoToLegacyInterfaces(ifaces) if err != nil { return nil, false, fmt.Errorf("converting server info interfaces to legacy: %v", err) } diff --git a/traffic_ops/traffic_ops_golang/crconfig/servers.go b/traffic_ops/traffic_ops_golang/crconfig/servers.go index 11e1959865..9c7c6447b2 100644 --- a/traffic_ops/traffic_ops_golang/crconfig/servers.go +++ b/traffic_ops/traffic_ops_golang/crconfig/servers.go @@ -23,12 +23,13 @@ import ( "database/sql" "errors" "fmt" - "github.com/apache/trafficcontrol/lib/go-util" "strconv" "strings" "github.com/apache/trafficcontrol/lib/go-log" "github.com/apache/trafficcontrol/lib/go-tc" + + "github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers" ) const RouterTypeName = "CCR" @@ -99,12 +100,15 @@ type ServerUnion struct { SecureAPIPort *string } +type ServerAndHost struct { + Server ServerUnion + Host string +} + const DefaultWeightMultiplier = 1000.0 const DefaultWeight = 0.999 func getAllServers(cdn string, tx *sql.Tx) (map[string]ServerUnion, error) { - servers := map[string]ServerUnion{} - serverParams, err := getServerParams(cdn, tx) if err != nil { return nil, errors.New("Error getting server params: " + err.Error()) @@ -112,87 +116,76 @@ func getAllServers(cdn string, tx *sql.Tx) (map[string]ServerUnion, error) { // TODO select deliveryservices as array? q := ` -select s.host_name, - cg.name as cachegroup, - concat(s.host_name, '.', s.domain_name) as fqdn, - s.xmpp_id as hashid, - s.https_port, - s.interface_name, - s.ip_address_is_service, - s.ip6_address_is_service, - s.ip_address, - s.ip6_address, - s.tcp_port, - p.name as profile_name, - cast(p.routing_disabled as int), - st.name as status, - t.name as type -from server as s -inner join cachegroup as cg ON cg.id = s.cachegroup -inner join type as t on t.id = s.type -inner join profile as p ON p.id = s.profile -inner join status as st ON st.id = s.status -where cdn_id = (select id from cdn where name = $1) -and (st.name = 'REPORTED' or st.name = 'ONLINE' or st.name = 'ADMIN_DOWN') -` + SELECT + s.id, + s.host_name, + cg.name as cachegroup, + concat(s.host_name, '.', s.domain_name) AS fqdn, + s.xmpp_id AS hashid, + s.https_port, + s.tcp_port, + p.name AS profile_name, + cast(p.routing_disabled AS int), + st.name AS status, + t.name AS type + FROM server AS s + INNER JOIN cachegroup AS cg ON cg.id = s.cachegroup + INNER JOIN type AS t on t.id = s.type + INNER JOIN profile AS p ON p.id = s.profile + INNER JOIN status AS st ON st.id = s.status + WHERE cdn_id = (SELECT id FROM cdn WHERE name = $1) + AND (st.name = 'REPORTED' OR st.name = 'ONLINE' OR st.name = 'ADMIN_DOWN') + ` rows, err := tx.Query(q, cdn) if err != nil { return nil, errors.New("Error querying servers: " + err.Error()) } defer rows.Close() + servers := map[int]ServerAndHost{} + ids := []int{} for rows.Next() { - port := sql.NullInt64{} - ip6 := sql.NullString{} - hashId := sql.NullString{} - httpsPort := sql.NullInt64{} + var port sql.NullInt64 + var hashId sql.NullString + var httpsPort sql.NullInt64 - ipIsService := false - ip6IsService := false + var s ServerAndHost - s := ServerUnion{} - - host := "" - status := "" - if err := rows.Scan(&host, &s.CacheGroup, &s.Fqdn, &hashId, &httpsPort, &s.InterfaceName, &ipIsService, &ip6IsService, &s.Ip, &ip6, &port, &s.Profile, &s.RoutingDisabled, &status, &s.ServerType); err != nil { + var status string + var id int + if err := rows.Scan(&id, &s.Host, &s.Server.CacheGroup, &s.Server.Fqdn, &hashId, &httpsPort, &port, &s.Server.Profile, &s.Server.RoutingDisabled, &status, &s.Server.ServerType); err != nil { return nil, errors.New("Error scanning server: " + err.Error()) } - if !ipIsService { - s.Ip = util.StrPtr("") - } - if !ip6IsService { - s.Ip6 = util.StrPtr("") - } else { - s.Ip6 = &ip6.String // Don't check valid, assign empty string if null - } - s.LocationId = s.CacheGroup + ids = append(ids, id) + + s.Server.LocationId = s.Server.CacheGroup serverStatus := tc.CRConfigServerStatus(status) - s.ServerStatus = &serverStatus + s.Server.ServerStatus = &serverStatus if port.Valid { i := int(port.Int64) - s.Port = &i + s.Server.Port = &i } if hashId.String != "" { - s.HashId = &hashId.String + s.Server.HashId = &hashId.String } else { - s.HashId = &host + s.Server.HashId = &s.Host } if httpsPort.Valid { i := int(httpsPort.Int64) - s.HttpsPort = &i + s.Server.HttpsPort = &i } - params, hasParams := serverParams[host] + params, hasParams := serverParams[s.Host] if hasParams && params.APIPort != nil { - s.APIPort = params.APIPort + s.Server.APIPort = params.APIPort } if hasParams && params.SecureAPIPort != nil { - s.SecureAPIPort = params.SecureAPIPort + s.Server.SecureAPIPort = params.SecureAPIPort } weightMultiplier := DefaultWeightMultiplier @@ -204,15 +197,61 @@ and (st.name = 'REPORTED' or st.name = 'ONLINE' or st.name = 'ADMIN_DOWN') weight = *params.Weight } hashCount := int(weight * weightMultiplier) - s.HashCount = &hashCount + s.Server.HashCount = &hashCount - servers[host] = s + servers[id] = s } if err := rows.Err(); err != nil { return nil, errors.New("Error iterating router param rows: " + err.Error()) } - return servers, nil + interfaces, err := dbhelpers.GetServersInterfaces(ids, tx) + if err != nil { + return nil, fmt.Errorf("getting interfaces for servers: %v", err) + } + + hostToServerMap := make(map[string]ServerUnion, len(servers)) + for id, server := range servers { + ifaces, ok := interfaces[id] + if !ok { + log.Warnf("server '%s' (#%d) has no interfaces", server.Host, id) + server.Server.InterfaceName = new(string) + server.Server.Ip = new(string) + server.Server.Ip6 = new(string) + hostToServerMap[server.Host] = server.Server + continue + } + + infs := make([]tc.ServerInterfaceInfo, 0, len(ifaces)) + for _, inf := range ifaces { + infs = append(infs, inf) + } + + legacyNet, err := tc.InterfaceInfoToLegacyInterfaces(infs) + if err != nil { + return nil, fmt.Errorf("Error converting interfaces to legacy data for server '%s' (#%d): %v", server.Host, id, err) + } + + server.Server.Ip = legacyNet.IPAddress + server.Server.Ip6 = legacyNet.IP6Address + + if server.Server.Ip == nil { + server.Server.Ip = new(string) + } + if server.Server.Ip6 == nil { + server.Server.Ip6 = new(string) + } + + server.Server.InterfaceName = legacyNet.InterfaceName + if server.Server.InterfaceName == nil { + server.Server.InterfaceName = new(string) + log.Warnf("Server %s (#%d) had no service-address-containing interfaces", server.Host, id) + } + + hostToServerMap[server.Host] = server.Server + } + + return hostToServerMap, nil } func getServerDSNames(cdn string, tx *sql.Tx) (map[tc.CacheName][]tc.DeliveryServiceName, error) { diff --git a/traffic_ops/traffic_ops_golang/crconfig/servers_test.go b/traffic_ops/traffic_ops_golang/crconfig/servers_test.go index 04e8e970dd..a83b4053ac 100644 --- a/traffic_ops/traffic_ops_golang/crconfig/servers_test.go +++ b/traffic_ops/traffic_ops_golang/crconfig/servers_test.go @@ -21,7 +21,9 @@ package crconfig import ( "context" + "fmt" "math/rand" + "net" "reflect" "testing" "time" @@ -58,18 +60,62 @@ func randFloat64() *float64 { return &f } -func randServer() tc.CRConfigTrafficOpsServer { +func randomIPv4() *string { + first := rand.Int31n(256) + second := rand.Int31n(256) + third := rand.Int31n(256) + fourth := rand.Int31n(256) + str := fmt.Sprintf("%d.%d.%d.%d", first, second, third, fourth) + return &str +} + +func randomIPv6() *string { + ip := net.IP([]byte{ + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + uint8(rand.Int31n(256)), + }).String() + return &ip +} + +func randServer(ipService bool, ip6Service bool) tc.CRConfigTrafficOpsServer { status := tc.CRConfigServerStatus(*randStr()) cachegroup := randStr() + ip := new(string) + ip6 := new(string) + inf := new(string) + + if ipService { + ip = randomIPv4() + inf = randStr() + } + if ip6Service { + ip6 = randomIPv6() + inf = randStr() + } + return tc.CRConfigTrafficOpsServer{ CacheGroup: cachegroup, Fqdn: randStr(), HashCount: randInt(), HashId: randStr(), HttpsPort: randInt(), - InterfaceName: randStr(), - Ip: randStr(), - Ip6: randStr(), + InterfaceName: inf, + Ip: ip, + Ip6: ip6, LocationId: cachegroup, Port: randInt(), Profile: randStr(), @@ -155,7 +201,7 @@ func ExpectedGetAllServers(params map[string]ServerParams, ipIsService bool, ip6 s := ServerUnion{ APIPort: param.APIPort, SecureAPIPort: param.SecureAPIPort, - CRConfigTrafficOpsServer: randServer(), + CRConfigTrafficOpsServer: randServer(ipIsService, ip6IsService), } i := int(*param.Weight * *param.WeightMultiplier) s.HashCount = &i @@ -171,11 +217,187 @@ func ExpectedGetAllServers(params map[string]ServerParams, ipIsService bool, ip6 } func MockGetAllServers(mock sqlmock.Sqlmock, expected map[string]ServerUnion, cdn string, ipIsService bool, ip6IsService bool) { - rows := sqlmock.NewRows([]string{"host_name", "cachegroup", "fqdn", "hashid", "https_port", "interface_name", "ip_address_is_service", "ip6_address_is_service", "ip_address", "ip6_address", "tcp_port", "profile_name", "routing_disabled", "status", "type"}) + serverRows := sqlmock.NewRows([]string{"id", "host_name", "cachegroup", "fqdn", "hashid", "https_port", "tcp_port", "profile_name", "routing_disabled", "status", "type"}) + interfaceRows := sqlmock.NewRows([]string{"max_bandwidth", "monitor", "mtu", "name", "server"}) + ipRows := sqlmock.NewRows([]string{"address", "gateway", "service_address", "interface", "server"}) + i := 1 for name, s := range expected { - rows = rows.AddRow(name, *s.CacheGroup, *s.Fqdn, *s.HashId, *s.HttpsPort, *s.InterfaceName, ipIsService, ip6IsService, *s.Ip, *s.Ip6, *s.Port, *s.Profile, s.RoutingDisabled, *s.ServerStatus, *s.ServerType) + serverRows = serverRows.AddRow(i, name, *s.CacheGroup, *s.Fqdn, *s.HashId, *s.HttpsPort, *s.Port, *s.Profile, s.RoutingDisabled, *s.ServerStatus, *s.ServerType) + if s.InterfaceName == nil { + i++ + continue + } + interfaceRows = interfaceRows.AddRow(nil, true, nil, *s.InterfaceName, i) + + if s.Ip != nil { + ipRows = ipRows.AddRow(*s.Ip, nil, ipIsService, *s.InterfaceName, i) + } + if s.Ip6 != nil { + ipRows = ipRows.AddRow(*s.Ip6, nil, ip6IsService, *s.InterfaceName, i) + } + i++ + } + mock.ExpectQuery("SELECT").WithArgs(cdn).WillReturnRows(serverRows) + mock.ExpectQuery("SELECT").WillReturnRows(interfaceRows) + mock.ExpectQuery("SELECT").WillReturnRows(ipRows) +} + +func compare(expected map[string]ServerUnion, actual map[string]ServerUnion, t *testing.T) { + for name, server := range expected { + actualServer, ok := actual[name] + if !ok { + t.Errorf("getAllServers expected: %v, actual: missing", name) + continue + } + + if actualServer.APIPort == nil && server.APIPort != nil { + t.Errorf("expected server '%s' to have APIPort '%s', actual: ", name, *server.APIPort) + } else if server.APIPort == nil && actualServer.APIPort != nil { + t.Errorf("expected server '%s' to have nil APIPort, actual: '%s'", name, *actualServer.APIPort) + } else if (server.APIPort != nil || actualServer.APIPort != nil) && *server.APIPort != *actualServer.APIPort { + t.Errorf("expected server '%s' to have APIPort '%s', actual: '%s'", name, *server.APIPort, *actualServer.APIPort) + } + + if actualServer.SecureAPIPort == nil && server.SecureAPIPort != nil { + t.Errorf("expected server '%s' to have SecureAPIPort '%s', actual: ", name, *server.SecureAPIPort) + } else if server.SecureAPIPort == nil && actualServer.SecureAPIPort != nil { + t.Errorf("expected server '%s' to have nil SecureAPIPort, actual: '%s'", name, *actualServer.SecureAPIPort) + } else if (server.SecureAPIPort != nil || actualServer.SecureAPIPort != nil) && *server.SecureAPIPort != *actualServer.SecureAPIPort { + t.Errorf("expected server '%s' to have SecureAPIPort '%s', actual: '%s'", name, *server.SecureAPIPort, *actualServer.SecureAPIPort) + } + + if actualServer.CacheGroup == nil && server.CacheGroup != nil { + t.Errorf("expected server '%s' to have CacheGroup '%s', actual: ", name, *server.CacheGroup) + } else if server.CacheGroup == nil && actualServer.CacheGroup != nil { + t.Errorf("expected server '%s' to have nil CacheGroup, actual: '%s'", name, *actualServer.CacheGroup) + } else if (server.CacheGroup != nil || actualServer.CacheGroup != nil) && *server.CacheGroup != *actualServer.CacheGroup { + t.Errorf("expected server '%s' to have CacheGroup '%s', actual: '%s'", name, *server.CacheGroup, *actualServer.CacheGroup) + } + + if actualServer.Fqdn == nil && server.Fqdn != nil { + t.Errorf("expected server '%s' to have Fqdn '%s', actual: ", name, *server.Fqdn) + } else if server.Fqdn == nil && actualServer.Fqdn != nil { + t.Errorf("expected server '%s' to have nil Fqdn, actual: '%s'", name, *actualServer.Fqdn) + } else if (server.Fqdn != nil || actualServer.Fqdn != nil) && *server.Fqdn != *actualServer.Fqdn { + t.Errorf("expected server '%s' to have Fqdn '%s', actual: '%s'", name, *server.Fqdn, *actualServer.Fqdn) + } + + if actualServer.HashCount == nil && server.HashCount != nil { + t.Errorf("expected server '%s' to have HashCount '%v', actual: ", name, *server.HashCount) + } else if server.HashCount == nil && actualServer.HashCount != nil { + t.Errorf("expected server '%s' to have nil HashCount, actual: '%v'", name, *actualServer.HashCount) + } else if (server.HashCount != nil || actualServer.HashCount != nil) && *server.HashCount != *actualServer.HashCount { + t.Errorf("expected server '%s' to have HashCount '%v', actual: '%v'", name, *server.HashCount, *actualServer.HashCount) + } + + if actualServer.HashId == nil && server.HashId != nil { + t.Errorf("expected server '%s' to have HashId '%v', actual: ", name, *server.HashId) + } else if server.HashId == nil && actualServer.HashId != nil { + t.Errorf("expected server '%s' to have nil HashId, actual: '%v'", name, *actualServer.HashId) + } else if (server.HashId != nil || actualServer.HashId != nil) && *server.HashId != *actualServer.HashId { + t.Errorf("expected server '%s' to have HashId '%v', actual: '%v'", name, *server.HashId, *actualServer.HashId) + } + + if actualServer.HttpsPort == nil && server.HttpsPort != nil { + t.Errorf("expected server '%s' to have HttpsPort '%v', actual: ", name, *server.HttpsPort) + } else if server.HttpsPort == nil && actualServer.HttpsPort != nil { + t.Errorf("expected server '%s' to have nil HttpsPort, actual: '%v'", name, *actualServer.HttpsPort) + } else if (server.HttpsPort != nil || actualServer.HttpsPort != nil) && *server.HttpsPort != *actualServer.HttpsPort { + t.Errorf("expected server '%s' to have HttpsPort '%v', actual: '%v'", name, *server.HttpsPort, *actualServer.HttpsPort) + } + + if actualServer.InterfaceName == nil && server.InterfaceName != nil { + t.Errorf("expected server '%s' to have InterfaceName '%v', actual: ", name, *server.InterfaceName) + } else if server.InterfaceName == nil && actualServer.InterfaceName != nil { + t.Errorf("expected server '%s' to have nil InterfaceName, actual: '%v'", name, *actualServer.InterfaceName) + } else if (server.InterfaceName != nil || actualServer.InterfaceName != nil) && *server.InterfaceName != *actualServer.InterfaceName { + t.Errorf("expected server '%s' to have InterfaceName '%v', actual: '%v'", name, *server.InterfaceName, *actualServer.InterfaceName) + } + + if actualServer.Ip == nil && server.Ip != nil { + t.Errorf("expected server '%s' to have Ip '%v', actual: ", name, *server.Ip) + } else if server.Ip == nil && actualServer.Ip != nil { + t.Errorf("expected server '%s' to have nil Ip, actual: '%v'", name, *actualServer.Ip) + } else if (server.Ip != nil || actualServer.Ip != nil) && *server.Ip != *actualServer.Ip { + t.Errorf("expected server '%s' to have Ip '%v', actual: '%v'", name, *server.Ip, *actualServer.Ip) + } + + if actualServer.Ip6 == nil && server.Ip6 != nil { + t.Errorf("expected server '%s' to have Ip6 '%v', actual: ", name, *server.Ip6) + } else if server.Ip6 == nil && actualServer.Ip6 != nil { + t.Errorf("expected server '%s' to have nil Ip6, actual: '%v'", name, *actualServer.Ip6) + } else if (server.Ip6 != nil || actualServer.Ip6 != nil) && *server.Ip6 != *actualServer.Ip6 { + t.Errorf("expected server '%s' to have Ip6 '%v', actual: '%v'", name, *server.Ip6, *actualServer.Ip6) + } + + if actualServer.LocationId == nil && server.LocationId != nil { + t.Errorf("expected server '%s' to have LocationId '%v', actual: ", name, *server.LocationId) + } else if server.LocationId == nil && actualServer.LocationId != nil { + t.Errorf("expected server '%s' to have nil LocationId, actual: '%v'", name, *actualServer.LocationId) + } else if (server.LocationId != nil || actualServer.LocationId != nil) && *server.LocationId != *actualServer.LocationId { + t.Errorf("expected server '%s' to have LocationId '%v', actual: '%v'", name, *server.LocationId, *actualServer.LocationId) + } + + if actualServer.Port == nil && server.Port != nil { + t.Errorf("expected server '%s' to have Port '%v', actual: ", name, *server.Port) + } else if server.Port == nil && actualServer.Port != nil { + t.Errorf("expected server '%s' to have nil Port, actual: '%v'", name, *actualServer.Port) + } else if (server.Port != nil || actualServer.Port != nil) && *server.Port != *actualServer.Port { + t.Errorf("expected server '%s' to have Port '%v', actual: '%v'", name, *server.Port, *actualServer.Port) + } + + if actualServer.Profile == nil && server.Profile != nil { + t.Errorf("expected server '%s' to have Profile '%v', actual: ", name, *server.Profile) + } else if server.Profile == nil && actualServer.Profile != nil { + t.Errorf("expected server '%s' to have nil Profile, actual: '%v'", name, *actualServer.Profile) + } else if (server.Profile != nil || actualServer.Profile != nil) && *server.Profile != *actualServer.Profile { + t.Errorf("expected server '%s' to have Profile '%v', actual: '%v'", name, *server.Profile, *actualServer.Profile) + } + + if actualServer.ServerStatus == nil && server.ServerStatus != nil { + t.Errorf("expected server '%s' to have ServerStatus '%v', actual: ", name, *server.ServerStatus) + } else if server.ServerStatus == nil && actualServer.ServerStatus != nil { + t.Errorf("expected server '%s' to have nil ServerStatus, actual: '%v'", name, *actualServer.ServerStatus) + } else if (server.ServerStatus != nil || actualServer.ServerStatus != nil) && *server.ServerStatus != *actualServer.ServerStatus { + t.Errorf("expected server '%s' to have ServerStatus '%v', actual: '%v'", name, *server.ServerStatus, *actualServer.ServerStatus) + } + + if actualServer.ServerType == nil && server.ServerType != nil { + t.Errorf("expected server '%s' to have ServerType '%v', actual: ", name, *server.ServerType) + } else if server.ServerType == nil && actualServer.ServerType != nil { + t.Errorf("expected server '%s' to have nil ServerType, actual: '%v'", name, *actualServer.ServerType) + } else if (server.ServerType != nil || actualServer.ServerType != nil) && *server.ServerType != *actualServer.ServerType { + t.Errorf("expected server '%s' to have ServerType '%v', actual: '%v'", name, *server.ServerType, *actualServer.ServerType) + } + + if actualServer.RoutingDisabled != server.RoutingDisabled { + t.Errorf("expected server '%s' to have RoutingDisabled '%d', actual: '%d'", name, server.RoutingDisabled, actualServer.RoutingDisabled) + } + + if len(actualServer.DeliveryServices) != len(server.DeliveryServices) { + t.Errorf("expected server '%s' to have %d DeliveryServices, actual: %d", name, len(server.DeliveryServices), len(actualServer.DeliveryServices)) + continue + } + + for dsName, dses := range server.DeliveryServices { + actualDSes, ok := actualServer.DeliveryServices[dsName] + if !ok { + t.Errorf("expected Delivery Service '%s' to be in server '%s', but it wasn't", dsName, name) + continue + } + + if len(dses) != len(actualDSes) { + t.Errorf("expected Delivery Service '%s' in server '%s' to have %d entries, actual: %d", dsName, name, len(dses), len(actualDSes)) + continue + } + + for i, ds := range dses { + if ds != actualDSes[i] { + t.Errorf("expected the %dth entry in Delivery Service '%s' in server '%s' to be '%s', actual: '%s'", i, dsName, name, ds, actualDSes[i]) + } + } + } } - mock.ExpectQuery("select").WithArgs(cdn).WillReturnRows(rows) } func TestGetAllServers(t *testing.T) { @@ -211,17 +433,7 @@ func TestGetAllServers(t *testing.T) { if len(actual) != len(expected) { t.Errorf("getAllServers len expected: %v, actual: %v", len(expected), len(actual)) } - - for name, server := range expected { - actualServer, ok := actual[name] - if !ok { - t.Errorf("getAllServers expected: %v, actual: missing", name) - continue - } - if !reflect.DeepEqual(server, actualServer) { - t.Errorf("getAllServers server %v expected: %v, actual: %v", name, server, actualServer) - } - } + compare(expected, actual, t) } func TestGetAllServersNonService(t *testing.T) { @@ -258,16 +470,7 @@ func TestGetAllServersNonService(t *testing.T) { t.Errorf("getAllServers len expected: %v, actual: %v", len(expected), len(actual)) } - for name, server := range expected { - actualServer, ok := actual[name] - if !ok { - t.Errorf("getAllServers expected: %v, actual: missing", name) - continue - } - if !reflect.DeepEqual(server, actualServer) { - t.Errorf("getAllServers server %v expected: %v, actual: %v", name, server, actualServer) - } - } + compare(expected, actual, t) } func ExpectedGetServerDSNames() map[tc.CacheName][]tc.DeliveryServiceName { diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go index cdb46f0048..f1a0f55f03 100644 --- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go +++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go @@ -519,47 +519,73 @@ WHERE s.id = $1 return row, true, nil } -// GetServerInterfaces, given the ID of a server, returns all of its network -// interfaces, or an error if one occurs during retrieval. -func GetServerInterfaces(id int, tx *sql.Tx) ([]tc.ServerInterfaceInfo, error) { +// GetServerInterfaces, given the IDs of one or more servers, returns all of their network +// interfaces mapped by their ids, or an error if one occurs during retrieval. +func GetServersInterfaces(ids []int, tx *sql.Tx) (map[int]map[string]tc.ServerInterfaceInfo, error) { q := ` - SELECT ( - json_build_object ( - 'ipAddresses', - ARRAY ( - SELECT ( - json_build_object ( - 'address', ip_address.address, - 'gateway', ip_address.gateway, - 'service_address', ip_address.service_address - ) - ) - FROM ip_address - WHERE ip_address.interface = interface.name - AND ip_address.server = $1 - ), - 'max_bandwidth', interface.max_bandwidth, - 'monitor', interface.monitor, - 'mtu', interface.mtu, - 'name', interface.name - ) - ) + SELECT max_bandwidth, + monitor, + mtu, + name, + server FROM interface - WHERE interface.server = $1 + WHERE interface.server = ANY ($1) ` - rows, err := tx.Query(q, id) + ifaceRows, err := tx.Query(q, pq.Array(ids)) if err != nil { return nil, err } - defer rows.Close() + defer ifaceRows.Close() - infs := []tc.ServerInterfaceInfo{} - for rows.Next() { + infs := map[int]map[string]tc.ServerInterfaceInfo{} + for ifaceRows.Next() { var inf tc.ServerInterfaceInfo - if err = rows.Scan(&inf); err != nil { + var server int + if err := ifaceRows.Scan(&inf.MaxBandwidth, &inf.Monitor, &inf.MTU, &inf.Name, &server); err != nil { + return nil, err + } + + if _, ok := infs[server]; !ok { + infs[server] = make(map[string]tc.ServerInterfaceInfo) + } + + infs[server][inf.Name] = inf + } + + q = ` + SELECT address, + gateway, + service_address, + interface, + server + FROM ip_address + WHERE ip_address.server = ANY ($1) + ` + ipRows, err := tx.Query(q, pq.Array(ids)) + if err != nil { + return nil, err + } + defer ipRows.Close() + + for ipRows.Next() { + var ip tc.ServerIPAddress + var inf string + var server int + if err = ipRows.Scan(&ip.Address, &ip.Gateway, &ip.ServiceAddress, &inf, &server); err != nil { return nil, err } - infs = append(infs, inf) + + ifaces, ok := infs[server] + if !ok { + return nil, fmt.Errorf("retrieved ip_address with server not previously found: %d", server) + } + + iface, ok := ifaces[inf] + if !ok { + return nil, fmt.Errorf("retrieved ip_address with interface not previously found: %s", inf) + } + iface.IPAddresses = append(iface.IPAddresses, ip) + infs[server][inf] = iface } return infs, nil