diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index ec8affbf175..ec0ec020f23 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -196,10 +196,7 @@ func registerFaultHandlers( agentState *v4.TMDSAgentState, metricsFactory metrics.EntryFactory, ) { - handler := fault.FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } + handler := fault.New(agentState, metricsFactory) if muxRouter == nil { return diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index 4daefdbb766..d033512164b 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -161,7 +161,12 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) Path: task.GetNetworkNamespace(), NetworkInterfaces: []*tmdsv4.NetworkInterface{ { - DeviceName: "", + // TODO: fetch the correct device name. + // We are exposing this information via AgentState to facilitate the fault injection + // handler to start/stop/check network faults. + // Use 'eth0'(a fake value) for existing fault injection related unit tests for now and + // it will be updated in the future. + DeviceName: "eth0", }, }, }, diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index d550dafbd20..438d0cdbc74 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net/http" + "sync" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -42,18 +43,35 @@ const ( ) type FaultHandler struct { - // TODO: Mutex will be used in a future PR - // mu sync.Mutex + // mutexMap is used to avoid multiple clients to manipulate same resource at same + // time. The 'key' is the the network namespace path and 'value' is the RWMutex. + // Using concurrent map here because the handler is shared by all requests. + mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory } +func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { + return &FaultHandler{ + AgentState: agentState, + MetricsFactory: mf, + mutexMap: sync.Map{}, + } +} + // NetworkFaultPath will take in a fault type and return the TMDS endpoint path func NetworkFaultPath(fault string) string { return fmt.Sprintf("/api/%s/fault/v1/%s", utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) } +// loadLock returns the lock associated with given key. +func (h *FaultHandler) loadLock(key string) *sync.RWMutex { + mu := new(sync.RWMutex) + actualMu, _ := h.mutexMap.LoadOrStore(key, mu) + return actualMu.(*sync.RWMutex) +} + // StartNetworkBlackholePort will return the request handler function for starting a network blackhole port fault func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -72,12 +90,17 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -118,12 +141,17 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -164,14 +192,18 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } - // Check status of current fault injection + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") logger.Info("Successfully checked status for fault", logger.Fields{ @@ -206,12 +238,17 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -248,12 +285,17 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -290,12 +332,17 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -331,12 +378,17 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -373,12 +425,17 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -416,11 +473,17 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. // Obtain the task metadata via the endpoint container ID // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -613,9 +676,19 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return errors.New("empty network namespaces within task network config") } + // Task network namespace path is required to inject faults in the associated task. + if taskNetworkConfig.NetworkNamespaces[0].Path == "" { + return errors.New("no path in the network namespace within task network config") + } + if len(taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces) == 0 || taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces == nil { return errors.New("empty network interfaces within task network config") } + // Device name is required to inject network faults to given ENI in the task. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName == "" { + return errors.New("no ENI device name in the network namespace within task network config") + } + return nil } diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go index 124c2f961c7..2e1b9b8e468 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go @@ -56,8 +56,6 @@ type NetworkNamespace struct { type NetworkInterface struct { // DeviceName is the device name on the host. DeviceName string - // ENIID is the id of eni. - ENIID string } // Instance's clock drift status diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index d550dafbd20..438d0cdbc74 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net/http" + "sync" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -42,18 +43,35 @@ const ( ) type FaultHandler struct { - // TODO: Mutex will be used in a future PR - // mu sync.Mutex + // mutexMap is used to avoid multiple clients to manipulate same resource at same + // time. The 'key' is the the network namespace path and 'value' is the RWMutex. + // Using concurrent map here because the handler is shared by all requests. + mutexMap sync.Map AgentState state.AgentState MetricsFactory metrics.EntryFactory } +func New(agentState state.AgentState, mf metrics.EntryFactory) *FaultHandler { + return &FaultHandler{ + AgentState: agentState, + MetricsFactory: mf, + mutexMap: sync.Map{}, + } +} + // NetworkFaultPath will take in a fault type and return the TMDS endpoint path func NetworkFaultPath(fault string) string { return fmt.Sprintf("/api/%s/fault/v1/%s", utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx), fault) } +// loadLock returns the lock associated with given key. +func (h *FaultHandler) loadLock(key string) *sync.RWMutex { + mu := new(sync.RWMutex) + actualMu, _ := h.mutexMap.LoadOrStore(key, mu) + return actualMu.(*sync.RWMutex) +} + // StartNetworkBlackholePort will return the request handler function for starting a network blackhole port fault func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { @@ -72,12 +90,17 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -118,12 +141,17 @@ func (h *FaultHandler) StopNetworkBlackHolePort() func(http.ResponseWriter, *htt }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -164,14 +192,18 @@ func (h *FaultHandler) CheckNetworkBlackHolePort() func(http.ResponseWriter, *ht }) // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } - // Check status of current fault injection + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") logger.Info("Successfully checked status for fault", logger.Fields{ @@ -206,12 +238,17 @@ func (h *FaultHandler) StartNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -248,12 +285,17 @@ func (h *FaultHandler) StopNetworkLatency() func(http.ResponseWriter, *http.Requ } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -290,12 +332,17 @@ func (h *FaultHandler) CheckNetworkLatency() func(http.ResponseWriter, *http.Req } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -331,12 +378,17 @@ func (h *FaultHandler) StartNetworkPacketLoss() func(http.ResponseWriter, *http. } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the start fault injection functionality if not running @@ -373,12 +425,17 @@ func (h *FaultHandler) StopNetworkPacketLoss() func(http.ResponseWriter, *http.R } // Obtain the task metadata via the endpoint container ID - // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.Lock() + defer rwMu.Unlock() + // TODO: Check status of current fault injection // TODO: Invoke the stop fault injection functionality if running @@ -416,11 +473,17 @@ func (h *FaultHandler) CheckNetworkPacketLoss() func(http.ResponseWriter, *http. // Obtain the task metadata via the endpoint container ID // TODO: Will be using the returned task metadata in a future PR - _, err = validateTaskMetadata(w, h.AgentState, requestType, r) + taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r) if err != nil { return } + // To avoid multiple requests to manipulate same network resource + networkNSPath := taskMetadata.TaskNetworkConfig.NetworkNamespaces[0].Path + rwMu := h.loadLock(networkNSPath) + rwMu.RLock() + defer rwMu.RUnlock() + // TODO: Check status of current fault injection // TODO: Return the correct status state responseBody := types.NewNetworkFaultInjectionSuccessResponse("running") @@ -613,9 +676,19 @@ func validateTaskNetworkConfig(taskNetworkConfig *state.TaskNetworkConfig) error return errors.New("empty network namespaces within task network config") } + // Task network namespace path is required to inject faults in the associated task. + if taskNetworkConfig.NetworkNamespaces[0].Path == "" { + return errors.New("no path in the network namespace within task network config") + } + if len(taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces) == 0 || taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces == nil { return errors.New("empty network interfaces within task network config") } + // Device name is required to inject network faults to given ENI in the task. + if taskNetworkConfig.NetworkNamespaces[0].NetworkInterfaces[0].DeviceName == "" { + return errors.New("no ENI device name in the network namespace within task network config") + } + return nil } diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index 7b01b8acd91..c810a48b94c 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -53,6 +53,12 @@ const ( ) var ( + noDeviceNameInNetworkInterfaces = []*state.NetworkInterface{ + { + DeviceName: "", + }, + } + happyNetworkInterfaces = []*state.NetworkInterface{ { DeviceName: deviceName, @@ -66,6 +72,13 @@ var ( }, } + noPathInNetworkNamespaces = []*state.NetworkNamespace{ + { + Path: "", + NetworkInterfaces: happyNetworkInterfaces, + }, + } + happyTaskNetworkConfig = state.TaskNetworkConfig{ NetworkMode: awsvpcNetworkMode, NetworkNamespaces: happyNetworkNamespaces, @@ -114,7 +127,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -128,7 +143,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -201,7 +218,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")). + Times(1) }, }, { @@ -211,7 +230,8 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( - "Unable to generate metadata for task")) + "Unable to generate metadata for task")). + Times(1) }, }, { @@ -220,7 +240,9 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, errors.New("unknown error")). + Times(1) }, }, { @@ -232,7 +254,7 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: false, - }, nil) + }, nil).Times(1) }, }, { @@ -248,20 +270,77 @@ func generateNetworkBlackHolePortTestCases(name string, expectedHappyResponseBod NetworkMode: invalidNetworkMode, NetworkNamespaces: happyNetworkNamespaces, }, - }, nil) + }, nil).Times(1) }, }, { - name: fmt.Sprintf("%s empty task network config", name), - expectedStatusCode: 500, - requestBody: happyBlackHolePortReqBody, - expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + name: fmt.Sprintf("%s empty task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, FaultInjectionEnabled: true, TaskNetworkConfig: nil, - }, nil) + }, nil).Times(1) + }, + }, + { + name: fmt.Sprintf("%s no task network namespace", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: nil, + }, + }, nil).Times(1) + }, + }, + { + name: fmt.Sprintf("%s no path in task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: noPathInNetworkNamespaces, + }, + }, nil).Times(1) + }, + }, + { + name: fmt.Sprintf("%s no device name in task network config", name), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse( + fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + FaultInjectionEnabled: true, + TaskNetworkConfig: &state.TaskNetworkConfig{ + NetworkMode: awsvpcNetworkMode, + NetworkNamespaces: []*state.NetworkNamespace{ + &state.NetworkNamespace{ + Path: "/path", + NetworkInterfaces: noDeviceNameInNetworkInterfaces, + }, + }, + }, + }, nil).Times(1) }, }, } @@ -285,12 +364,7 @@ func TestStartNetworkBlackHolePort(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.BlackHolePortFaultType), handler.StartNetworkBlackholePort(), @@ -338,12 +412,7 @@ func TestStopNetworkBlackHolePort(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.BlackHolePortFaultType), handler.StopNetworkBlackHolePort(), @@ -388,11 +457,7 @@ func TestCheckNetworkBlackHolePort(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } + handler := New(agentState, metricsFactory) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) @@ -442,7 +507,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -456,7 +523,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse(expectedHappyResponseBody), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(happyTaskResponse, nil) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(happyTaskResponse, nil). + Times(1) }, }, { @@ -469,7 +538,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.DelayMilliseconds of type uint64"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -482,7 +553,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.JitterMilliseconds of type uint64"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -495,7 +568,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("json: cannot unmarshal string into Go struct field NetworkLatencyRequest.Sources of type []*string"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -508,7 +583,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -520,7 +597,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter Sources is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -532,7 +611,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter JitterMilliseconds is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -544,7 +625,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("required parameter DelayMilliseconds is missing"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -557,7 +640,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 10.1.2.3.4 for parameter Sources"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -570,7 +655,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n }, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value 52.95.154.0/33 for parameter Sources"), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, nil).Times(0) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, nil). + Times(0) }, }, { @@ -579,7 +666,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to lookup container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorLookupFailure("task lookup failed")). + Times(1) }, }, { @@ -588,8 +677,10 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("unable to obtain container metadata for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( - "Unable to generate metadata for task")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, state.NewErrorMetadataFetchFailure( + "Unable to generate metadata for task")). + Times(1) }, }, { @@ -598,7 +689,9 @@ func generateNetworkLatencyTestCases(name, expectedHappyResponseBody string) []n requestBody: happyNetworkLatencyReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(fmt.Sprintf("failed to get task metadata due to internal server error for container: %s", endpointId)), setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { - agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, errors.New("unknown error")) + agentState.EXPECT().GetTaskMetadata(endpointId). + Return(state.TaskResponse{}, errors.New("unknown error")). + Times(1) }, }, { @@ -663,12 +756,7 @@ func TestStartNetworkLatency(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.LatencyFaultType), handler.StartNetworkLatency(), @@ -716,12 +804,7 @@ func TestStopNetworkLatency(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.LatencyFaultType), handler.StopNetworkLatency(), @@ -766,12 +849,7 @@ func TestCheckNetworkLatency(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } @@ -1043,12 +1121,7 @@ func TestStartNetworkPacketLoss(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), handler.StartNetworkPacketLoss(), @@ -1096,12 +1169,7 @@ func TestStopNetworkPacketLoss(t *testing.T) { } router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) router.HandleFunc( NetworkFaultPath(types.PacketLossFaultType), handler.StopNetworkPacketLoss(), @@ -1146,12 +1214,7 @@ func TestCheckNetworkPacketLoss(t *testing.T) { metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) router := mux.NewRouter() - - handler := FaultHandler{ - AgentState: agentState, - MetricsFactory: metricsFactory, - } - + handler := New(agentState, metricsFactory) if tc.setAgentStateExpectations != nil { tc.setAgentStateExpectations(agentState) } diff --git a/ecs-agent/tmds/handlers/v4/handlers_test.go b/ecs-agent/tmds/handlers/v4/handlers_test.go index 55d7a5012ec..388170f7244 100644 --- a/ecs-agent/tmds/handlers/v4/handlers_test.go +++ b/ecs-agent/tmds/handlers/v4/handlers_test.go @@ -167,7 +167,6 @@ func taskResponse() *state.TaskResponse { NetworkInterfaces: []*state.NetworkInterface{ &state.NetworkInterface{ DeviceName: "eth1", - ENIID: "eni-013ff4ad5747a0f6a", }, }, }, diff --git a/ecs-agent/tmds/handlers/v4/state/response.go b/ecs-agent/tmds/handlers/v4/state/response.go index 124c2f961c7..2e1b9b8e468 100644 --- a/ecs-agent/tmds/handlers/v4/state/response.go +++ b/ecs-agent/tmds/handlers/v4/state/response.go @@ -56,8 +56,6 @@ type NetworkNamespace struct { type NetworkInterface struct { // DeviceName is the device name on the host. DeviceName string - // ENIID is the id of eni. - ENIID string } // Instance's clock drift status