diff --git a/agent/app/agent.go b/agent/app/agent.go index 42d3f210e88..502c8b15275 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -15,6 +15,7 @@ package app import ( "context" + "encoding/json" "errors" "fmt" "time" @@ -626,13 +627,32 @@ func (agent *ecsAgent) startSpotInstanceDrainingPoller(client api.ECSClient) { } } -// spotInstanceDrainingPoller returns true if spot instance termination time has been +// spotInstanceDrainingPoller returns true if spot instance interruption has been // set AND the container instance state is successfully updated to DRAINING. func (agent *ecsAgent) spotInstanceDrainingPoller(client api.ECSClient) bool { - // this endpoint 404s unless a termination time has been set, so expect failure in most cases. - termtime, err := agent.ec2MetadataClient.SpotTerminationTime() - if err == nil && len(termtime) > 0 { - seelog.Infof("Received a spot termination time (%s), setting state to DRAINING", termtime) + // this endpoint 404s unless a interruption has been set, so expect failure in most cases. + resp, err := agent.ec2MetadataClient.SpotInstanceAction() + if err == nil { + type InstanceAction struct { + Time string + Action string + } + ia := InstanceAction{} + + err := json.Unmarshal([]byte(resp), &ia) + if err != nil { + seelog.Errorf("Invalid response from /spot/instance-action endpoint: %s Error: %s", resp, err) + return false + } + + switch ia.Action { + case "hibernate", "terminate", "stop": + default: + seelog.Errorf("Invalid response from /spot/instance-action endpoint: %s, Error: unrecognized action (%s)", resp, ia.Action) + return false + } + + seelog.Infof("Received a spot interruption (%s) scheduled for %s, setting state to DRAINING", ia.Action, ia.Time) err = client.UpdateContainerInstancesState(agent.containerInstanceARN, "DRAINING") if err != nil { seelog.Errorf("Error setting instance [ARN: %s] state to DRAINING: %s", agent.containerInstanceARN, err) diff --git a/agent/app/agent_test.go b/agent/app/agent_test.go index f57b73a8ecc..7ee35e3ea9e 100644 --- a/agent/app/agent_test.go +++ b/agent/app/agent_test.go @@ -1182,7 +1182,14 @@ func TestGetHostPublicIPv4AddressFromEC2MetadataFailWithError(t *testing.T) { assert.Empty(t, agent.getHostPublicIPv4AddressFromEC2Metadata()) } -func TestSpotTerminationTimeCheck_Yes(t *testing.T) { +func TestSpotInstanceActionCheck_Sunny(t *testing.T) { + tests := []struct { + jsonresp string + }{ + {jsonresp: `{"action": "terminate", "time": "2017-09-18T08:22:00Z"}`}, + {jsonresp: `{"action": "hibernate", "time": "2017-09-18T08:22:00Z"}`}, + {jsonresp: `{"action": "stop", "time": "2017-09-18T08:22:00Z"}`}, + } ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1190,19 +1197,28 @@ func TestSpotTerminationTimeCheck_Yes(t *testing.T) { ec2Client := mock_ec2.NewMockClient(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - myARN := "myARN" - agent := &ecsAgent{ - ec2MetadataClient: ec2MetadataClient, - ec2Client: ec2Client, - containerInstanceARN: myARN, + for _, test := range tests { + myARN := "myARN" + agent := &ecsAgent{ + ec2MetadataClient: ec2MetadataClient, + ec2Client: ec2Client, + containerInstanceARN: myARN, + } + ec2MetadataClient.EXPECT().SpotInstanceAction().Return(test.jsonresp, nil) + ecsClient.EXPECT().UpdateContainerInstancesState(myARN, "DRAINING").Return(nil) + + assert.True(t, agent.spotInstanceDrainingPoller(ecsClient)) } - ec2MetadataClient.EXPECT().SpotTerminationTime().Return("2019-08-26T18:21:08Z", nil) - ecsClient.EXPECT().UpdateContainerInstancesState(myARN, "DRAINING").Return(nil) - - assert.True(t, agent.spotInstanceDrainingPoller(ecsClient)) } -func TestSpotTerminationTimeCheck_EmptyTimestamp(t *testing.T) { +func TestSpotInstanceActionCheck_Fail(t *testing.T) { + tests := []struct { + jsonresp string + }{ + {jsonresp: `{"action": "terminate" "time": "2017-09-18T08:22:00Z"}`}, // invalid json + {jsonresp: ``}, // empty json + {jsonresp: `{"action": "flip!", "time": "2017-09-18T08:22:00Z"}`}, // invalid action + } ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1210,20 +1226,22 @@ func TestSpotTerminationTimeCheck_EmptyTimestamp(t *testing.T) { ec2Client := mock_ec2.NewMockClient(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) - myARN := "myARN" - agent := &ecsAgent{ - ec2MetadataClient: ec2MetadataClient, - ec2Client: ec2Client, - containerInstanceARN: myARN, + for _, test := range tests { + myARN := "myARN" + agent := &ecsAgent{ + ec2MetadataClient: ec2MetadataClient, + ec2Client: ec2Client, + containerInstanceARN: myARN, + } + ec2MetadataClient.EXPECT().SpotInstanceAction().Return(test.jsonresp, nil) + // Container state should NOT be updated because the termination time field is empty. + ecsClient.EXPECT().UpdateContainerInstancesState(gomock.Any(), gomock.Any()).Times(0) + + assert.False(t, agent.spotInstanceDrainingPoller(ecsClient)) } - ec2MetadataClient.EXPECT().SpotTerminationTime().Return("", nil) - // Container state should NOT be updated because the termination time field is empty. - ecsClient.EXPECT().UpdateContainerInstancesState(gomock.Any(), gomock.Any()).Times(0) - - assert.False(t, agent.spotInstanceDrainingPoller(ecsClient)) } -func TestSpotTerminationTimeCheck_No(t *testing.T) { +func TestSpotInstanceActionCheck_NoInstanceActionYet(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1237,7 +1255,7 @@ func TestSpotTerminationTimeCheck_No(t *testing.T) { ec2Client: ec2Client, containerInstanceARN: myARN, } - ec2MetadataClient.EXPECT().SpotTerminationTime().Return("", fmt.Errorf("404")) + ec2MetadataClient.EXPECT().SpotInstanceAction().Return("", fmt.Errorf("404")) // Container state should NOT be updated because there is no termination time. ecsClient.EXPECT().UpdateContainerInstancesState(gomock.Any(), gomock.Any()).Times(0) diff --git a/agent/ec2/blackhole_ec2_metadata_client.go b/agent/ec2/blackhole_ec2_metadata_client.go index a44d6d57615..05c51523981 100644 --- a/agent/ec2/blackhole_ec2_metadata_client.go +++ b/agent/ec2/blackhole_ec2_metadata_client.go @@ -77,6 +77,6 @@ func (blackholeMetadataClient) PublicIPv4Address() (string, error) { return "", errors.New("blackholed") } -func (blackholeMetadataClient) SpotTerminationTime() (string, error) { +func (blackholeMetadataClient) SpotInstanceAction() (string, error) { return "", errors.New("blackholed") } diff --git a/agent/ec2/ec2_metadata_client.go b/agent/ec2/ec2_metadata_client.go index 6b429d07342..5e37d610a4d 100644 --- a/agent/ec2/ec2_metadata_client.go +++ b/agent/ec2/ec2_metadata_client.go @@ -33,7 +33,7 @@ const ( AllMacResource = "network/interfaces/macs" VPCIDResourceFormat = "network/interfaces/macs/%s/vpc-id" SubnetIDResourceFormat = "network/interfaces/macs/%s/subnet-id" - SpotTerminationTimeResource = "spot/termination-time" + SpotInstanceActionResource = "spot/instance-action" InstanceIDResource = "instance-id" PrivateIPv4Resource = "local-ipv4" PublicIPv4Resource = "public-ipv4" @@ -77,7 +77,7 @@ type EC2MetadataClient interface { Region() (string, error) PrivateIPv4Address() (string, error) PublicIPv4Address() (string, error) - SpotTerminationTime() (string, error) + SpotInstanceAction() (string, error) } type ec2MetadataClientImpl struct { @@ -187,9 +187,10 @@ func (c *ec2MetadataClientImpl) PrivateIPv4Address() (string, error) { return c.client.GetMetadata(PrivateIPv4Resource) } -// SpotTerminationTime returns the spot termination time, if it has been set. -// If the time has not been set (ie, the instance is not scheduled for termination) +// SpotInstanceAction returns the spot instance-action, if it has been set. +// If the time has not been set (ie, the instance is not scheduled for interruption) // then this function returns an error. -func (c *ec2MetadataClientImpl) SpotTerminationTime() (string, error) { - return c.client.GetMetadata(SpotTerminationTimeResource) +// see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-interruptions.html#using-spot-instances-managing-interruptions +func (c *ec2MetadataClientImpl) SpotInstanceAction() (string, error) { + return c.client.GetMetadata(SpotInstanceActionResource) } diff --git a/agent/ec2/ec2_metadata_client_test.go b/agent/ec2/ec2_metadata_client_test.go index ad808adf845..1983540a6e4 100644 --- a/agent/ec2/ec2_metadata_client_test.go +++ b/agent/ec2/ec2_metadata_client_test.go @@ -225,7 +225,7 @@ func TestPublicIPv4Address(t *testing.T) { assert.Equal(t, publicIP, publicIPResponse) } -func TestSpotTerminationTime(t *testing.T) { +func TestSpotInstanceAction(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -233,13 +233,13 @@ func TestSpotTerminationTime(t *testing.T) { testClient := ec2.NewEC2MetadataClient(mockGetter) mockGetter.EXPECT().GetMetadata( - ec2.SpotTerminationTimeResource).Return("2019-08-26T17:54:20Z", nil) - resp, err := testClient.SpotTerminationTime() + ec2.SpotInstanceActionResource).Return("{\"action\": \"terminate\", \"time\": \"2017-09-18T08:22:00Z\"}", nil) + resp, err := testClient.SpotInstanceAction() assert.NoError(t, err) - assert.Equal(t, "2019-08-26T17:54:20Z", resp) + assert.Equal(t, "{\"action\": \"terminate\", \"time\": \"2017-09-18T08:22:00Z\"}", resp) } -func TestSpotTerminationTimeError(t *testing.T) { +func TestSpotInstanceActionError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -247,8 +247,8 @@ func TestSpotTerminationTimeError(t *testing.T) { testClient := ec2.NewEC2MetadataClient(mockGetter) mockGetter.EXPECT().GetMetadata( - ec2.SpotTerminationTimeResource).Return("", fmt.Errorf("ERROR")) - resp, err := testClient.SpotTerminationTime() + ec2.SpotInstanceActionResource).Return("", fmt.Errorf("ERROR")) + resp, err := testClient.SpotInstanceAction() assert.Error(t, err) assert.Equal(t, "", resp) } diff --git a/agent/ec2/mocks/ec2_mocks.go b/agent/ec2/mocks/ec2_mocks.go index 6385bca3e15..8ffa9430d16 100644 --- a/agent/ec2/mocks/ec2_mocks.go +++ b/agent/ec2/mocks/ec2_mocks.go @@ -216,19 +216,19 @@ func (mr *MockEC2MetadataClientMockRecorder) Region() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Region", reflect.TypeOf((*MockEC2MetadataClient)(nil).Region)) } -// SpotTerminationTime mocks base method -func (m *MockEC2MetadataClient) SpotTerminationTime() (string, error) { +// SpotInstanceAction mocks base method +func (m *MockEC2MetadataClient) SpotInstanceAction() (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SpotTerminationTime") + ret := m.ctrl.Call(m, "SpotInstanceAction") ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } -// SpotTerminationTime indicates an expected call of SpotTerminationTime -func (mr *MockEC2MetadataClientMockRecorder) SpotTerminationTime() *gomock.Call { +// SpotInstanceAction indicates an expected call of SpotInstanceAction +func (mr *MockEC2MetadataClientMockRecorder) SpotInstanceAction() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpotTerminationTime", reflect.TypeOf((*MockEC2MetadataClient)(nil).SpotTerminationTime)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpotInstanceAction", reflect.TypeOf((*MockEC2MetadataClient)(nil).SpotInstanceAction)) } // SubnetID mocks base method