Skip to content

Commit

Permalink
Fix API (deliveryserviceserver and deliveryservices/dsName/servers) s…
Browse files Browse the repository at this point in the history
…hould not assign Server from different CDN to Delivery Service (#4754)

* Fix API (deliveryserviceserver and deliveryservices/dsName/servers) should not assign Server from different CDN to Delivery Service

* Formatting

* Code review
  • Loading branch information
srijeet0406 authored Jun 11, 2020
1 parent 1ed80c6 commit 74bc7f2
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 23 deletions.
24 changes: 14 additions & 10 deletions traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,15 +644,18 @@ func GetServerNameFromID(tx *sql.Tx, id int) (string, bool, error) {
return name, true, nil
}

type ServerHostNameAndType struct {
type ServerHostNameCDNIDAndType struct {
HostName string
CDNID int
Type string
}

func GetServerHostNamesAndTypesFromIDs(tx *sql.Tx, ids []int) ([]ServerHostNameAndType, error) {
// GetServerHostNamesAndTypesFromIDs returns the server's hostname, cdn ID and associated type name
func GetServerHostNamesAndTypesFromIDs(tx *sql.Tx, ids []int) ([]ServerHostNameCDNIDAndType, error) {
qry := `
SELECT
s.host_name,
s.cdn_id,
t.name
FROM
server s JOIN type t ON s.type = t.id
Expand All @@ -665,22 +668,23 @@ WHERE
}
defer log.Close(rows, "error closing rows")

servers := []ServerHostNameAndType{}
servers := []ServerHostNameCDNIDAndType{}
for rows.Next() {
s := ServerHostNameAndType{}
if err := rows.Scan(&s.HostName, &s.Type); err != nil {
s := ServerHostNameCDNIDAndType{}
if err := rows.Scan(&s.HostName, &s.CDNID, &s.Type); err != nil {
return nil, errors.New("scanning server host name and type: " + err.Error())
}
servers = append(servers, s)
}
return servers, nil
}

// GetServerTypesFromHostNames returns the host names and types of the given server host names or an error if any occur.
func GetServerTypesFromHostNames(tx *sql.Tx, hostNames []string) ([]ServerHostNameAndType, error) {
// GetServerTypesCdnIdFromHostNames returns the host names, server cdn and types of the given server host names or an error if any occur.
func GetServerTypesCdnIdFromHostNames(tx *sql.Tx, hostNames []string) ([]ServerHostNameCDNIDAndType, error) {
qry := `
SELECT
s.host_name,
s.cdn_id,
t.name
FROM
server s JOIN type t ON s.type = t.id
Expand All @@ -693,10 +697,10 @@ WHERE
}
defer log.Close(rows, "error closing rows")

servers := []ServerHostNameAndType{}
servers := []ServerHostNameCDNIDAndType{}
for rows.Next() {
s := ServerHostNameAndType{}
if err := rows.Scan(&s.HostName, &s.Type); err != nil {
s := ServerHostNameCDNIDAndType{}
if err := rows.Scan(&s.HostName, &s.CDNID, &s.Type); err != nil {
return nil, errors.New("scanning server host name and type: " + err.Error())
}
servers = append(servers, s)
Expand Down
33 changes: 20 additions & 13 deletions traffic_ops/traffic_ops_golang/deliveryservice/servers/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,19 +319,18 @@ func GetReplaceHandler(w http.ResponseWriter, r *http.Request) {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
serverNamesAndTypes, err := dbhelpers.GetServerHostNamesAndTypesFromIDs(inf.Tx.Tx, servers)
serverNamesCdnIdAndTypes, err := dbhelpers.GetServerHostNamesAndTypesFromIDs(inf.Tx.Tx, servers)
if err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, err, nil)
return
}

userErr = ValidateDSSAssignments(ds, serverNamesAndTypes)
userErr = ValidateDSSAssignments(ds, serverNamesCdnIdAndTypes)
if userErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, userErr, nil)
return
}

usrErr, sysErr, status := ValidateServerCapabilities(ds.ID, serverNamesAndTypes, inf.Tx.Tx)
usrErr, sysErr, status := ValidateServerCapabilities(ds.ID, serverNamesCdnIdAndTypes, inf.Tx.Tx)
if usrErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, status, usrErr, sysErr)
return
Expand Down Expand Up @@ -401,19 +400,19 @@ func GetCreateHandler(w http.ResponseWriter, r *http.Request) {
payload.XmlId = dsName
serverNames := payload.ServerNames

serverNamesAndTypes, err := dbhelpers.GetServerTypesFromHostNames(inf.Tx.Tx, serverNames)
serverNamesCdnIdAndTypes, err := dbhelpers.GetServerTypesCdnIdFromHostNames(inf.Tx.Tx, serverNames)
if err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, err, nil)
return
}

userErr = ValidateDSSAssignments(ds, serverNamesAndTypes)
userErr = ValidateDSSAssignments(ds, serverNamesCdnIdAndTypes)
if userErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, userErr, nil)
return
}

usrErr, sysErr, status := ValidateServerCapabilities(ds.ID, serverNamesAndTypes, inf.Tx.Tx)
usrErr, sysErr, status := ValidateServerCapabilities(ds.ID, serverNamesCdnIdAndTypes, inf.Tx.Tx)
if usrErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, status, usrErr, sysErr)
return
Expand Down Expand Up @@ -445,8 +444,13 @@ func GetCreateHandler(w http.ResponseWriter, r *http.Request) {
}

// ValidateDSSAssignments returns an error if the given servers cannot be assigned to the given delivery service.
func ValidateDSSAssignments(ds DSInfo, servers []dbhelpers.ServerHostNameAndType) error {
func ValidateDSSAssignments(ds DSInfo, servers []dbhelpers.ServerHostNameCDNIDAndType) error {
if ds.Topology == nil {
for _, s := range servers {
if ds.CDNID != nil && s.CDNID != *ds.CDNID {
return errors.New("server and delivery service CDNs do not match")
}
}
return nil
}
for _, s := range servers {
Expand All @@ -458,7 +462,7 @@ func ValidateDSSAssignments(ds DSInfo, servers []dbhelpers.ServerHostNameAndType
}

// ValidateServerCapabilities checks that the delivery service's requirements are met by each server to be assigned.
func ValidateServerCapabilities(dsID int, serverNamesAndTypes []dbhelpers.ServerHostNameAndType, tx *sql.Tx) (error, error, int) {
func ValidateServerCapabilities(dsID int, serverNamesAndTypes []dbhelpers.ServerHostNameCDNIDAndType, tx *sql.Tx) (error, error, int) {
nonOriginServerNames := []string{}
for _, s := range serverNamesAndTypes {
if strings.HasPrefix(s.Type, tc.EdgeTypePrefix) {
Expand Down Expand Up @@ -682,6 +686,7 @@ type DSInfo struct {
CacheURL *string
MaxOriginConnections *int
Topology *string
CDNID *int
}

// GetDSInfo loads the DeliveryService fields needed by Delivery Service Servers from the database, from the ID. Returns the data, whether the delivery service was found, and any error.
Expand All @@ -696,15 +701,16 @@ SELECT
ds.signing_algorithm,
ds.cacheurl,
ds.max_origin_connections,
ds.topology
ds.topology,
ds.cdn_id
FROM
deliveryservice ds
JOIN type tp ON ds.type = tp.id
WHERE
ds.id = $1
`
di := DSInfo{ID: id}
if err := tx.QueryRow(qry, id).Scan(&di.Name, &di.Type, &di.EdgeHeaderRewrite, &di.MidHeaderRewrite, &di.RegexRemap, &di.SigningAlgorithm, &di.CacheURL, &di.MaxOriginConnections, &di.Topology); err != nil {
if err := tx.QueryRow(qry, id).Scan(&di.Name, &di.Type, &di.EdgeHeaderRewrite, &di.MidHeaderRewrite, &di.RegexRemap, &di.SigningAlgorithm, &di.CacheURL, &di.MaxOriginConnections, &di.Topology, &di.CDNID); err != nil {
if err == sql.ErrNoRows {
return DSInfo{}, false, nil
}
Expand All @@ -726,15 +732,16 @@ SELECT
ds.signing_algorithm,
ds.cacheurl,
ds.max_origin_connections,
ds.topology
ds.topology,
ds.cdn_id
FROM
deliveryservice ds
JOIN type tp ON ds.type = tp.id
WHERE
ds.xml_id = $1
`
di := DSInfo{Name: dsName}
if err := tx.QueryRow(qry, dsName).Scan(&di.ID, &di.Type, &di.EdgeHeaderRewrite, &di.MidHeaderRewrite, &di.RegexRemap, &di.SigningAlgorithm, &di.CacheURL, &di.MaxOriginConnections, &di.Topology); err != nil {
if err := tx.QueryRow(qry, dsName).Scan(&di.ID, &di.Type, &di.EdgeHeaderRewrite, &di.MidHeaderRewrite, &di.RegexRemap, &di.SigningAlgorithm, &di.CacheURL, &di.MaxOriginConnections, &di.Topology, &di.CDNID); err != nil {
if err == sql.ErrNoRows {
return DSInfo{}, false, nil
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package servers

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

import (
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
"testing"
)

func TestValidateDSSAssignments(t *testing.T) {
expected := `server and delivery service CDNs do not match`
cdnID := 1
ds := DSInfo{
ID: 0,
CDNID: &cdnID,
}
var servers []dbhelpers.ServerHostNameCDNIDAndType
server := dbhelpers.ServerHostNameCDNIDAndType{
HostName: "serverHost",
CDNID: 0,
Type: "",
}
servers = append(servers, server)
userErr := ValidateDSSAssignments(ds, servers)
if userErr == nil {
t.Fatalf("Expected user error with mismatching ds and server CDN IDs, got no error instead")
}
if userErr.Error() != expected {
t.Errorf("Expected error details %v, got %v", expected, userErr.Error())
}
servers[0].CDNID = 1
userErr = ValidateDSSAssignments(ds, servers)
if userErr != nil {
t.Fatalf("Expected no user error, got %v", userErr.Error())
}
}

0 comments on commit 74bc7f2

Please sign in to comment.