diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go index 978304d57cf..8b5ed661e18 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go @@ -294,8 +294,22 @@ func (client *ecsClient) setInstanceIdentity( registerRequest.InstanceIdentityDocument = &instanceIdentityDoc if iidRetrieved { - instanceIdentitySignature, err = client.ec2metadata. - GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + ctx, cancel = context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut) + defer cancel() + err = retry.RetryWithBackoffCtx(ctx, backoff, func() error { + var attemptErr error + logger.Debug("Attempting to get Instance Identity Signature") + instanceIdentitySignature, attemptErr = client.ec2metadata. + GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + if attemptErr != nil { + logger.Debug("Unable to get instance identity signature, retrying", logger.Fields{ + field.Error: attemptErr, + }) + return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr) + } + logger.Debug("Successfully retrieved Instance Identity Signature") + return nil + }) if err != nil { logger.Error("Unable to get instance identity signature", logger.Fields{ field.Error: err, @@ -521,7 +535,7 @@ func (client *ecsClient) submitTaskStateChange(change ecs.TaskStateChange) error PullStartedAt: change.PullStartedAt, PullStoppedAt: change.PullStoppedAt, ExecutionStoppedAt: change.ExecutionStoppedAt, - ManagedAgents: change.ManagedAgents, + ManagedAgents: formatManagedAgents(change.ManagedAgents), Containers: formatContainers(change.Containers, client.shouldExcludeIPv6PortBinding, change.TaskARN), } @@ -752,18 +766,29 @@ func (client *ecsClient) UpdateContainerInstancesState(instanceARN string, statu return err } +func formatManagedAgents(managedAgents []*ecsmodel.ManagedAgentStateChange) []*ecsmodel.ManagedAgentStateChange { + var result []*ecsmodel.ManagedAgentStateChange + for _, m := range managedAgents { + if m.Reason != nil { + m.Reason = trimStringPtr(m.Reason, ecsMaxContainerReasonLength) + } + result = append(result, m) + } + return result +} + func formatContainers(containers []*ecsmodel.ContainerStateChange, shouldExcludeIPv6PortBinding bool, taskARN string) []*ecsmodel.ContainerStateChange { var result []*ecsmodel.ContainerStateChange for _, c := range containers { if c.RuntimeId != nil { - c.RuntimeId = aws.String(trimString(aws.StringValue(c.RuntimeId), ecsMaxRuntimeIDLength)) + c.RuntimeId = trimStringPtr(c.RuntimeId, ecsMaxRuntimeIDLength) } if c.Reason != nil { - c.Reason = aws.String(trimString(aws.StringValue(c.Reason), ecsMaxContainerReasonLength)) + c.Reason = trimStringPtr(c.Reason, ecsMaxContainerReasonLength) } if c.ImageDigest != nil { - c.ImageDigest = aws.String(trimString(aws.StringValue(c.ImageDigest), ecsMaxImageDigestLength)) + c.ImageDigest = trimStringPtr(c.ImageDigest, ecsMaxImageDigestLength) } if shouldExcludeIPv6PortBinding { c.NetworkBindings = excludeIPv6PortBindingFromNetworkBindings(c.NetworkBindings, @@ -791,6 +816,13 @@ func excludeIPv6PortBindingFromNetworkBindings(networkBindings []*ecsmodel.Netwo return result } +func trimStringPtr(inputStringPtr *string, maxLen int) *string { + if inputStringPtr == nil { + return nil + } + return aws.String(trimString(aws.StringValue(inputStringPtr), maxLen)) +} + func trimString(inputString string, maxLen int) string { if len(inputString) > maxLen { trimmed := inputString[0:maxLen] diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon/managed_daemon.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon/managed_daemon.go index dfa587b749d..3b615e40697 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon/managed_daemon.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon/managed_daemon.go @@ -62,6 +62,11 @@ type ManagedDaemon struct { linuxParameters *ecsacs.LinuxParameters privileged bool + + containerId string + + containerCGroup string + networkNameSpace string } // A valid managed daemon will require @@ -184,6 +189,11 @@ func (md *ManagedDaemon) GetCommand() []string { return md.command } +func (md *ManagedDaemon) SetCommand(command []string) { + md.command = make([]string, len(command)) + copy(md.command, command) +} + // returns list of mountpoints without the // agentCommunicationMount and applicationLogMount func (md *ManagedDaemon) GetFilteredMountPoints() []*MountPoint { @@ -211,6 +221,26 @@ func (md *ManagedDaemon) GetLoadedDaemonImageRef() string { return md.loadedDaemonImageRef } +func (md *ManagedDaemon) SetLoadedDaemonImageRef(loadedImageRef string) { + md.loadedDaemonImageRef = loadedImageRef +} + +func (md *ManagedDaemon) GetHealthCheckTest() []string { + return md.healthCheckTest +} + +func (md *ManagedDaemon) GetHealthCheckInterval() time.Duration { + return md.healthCheckInterval +} + +func (md *ManagedDaemon) GetHealthCheckTimeout() time.Duration { + return md.healthCheckTimeout +} + +func (md *ManagedDaemon) GetHealthCheckRetries() int { + return md.healthCheckRetries +} + func (md *ManagedDaemon) SetHealthCheck( healthCheckTest []string, healthCheckInterval time.Duration, @@ -277,14 +307,34 @@ func (md *ManagedDaemon) SetEnvironment(environment map[string]string) { } } -func (md *ManagedDaemon) SetLoadedDaemonImageRef(loadedImageRef string) { - md.loadedDaemonImageRef = loadedImageRef -} - func (md *ManagedDaemon) SetPrivileged(isPrivileged bool) { md.privileged = isPrivileged } +func (md *ManagedDaemon) GetContainerId() string { + return md.containerId +} + +func (md *ManagedDaemon) SetContainerId(containerId string) { + md.containerId = containerId +} + +func (md *ManagedDaemon) GetContainerCGroup() string { + return md.containerCGroup +} + +func (md *ManagedDaemon) SetContainerCGroup(containerCGroup string) { + md.containerCGroup = containerCGroup +} + +func (md *ManagedDaemon) GetNetworkNameSpace() string { + return md.networkNameSpace +} + +func (md *ManagedDaemon) SetNetworkNameSpace(networkNameSpace string) { + md.networkNameSpace = networkNameSpace +} + // AddMountPoint will add by MountPoint.SourceVolume // which is unique to the task and is a required field // and will throw an error if an existing diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface/networkinterface.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface/networkinterface.go index b3c333bbb4c..e6015018a8b 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface/networkinterface.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface/networkinterface.go @@ -55,13 +55,16 @@ type NetworkInterface struct { // InterfaceAssociationProtocol is the type of NetworkInterface, valid value: "default", "vlan" InterfaceAssociationProtocol string `json:",omitempty"` - Index int64 `json:"Index,omitempty"` - UserID uint32 `json:"UserID,omitempty"` - Name string `json:"Name,omitempty"` - DeviceName string `json:"DeviceName,omitempty"` - GuestNetNSName string `json:"GuestNetNSName,omitempty"` - KnownStatus status.NetworkStatus `json:"KnownStatus,omitempty"` - DesiredStatus status.NetworkStatus `json:"DesiredStatus,omitempty"` + Index int64 `json:"Index,omitempty"` + UserID uint32 `json:"UserID,omitempty"` + Name string `json:"Name,omitempty"` + DeviceName string `json:"DeviceName,omitempty"` + KnownStatus status.NetworkStatus `json:"KnownStatus,omitempty"` + DesiredStatus status.NetworkStatus `json:"DesiredStatus,omitempty"` + + // GuestNetNSName represents the interface's network namespace inside a guest OS if applicable. + // A sample use case is while running tasks inside Firecracker microVMs. + GuestNetNSName string `json:"GuestNetNSName,omitempty"` // InterfaceVlanProperties contains information for an interface // that is supposed to be used as a VLAN device @@ -413,6 +416,7 @@ func InterfaceFromACS(acsENI *ecsacs.ElasticNetworkInterface) (*NetworkInterface interfaceVlanProperties.TrunkInterfaceMacAddress = aws.StringValue(acsENI.InterfaceVlanProperties.TrunkInterfaceMacAddress) interfaceVlanProperties.VlanID = aws.StringValue(acsENI.InterfaceVlanProperties.VlanId) ni.InterfaceVlanProperties = &interfaceVlanProperties + ni.InterfaceAssociationProtocol = VLANInterfaceAssociationProtocol } for _, nameserverIP := range acsENI.DomainNameServers { @@ -474,7 +478,7 @@ func ValidateENI(acsENI *ecsacs.ElasticNetworkInterface) error { func New( acsENI *ecsacs.ElasticNetworkInterface, guestNetNSName string, - peerInterface *ecsacs.ElasticNetworkInterface, + ifaceList []*ecsacs.ElasticNetworkInterface, macToName map[string]string, ) (*NetworkInterface, error) { var err error @@ -494,7 +498,7 @@ func New( } case VETHInterfaceAssociationProtocol: - networkInterface, err = vethPairFromACS(acsENI, peerInterface) + networkInterface, err = vethPairFromACS(acsENI, ifaceList) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal interface veth properties") } @@ -520,7 +524,9 @@ func New( networkInterface.KnownStatus = status.NetworkNone networkInterface.DesiredStatus = status.NetworkReadyPull networkInterface.GuestNetNSName = guestNetNSName - networkInterface.setDeviceName(macToName) + if err = networkInterface.setDeviceName(macToName); err != nil { + return nil, err + } return networkInterface, nil } @@ -623,11 +629,20 @@ func v2nTunnelFromACS(acsENI *ecsacs.ElasticNetworkInterface) (*NetworkInterface // vethPairFromACS creates an NetworkInterface model with veth pair properties from the ACS NetworkInterface payload. func vethPairFromACS( acsENI *ecsacs.ElasticNetworkInterface, - peerInterface *ecsacs.ElasticNetworkInterface) (*NetworkInterface, error) { + ifaceList []*ecsacs.ElasticNetworkInterface) (*NetworkInterface, error) { if acsENI.InterfaceVethProperties == nil || acsENI.InterfaceVethProperties.PeerInterface == nil { return nil, errors.New("interface veth properties not found in payload") } + + peerName := aws.StringValue(acsENI.InterfaceVethProperties.PeerInterface) + var peerInterface *ecsacs.ElasticNetworkInterface + for _, iface := range ifaceList { + if aws.StringValue(iface.Name) == peerName { + peerInterface = iface + } + } + if aws.StringValue(peerInterface.InterfaceAssociationProtocol) == VETHInterfaceAssociationProtocol { return nil, errors.New("peer interface cannot be veth") } diff --git a/ecs-agent/api/ecs/client/ecs_client.go b/ecs-agent/api/ecs/client/ecs_client.go index 978304d57cf..8b5ed661e18 100644 --- a/ecs-agent/api/ecs/client/ecs_client.go +++ b/ecs-agent/api/ecs/client/ecs_client.go @@ -294,8 +294,22 @@ func (client *ecsClient) setInstanceIdentity( registerRequest.InstanceIdentityDocument = &instanceIdentityDoc if iidRetrieved { - instanceIdentitySignature, err = client.ec2metadata. - GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + ctx, cancel = context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut) + defer cancel() + err = retry.RetryWithBackoffCtx(ctx, backoff, func() error { + var attemptErr error + logger.Debug("Attempting to get Instance Identity Signature") + instanceIdentitySignature, attemptErr = client.ec2metadata. + GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + if attemptErr != nil { + logger.Debug("Unable to get instance identity signature, retrying", logger.Fields{ + field.Error: attemptErr, + }) + return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr) + } + logger.Debug("Successfully retrieved Instance Identity Signature") + return nil + }) if err != nil { logger.Error("Unable to get instance identity signature", logger.Fields{ field.Error: err, @@ -521,7 +535,7 @@ func (client *ecsClient) submitTaskStateChange(change ecs.TaskStateChange) error PullStartedAt: change.PullStartedAt, PullStoppedAt: change.PullStoppedAt, ExecutionStoppedAt: change.ExecutionStoppedAt, - ManagedAgents: change.ManagedAgents, + ManagedAgents: formatManagedAgents(change.ManagedAgents), Containers: formatContainers(change.Containers, client.shouldExcludeIPv6PortBinding, change.TaskARN), } @@ -752,18 +766,29 @@ func (client *ecsClient) UpdateContainerInstancesState(instanceARN string, statu return err } +func formatManagedAgents(managedAgents []*ecsmodel.ManagedAgentStateChange) []*ecsmodel.ManagedAgentStateChange { + var result []*ecsmodel.ManagedAgentStateChange + for _, m := range managedAgents { + if m.Reason != nil { + m.Reason = trimStringPtr(m.Reason, ecsMaxContainerReasonLength) + } + result = append(result, m) + } + return result +} + func formatContainers(containers []*ecsmodel.ContainerStateChange, shouldExcludeIPv6PortBinding bool, taskARN string) []*ecsmodel.ContainerStateChange { var result []*ecsmodel.ContainerStateChange for _, c := range containers { if c.RuntimeId != nil { - c.RuntimeId = aws.String(trimString(aws.StringValue(c.RuntimeId), ecsMaxRuntimeIDLength)) + c.RuntimeId = trimStringPtr(c.RuntimeId, ecsMaxRuntimeIDLength) } if c.Reason != nil { - c.Reason = aws.String(trimString(aws.StringValue(c.Reason), ecsMaxContainerReasonLength)) + c.Reason = trimStringPtr(c.Reason, ecsMaxContainerReasonLength) } if c.ImageDigest != nil { - c.ImageDigest = aws.String(trimString(aws.StringValue(c.ImageDigest), ecsMaxImageDigestLength)) + c.ImageDigest = trimStringPtr(c.ImageDigest, ecsMaxImageDigestLength) } if shouldExcludeIPv6PortBinding { c.NetworkBindings = excludeIPv6PortBindingFromNetworkBindings(c.NetworkBindings, @@ -791,6 +816,13 @@ func excludeIPv6PortBindingFromNetworkBindings(networkBindings []*ecsmodel.Netwo return result } +func trimStringPtr(inputStringPtr *string, maxLen int) *string { + if inputStringPtr == nil { + return nil + } + return aws.String(trimString(aws.StringValue(inputStringPtr), maxLen)) +} + func trimString(inputString string, maxLen int) string { if len(inputString) > maxLen { trimmed := inputString[0:maxLen] diff --git a/ecs-agent/api/ecs/client/ecs_client_test.go b/ecs-agent/api/ecs/client/ecs_client_test.go index 095be1f9841..88601c3fb32 100644 --- a/ecs-agent/api/ecs/client/ecs_client_test.go +++ b/ecs-agent/api/ecs/client/ecs_client_test.go @@ -318,6 +318,10 @@ func TestRegisterContainerInstance(t *testing.T) { name: "basic case", mockCfgAccessorOverride: nil, }, + { + name: "retry GetDynamicData", + mockCfgAccessorOverride: nil, + }, { name: "no instance identity doc", mockCfgAccessorOverride: func(cfgAccessor *mock_config.MockAgentConfigAccessor) { @@ -386,6 +390,8 @@ func TestRegisterContainerInstance(t *testing.T) { Return("", errors.New("fake unit test error")), mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource). Return(expectedIID, nil), + mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource). + Return("", errors.New("fake unit test error")), mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource). Return(expectedIIDSig, nil), ) @@ -1361,6 +1367,25 @@ func TestWithIPv6PortBindingExcludedSetFalse(t *testing.T) { assert.NoError(t, err, "Unable to submit container state change") } +func TestTrimStringPtr(t *testing.T) { + const testMaxLen = 32 + testCases := []struct { + inputStringPtr *string + expectedOutput *string + name string + }{ + {nil, nil, "nil"}, + {aws.String("abc"), aws.String("abc"), "input does not exceed max length"}, + {aws.String("abcdefghijklmnopqrstuvwxyz1234567890"), + aws.String("abcdefghijklmnopqrstuvwxyz123456"), "input exceeds max length"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedOutput, trimStringPtr(tc.inputStringPtr, testMaxLen)) + }) + } +} + func extractTagsMapFromRegisterContainerInstanceInput(req *ecsmodel.RegisterContainerInstanceInput) map[string]string { tagsMap := make(map[string]string, len(req.Tags)) for i := range req.Tags { diff --git a/ecs-agent/manageddaemon/managed_daemon.go b/ecs-agent/manageddaemon/managed_daemon.go index dfa587b749d..3b615e40697 100644 --- a/ecs-agent/manageddaemon/managed_daemon.go +++ b/ecs-agent/manageddaemon/managed_daemon.go @@ -62,6 +62,11 @@ type ManagedDaemon struct { linuxParameters *ecsacs.LinuxParameters privileged bool + + containerId string + + containerCGroup string + networkNameSpace string } // A valid managed daemon will require @@ -184,6 +189,11 @@ func (md *ManagedDaemon) GetCommand() []string { return md.command } +func (md *ManagedDaemon) SetCommand(command []string) { + md.command = make([]string, len(command)) + copy(md.command, command) +} + // returns list of mountpoints without the // agentCommunicationMount and applicationLogMount func (md *ManagedDaemon) GetFilteredMountPoints() []*MountPoint { @@ -211,6 +221,26 @@ func (md *ManagedDaemon) GetLoadedDaemonImageRef() string { return md.loadedDaemonImageRef } +func (md *ManagedDaemon) SetLoadedDaemonImageRef(loadedImageRef string) { + md.loadedDaemonImageRef = loadedImageRef +} + +func (md *ManagedDaemon) GetHealthCheckTest() []string { + return md.healthCheckTest +} + +func (md *ManagedDaemon) GetHealthCheckInterval() time.Duration { + return md.healthCheckInterval +} + +func (md *ManagedDaemon) GetHealthCheckTimeout() time.Duration { + return md.healthCheckTimeout +} + +func (md *ManagedDaemon) GetHealthCheckRetries() int { + return md.healthCheckRetries +} + func (md *ManagedDaemon) SetHealthCheck( healthCheckTest []string, healthCheckInterval time.Duration, @@ -277,14 +307,34 @@ func (md *ManagedDaemon) SetEnvironment(environment map[string]string) { } } -func (md *ManagedDaemon) SetLoadedDaemonImageRef(loadedImageRef string) { - md.loadedDaemonImageRef = loadedImageRef -} - func (md *ManagedDaemon) SetPrivileged(isPrivileged bool) { md.privileged = isPrivileged } +func (md *ManagedDaemon) GetContainerId() string { + return md.containerId +} + +func (md *ManagedDaemon) SetContainerId(containerId string) { + md.containerId = containerId +} + +func (md *ManagedDaemon) GetContainerCGroup() string { + return md.containerCGroup +} + +func (md *ManagedDaemon) SetContainerCGroup(containerCGroup string) { + md.containerCGroup = containerCGroup +} + +func (md *ManagedDaemon) GetNetworkNameSpace() string { + return md.networkNameSpace +} + +func (md *ManagedDaemon) SetNetworkNameSpace(networkNameSpace string) { + md.networkNameSpace = networkNameSpace +} + // AddMountPoint will add by MountPoint.SourceVolume // which is unique to the task and is a required field // and will throw an error if an existing diff --git a/ecs-agent/manageddaemon/managed_daemon_test.go b/ecs-agent/manageddaemon/managed_daemon_test.go index 000e126edb8..60e000a02cc 100644 --- a/ecs-agent/manageddaemon/managed_daemon_test.go +++ b/ecs-agent/manageddaemon/managed_daemon_test.go @@ -333,3 +333,31 @@ func TestIsValidManagedDaemon(t *testing.T) { }) } } + +func TestSetContainerId(t *testing.T) { + testContainerId := "testContainerId" + tmd := NewManagedDaemon(TestImageName, TestImageTag) + tmd.SetContainerId(testContainerId) + assert.Equal(t, testContainerId, tmd.GetContainerId(), "Wrong value for set ContainerId") +} + +func TestSetContainerCGroup(t *testing.T) { + testContainerCGroup := "testContainerCGroup" + tmd := NewManagedDaemon(TestImageName, TestImageTag) + tmd.SetContainerCGroup(testContainerCGroup) + assert.Equal(t, testContainerCGroup, tmd.GetContainerCGroup(), "Wrong value for set ContainerCGroup") +} + +func TestSetNetworkNameSpace(t *testing.T) { + testNetworkNameSpace := "testNetworkNameSpace" + tmd := NewManagedDaemon(TestImageName, TestImageTag) + tmd.SetNetworkNameSpace(testNetworkNameSpace) + assert.Equal(t, testNetworkNameSpace, tmd.GetNetworkNameSpace(), "Wrong value for set NetworkNameSpace") +} + +func TestSetCommand(t *testing.T) { + testCommand := []string{"testCommand1", "testCommand2"} + tmd := NewManagedDaemon(TestImageName, TestImageTag) + tmd.SetCommand(testCommand) + assert.Equal(t, testCommand, tmd.GetCommand(), "Wrong value for set Command") +} diff --git a/ecs-agent/netlib/common_test.go b/ecs-agent/netlib/common_test.go index e1f60ebcc9b..328dd913cb9 100644 --- a/ecs-agent/netlib/common_test.go +++ b/ecs-agent/netlib/common_test.go @@ -31,6 +31,8 @@ const ( eniName = "f05c89a3ab01" eniMAC2 = "f0:5c:89:a3:ab:02" eniName2 = "f05c89a3ab02" + trunkMAC = "f0:5c:89:a3:ab:03" + vlanID = "133" eniID = "eni-abdf1234" eniID2 = "eni-abdf12342" dnsName = "amazon.com" @@ -45,13 +47,18 @@ const ( netNSNamePattern = "%s-%s" searchDomainName = "us-west-2.test.compute.internal" netNSPathDir = "/var/run/netns/" + tunnelID = "1a2b3c" + destinationIP = "10.176.1.19" + primaryIfaceName = "primary" + secondaryIfaceName = "secondary" + vethIfaceName = "veth" ) // getSingleNetNSAWSVPCTestData returns a task payload and a task network config // to be used the input and reference result for tests. The reference object will // has only one network namespace and network interface. func getSingleNetNSAWSVPCTestData(testTaskID string) (*ecsacs.Task, tasknetworkconfig.TaskNetworkConfig) { - enis, netIfs := getTestInterfacesData() + enis, netIfs := getTestInterfacesData_Containerd() taskPayload := &ecsacs.Task{ NetworkMode: aws.String(ecs.NetworkModeAwsvpc), ElasticNetworkInterfaces: []*ecsacs.ElasticNetworkInterface{enis[0]}, @@ -81,7 +88,7 @@ func getSingleNetNSAWSVPCTestData(testTaskID string) (*ecsacs.Task, tasknetworkc // getSingleNetNSMultiIfaceAWSVPCTestData returns test data for EKS like use cases. func getSingleNetNSMultiIfaceAWSVPCTestData(testTaskID string) (*ecsacs.Task, tasknetworkconfig.TaskNetworkConfig) { taskPayload, taskNetConfig := getSingleNetNSAWSVPCTestData(testTaskID) - enis, netIfs := getTestInterfacesData() + enis, netIfs := getTestInterfacesData_Containerd() secondIFPayload := enis[1] secondIF := &netIfs[1] taskPayload.ElasticNetworkInterfaces = append(taskPayload.ElasticNetworkInterfaces, secondIFPayload) @@ -95,7 +102,7 @@ func getSingleNetNSMultiIfaceAWSVPCTestData(testTaskID string) (*ecsacs.Task, ta func getMultiNetNSMultiIfaceAWSVPCTestData(testTaskID string) (*ecsacs.Task, tasknetworkconfig.TaskNetworkConfig) { ifName1 := "primary-eni" ifName2 := "secondary-eni" - enis, netIfs := getTestInterfacesData() + enis, netIfs := getTestInterfacesData_Containerd() enis[0].Name = aws.String(ifName1) enis[1].Name = aws.String(ifName2) @@ -156,7 +163,7 @@ func getMultiNetNSMultiIfaceAWSVPCTestData(testTaskID string) (*ecsacs.Task, tas return taskPayload, taskNetConfig } -func getTestInterfacesData() ([]*ecsacs.ElasticNetworkInterface, []networkinterface.NetworkInterface) { +func getTestInterfacesData_Containerd() ([]*ecsacs.ElasticNetworkInterface, []networkinterface.NetworkInterface) { // interfacePayloads have multiple interfaces as they are sent by ACS // that can be used as input data for tests. interfacePayloads := []*ecsacs.ElasticNetworkInterface{ @@ -260,3 +267,168 @@ func getTestInterfacesData() ([]*ecsacs.ElasticNetworkInterface, []networkinterf return interfacePayloads, networkInterfaces } + +// getV2NTestData returns a test task payload with a V2N interface to be used as test input and the +// task network config object as the expected output. +func getV2NTestData(testTaskID string) (*ecsacs.Task, tasknetworkconfig.TaskNetworkConfig) { + enis, netIfs := getTestInterfacesData_Firecracker() + taskPayload := &ecsacs.Task{ + NetworkMode: aws.String(ecs.NetworkModeAwsvpc), + ElasticNetworkInterfaces: enis, + Containers: []*ecsacs.Container{ + { + NetworkInterfaceNames: []*string{aws.String(primaryIfaceName)}, + }, + { + NetworkInterfaceNames: []*string{aws.String(primaryIfaceName)}, + }, + { + NetworkInterfaceNames: []*string{aws.String(secondaryIfaceName), aws.String(vethIfaceName)}, + }, + { + NetworkInterfaceNames: []*string{aws.String(secondaryIfaceName), aws.String(vethIfaceName)}, + }, + }, + } + + netNSName := fmt.Sprintf(netNSNamePattern, testTaskID, primaryIfaceName) + netNSPath := netNSPathDir + netNSName + + taskNetConfig := tasknetworkconfig.TaskNetworkConfig{ + NetworkMode: ecs.NetworkModeAwsvpc, + NetworkNamespaces: []*tasknetworkconfig.NetworkNamespace{ + { + Name: netNSName, + Path: netNSPath, + Index: 0, + NetworkInterfaces: netIfs, + KnownState: status.NetworkNone, + DesiredState: status.NetworkReadyPull, + }, + }, + } + + return taskPayload, taskNetConfig +} + +func getTestInterfacesData_Firecracker() ([]*ecsacs.ElasticNetworkInterface, []*networkinterface.NetworkInterface) { + // interfacePayloads have multiple interfaces as they are sent by ACS + // that can be used as input data for tests. + interfacePayloads := []*ecsacs.ElasticNetworkInterface{ + { + Name: aws.String(primaryIfaceName), + Ec2Id: aws.String(eniID), + MacAddress: aws.String(eniMAC), + PrivateDnsName: aws.String(dnsName), + DomainNameServers: []*string{aws.String(nameServer)}, + Index: aws.Int64(0), + Ipv4Addresses: []*ecsacs.IPv4AddressAssignment{ + { + Primary: aws.Bool(true), + PrivateAddress: aws.String(ipv4Addr), + }, + }, + Ipv6Addresses: []*ecsacs.IPv6AddressAssignment{ + { + Address: aws.String(ipv6Addr), + }, + }, + SubnetGatewayIpv4Address: aws.String(subnetGatewayCIDR), + InterfaceAssociationProtocol: aws.String(networkinterface.VLANInterfaceAssociationProtocol), + DomainName: []*string{aws.String(searchDomainName)}, + InterfaceVlanProperties: &ecsacs.NetworkInterfaceVlanProperties{ + TrunkInterfaceMacAddress: aws.String(trunkMAC), + VlanId: aws.String(vlanID), + }, + }, + { + Name: aws.String(secondaryIfaceName), + PrivateDnsName: aws.String(dnsName), + DomainNameServers: []*string{aws.String(nameServer2)}, + Index: aws.Int64(1), + SubnetGatewayIpv4Address: aws.String(subnetGatewayCIDR2), + InterfaceAssociationProtocol: aws.String(networkinterface.V2NInterfaceAssociationProtocol), + DomainName: []*string{aws.String(searchDomainName)}, + InterfaceTunnelProperties: &ecsacs.NetworkInterfaceTunnelProperties{ + TunnelId: aws.String(tunnelID), + InterfaceIpAddress: aws.String(destinationIP), + }, + }, + { + Name: aws.String(vethIfaceName), + InterfaceAssociationProtocol: aws.String(networkinterface.VETHInterfaceAssociationProtocol), + InterfaceVethProperties: &ecsacs.NetworkInterfaceVethProperties{ + PeerInterface: aws.String("primary"), + }, + }, + } + + // networkInterfaces represents the desired structure of the network interfaces list + // in the task network configuration object for the payload above. + networkInterfaces := []*networkinterface.NetworkInterface{ + { + ID: eniID, + MacAddress: eniMAC, + Name: primaryIfaceName, + IPV4Addresses: []*networkinterface.IPV4Address{ + { + Primary: true, + Address: ipv4Addr, + }, + }, + IPV6Addresses: []*networkinterface.IPV6Address{ + { + Address: ipv6Addr, + }, + }, + SubnetGatewayIPV4Address: subnetGatewayCIDR, + DomainNameServers: []string{nameServer}, + DomainNameSearchList: []string{searchDomainName}, + PrivateDNSName: dnsName, + InterfaceAssociationProtocol: networkinterface.VLANInterfaceAssociationProtocol, + Index: int64(0), + Default: true, + KnownStatus: status.NetworkNone, + DesiredStatus: status.NetworkReadyPull, + InterfaceVlanProperties: &networkinterface.InterfaceVlanProperties{ + TrunkInterfaceMacAddress: trunkMAC, + VlanID: vlanID, + }, + DeviceName: "eth1.133", + }, + { + Name: secondaryIfaceName, + IPV4Addresses: []*networkinterface.IPV4Address{ + { + Address: networkinterface.DefaultGeneveInterfaceIPAddress, + }, + }, + SubnetGatewayIPV4Address: networkinterface.DefaultGeneveInterfaceGateway, + DomainNameServers: []string{nameServer2}, + DomainNameSearchList: []string{searchDomainName}, + InterfaceAssociationProtocol: networkinterface.V2NInterfaceAssociationProtocol, + Index: int64(1), + KnownStatus: status.NetworkNone, + DesiredStatus: status.NetworkReadyPull, + TunnelProperties: &networkinterface.TunnelProperties{ + ID: tunnelID, + DestinationIPAddress: destinationIP, + }, + GuestNetNSName: secondaryIfaceName, + }, + { + Name: vethIfaceName, + InterfaceAssociationProtocol: networkinterface.VETHInterfaceAssociationProtocol, + DomainNameServers: []string{nameServer}, + DomainNameSearchList: []string{searchDomainName}, + VETHProperties: &networkinterface.VETHProperties{ + PeerInterfaceName: primaryIfaceName, + }, + GuestNetNSName: secondaryIfaceName, + KnownStatus: status.NetworkNone, + DesiredStatus: status.NetworkReadyPull, + }, + } + + return interfacePayloads, networkInterfaces +} diff --git a/ecs-agent/netlib/model/networkinterface/networkinterface.go b/ecs-agent/netlib/model/networkinterface/networkinterface.go index b3c333bbb4c..e6015018a8b 100644 --- a/ecs-agent/netlib/model/networkinterface/networkinterface.go +++ b/ecs-agent/netlib/model/networkinterface/networkinterface.go @@ -55,13 +55,16 @@ type NetworkInterface struct { // InterfaceAssociationProtocol is the type of NetworkInterface, valid value: "default", "vlan" InterfaceAssociationProtocol string `json:",omitempty"` - Index int64 `json:"Index,omitempty"` - UserID uint32 `json:"UserID,omitempty"` - Name string `json:"Name,omitempty"` - DeviceName string `json:"DeviceName,omitempty"` - GuestNetNSName string `json:"GuestNetNSName,omitempty"` - KnownStatus status.NetworkStatus `json:"KnownStatus,omitempty"` - DesiredStatus status.NetworkStatus `json:"DesiredStatus,omitempty"` + Index int64 `json:"Index,omitempty"` + UserID uint32 `json:"UserID,omitempty"` + Name string `json:"Name,omitempty"` + DeviceName string `json:"DeviceName,omitempty"` + KnownStatus status.NetworkStatus `json:"KnownStatus,omitempty"` + DesiredStatus status.NetworkStatus `json:"DesiredStatus,omitempty"` + + // GuestNetNSName represents the interface's network namespace inside a guest OS if applicable. + // A sample use case is while running tasks inside Firecracker microVMs. + GuestNetNSName string `json:"GuestNetNSName,omitempty"` // InterfaceVlanProperties contains information for an interface // that is supposed to be used as a VLAN device @@ -413,6 +416,7 @@ func InterfaceFromACS(acsENI *ecsacs.ElasticNetworkInterface) (*NetworkInterface interfaceVlanProperties.TrunkInterfaceMacAddress = aws.StringValue(acsENI.InterfaceVlanProperties.TrunkInterfaceMacAddress) interfaceVlanProperties.VlanID = aws.StringValue(acsENI.InterfaceVlanProperties.VlanId) ni.InterfaceVlanProperties = &interfaceVlanProperties + ni.InterfaceAssociationProtocol = VLANInterfaceAssociationProtocol } for _, nameserverIP := range acsENI.DomainNameServers { @@ -474,7 +478,7 @@ func ValidateENI(acsENI *ecsacs.ElasticNetworkInterface) error { func New( acsENI *ecsacs.ElasticNetworkInterface, guestNetNSName string, - peerInterface *ecsacs.ElasticNetworkInterface, + ifaceList []*ecsacs.ElasticNetworkInterface, macToName map[string]string, ) (*NetworkInterface, error) { var err error @@ -494,7 +498,7 @@ func New( } case VETHInterfaceAssociationProtocol: - networkInterface, err = vethPairFromACS(acsENI, peerInterface) + networkInterface, err = vethPairFromACS(acsENI, ifaceList) if err != nil { return nil, errors.Wrap(err, "failed to unmarshal interface veth properties") } @@ -520,7 +524,9 @@ func New( networkInterface.KnownStatus = status.NetworkNone networkInterface.DesiredStatus = status.NetworkReadyPull networkInterface.GuestNetNSName = guestNetNSName - networkInterface.setDeviceName(macToName) + if err = networkInterface.setDeviceName(macToName); err != nil { + return nil, err + } return networkInterface, nil } @@ -623,11 +629,20 @@ func v2nTunnelFromACS(acsENI *ecsacs.ElasticNetworkInterface) (*NetworkInterface // vethPairFromACS creates an NetworkInterface model with veth pair properties from the ACS NetworkInterface payload. func vethPairFromACS( acsENI *ecsacs.ElasticNetworkInterface, - peerInterface *ecsacs.ElasticNetworkInterface) (*NetworkInterface, error) { + ifaceList []*ecsacs.ElasticNetworkInterface) (*NetworkInterface, error) { if acsENI.InterfaceVethProperties == nil || acsENI.InterfaceVethProperties.PeerInterface == nil { return nil, errors.New("interface veth properties not found in payload") } + + peerName := aws.StringValue(acsENI.InterfaceVethProperties.PeerInterface) + var peerInterface *ecsacs.ElasticNetworkInterface + for _, iface := range ifaceList { + if aws.StringValue(iface.Name) == peerName { + peerInterface = iface + } + } + if aws.StringValue(peerInterface.InterfaceAssociationProtocol) == VETHInterfaceAssociationProtocol { return nil, errors.New("peer interface cannot be veth") } diff --git a/ecs-agent/netlib/network_builder_linux_test.go b/ecs-agent/netlib/network_builder_linux_test.go index e2885cc0b69..789f5995d05 100644 --- a/ecs-agent/netlib/network_builder_linux_test.go +++ b/ecs-agent/netlib/network_builder_linux_test.go @@ -56,9 +56,13 @@ func TestNewNetworkBuilder(t *testing.T) { // the network builder is able to translate the input task payload into the desired // network data models. func TestNetworkBuilder_BuildTaskNetworkConfiguration(t *testing.T) { - t.Run("containerd-default", getTestFunc(getSingleNetNSAWSVPCTestData)) - t.Run("containerd-multi-interface", getTestFunc(getSingleNetNSMultiIfaceAWSVPCTestData)) - t.Run("containerd-multi-netns", getTestFunc(getMultiNetNSMultiIfaceAWSVPCTestData)) + // Warmpool test cases. + t.Run("containerd-default", getTestFunc(getSingleNetNSAWSVPCTestData, platform.WarmpoolPlatform)) + t.Run("containerd-multi-interface", getTestFunc(getSingleNetNSMultiIfaceAWSVPCTestData, platform.WarmpoolPlatform)) + t.Run("containerd-multi-netns", getTestFunc(getMultiNetNSMultiIfaceAWSVPCTestData, platform.WarmpoolPlatform)) + + // Firecracker test cases. + t.Run("firecracker-v2n-veth", getTestFunc(getV2NTestData, platform.FirecrackerPlatform)) } func TestNetworkBuilder_Start(t *testing.T) { @@ -72,7 +76,10 @@ func TestNetworkBuilder_Stop(t *testing.T) { // getTestFunc returns a test function that verifies the capability of the networkBuilder // to translate a given input task payload into desired network data models. -func getTestFunc(dataGenF func(string) (input *ecsacs.Task, expected tasknetworkconfig.TaskNetworkConfig)) func(*testing.T) { +func getTestFunc( + dataGenF func(string) (input *ecsacs.Task, expected tasknetworkconfig.TaskNetworkConfig), + plt string, +) func(*testing.T) { return func(t *testing.T) { ctrl := gomock.NewController(t) @@ -80,7 +87,7 @@ func getTestFunc(dataGenF func(string) (input *ecsacs.Task, expected tasknetwork // Create a networkBuilder for the warmpool platform. mockNet := mock_netwrapper.NewMockNet(ctrl) - platformAPI, err := platform.NewPlatform(platform.WarmpoolPlatform, nil, "", mockNet) + platformAPI, err := platform.NewPlatform(plt, nil, "", mockNet) require.NoError(t, err) netBuilder := &networkBuilder{ platformAPI: platformAPI, @@ -89,10 +96,27 @@ func getTestFunc(dataGenF func(string) (input *ecsacs.Task, expected tasknetwork // Generate input task payload and a reference to verify the output with. taskPayload, expectedConfig := dataGenF(taskID) + // The agent expects the regular / trunk ENI to be present on the host. + // We should mock the net.Interfaces() method to return a list of interfaces + // on the host accordingly. var ifaces []net.Interface idx := 1 for _, eni := range taskPayload.ElasticNetworkInterfaces { - hw, err := net.ParseMAC(aws.StringValue(eni.MacAddress)) + var mac string + // In case of regular ENIs, the agent expects to find an interface with + // the ACS ENI's MAC address on the host. In case of branch ENIs (which + // use VLAN ID), the agent expects to find a trunk interface with the MAC + // address specified in the VLAN properties of the ACS ENI. + if eni.InterfaceVlanProperties == nil { + mac = aws.StringValue(eni.MacAddress) + } else { + mac = aws.StringValue(eni.InterfaceVlanProperties.TrunkInterfaceMacAddress) + } + // Veth and V2N interfaces will not have a MAC address associated with them. + if mac == "" { + continue + } + hw, err := net.ParseMAC(mac) require.NoError(t, err) ifaces = append(ifaces, net.Interface{ HardwareAddr: hw, diff --git a/ecs-agent/netlib/platform/common_linux.go b/ecs-agent/netlib/platform/common_linux.go index 05f010e4802..019c5c194b4 100644 --- a/ecs-agent/netlib/platform/common_linux.go +++ b/ecs-agent/netlib/platform/common_linux.go @@ -100,12 +100,16 @@ func NewPlatform( net: netWrapper, } - // TODO: implement remaining platforms - FoF, windows. + // TODO: implement remaining platforms - windows. switch platformString { case WarmpoolPlatform: return &containerd{ common: commonPlatform, }, nil + case FirecrackerPlatform: + return &firecraker{ + common: commonPlatform, + }, nil } return nil, errors.New("invalid platform: " + platformString) } @@ -114,13 +118,16 @@ func NewPlatform( // into the task network configuration data structure internal to the agent. func (c *common) buildTaskNetworkConfiguration( taskID string, - taskPayload *ecsacs.Task) (*tasknetworkconfig.TaskNetworkConfig, error) { + taskPayload *ecsacs.Task, + singleNetNS bool, + ifaceToGuestNetNS map[string]string, +) (*tasknetworkconfig.TaskNetworkConfig, error) { mode := aws.StringValue(taskPayload.NetworkMode) var netNSs []*tasknetworkconfig.NetworkNamespace var err error switch mode { case ecs.NetworkModeAwsvpc: - netNSs, err = c.buildAWSVPCNetworkNamespaces(taskID, taskPayload) + netNSs, err = c.buildAWSVPCNetworkNamespaces(taskID, taskPayload, singleNetNS, ifaceToGuestNetNS) if err != nil { return nil, errors.Wrap(err, "failed to translate network configuration") } @@ -145,16 +152,19 @@ func (c *common) GetNetNSPath(netNSName string) string { } // buildAWSVPCNetworkNamespaces returns list of NetworkNamespace which will be used to -// create the task's network configuration. All cases except those for FoF is covered by -// this method. FoF requires a separate specific implementation because the network setup -// is different due to the presence of the microVM. +// create the task's network configuration. // Use cases covered by this method are: // 1. Single interface, network namespace (the only externally available config). // 2. Single netns, multiple interfaces (For a non-managed multi-ENI experience. Eg EKS use case). // 3. Multiple netns, multiple interfaces (future use case for internal customer who need // a managed multi-ENI experience). -func (c *common) buildAWSVPCNetworkNamespaces(taskID string, - taskPayload *ecsacs.Task) ([]*tasknetworkconfig.NetworkNamespace, error) { +// 4. Single netns, multiple interfaces (for V2N tasks on FoF). +func (c *common) buildAWSVPCNetworkNamespaces( + taskID string, + taskPayload *ecsacs.Task, + singleNetNS bool, + ifaceToGuestNetNS map[string]string, +) ([]*tasknetworkconfig.NetworkNamespace, error) { if len(taskPayload.ElasticNetworkInterfaces) == 0 { return nil, errors.New("interfaces list cannot be empty") } @@ -163,19 +173,17 @@ func (c *common) buildAWSVPCNetworkNamespaces(taskID string, if err != nil { return nil, err } - // If task payload has only one interface, the network configuration is - // straight forward. It will have only one network namespace containing - // the corresponding network interface. - // Empty Name fields in network interface names indicate that all - // interfaces share the same network namespace. This use case is - // utilized by certain internal teams like EKS on Fargate. - if len(taskPayload.ElasticNetworkInterfaces) == 1 || + // If we require all interfaces to be in one single netns, the network configuration is straight forward. + // This case is identified if the singleNetNS flag is set, or if the ENIs have an empty 'Name' field, + // or if there is only on ENI in the payload. + if singleNetNS || len(taskPayload.ElasticNetworkInterfaces) == 1 || aws.StringValue(taskPayload.ElasticNetworkInterfaces[0].Name) == "" { primaryNetNS, err := c.buildNetNS(taskID, 0, taskPayload.ElasticNetworkInterfaces, taskPayload.ProxyConfiguration, - macToNames) + macToNames, + ifaceToGuestNetNS) if err != nil { return nil, err } @@ -226,7 +234,7 @@ func (c *common) buildAWSVPCNetworkNamespaces(taskID string, continue } - netNS, err := c.buildNetNS(taskID, nsIndex, ifaces, nil, macToNames) + netNS, err := c.buildNetNS(taskID, nsIndex, ifaces, nil, macToNames, nil) if err != nil { return nil, err } @@ -237,17 +245,21 @@ func (c *common) buildAWSVPCNetworkNamespaces(taskID string, return netNSs, nil } +// buildNetNS creates a single network namespace object using the input network config data. func (c *common) buildNetNS( taskID string, index int, networkInterfaces []*ecsacs.ElasticNetworkInterface, proxyConfig *ecsacs.ProxyConfiguration, - macToName map[string]string) (*tasknetworkconfig.NetworkNamespace, error) { + macToName map[string]string, + ifaceToGuestNetNS map[string]string, +) (*tasknetworkconfig.NetworkNamespace, error) { var primaryIF *networkinterface.NetworkInterface var ifaces []*networkinterface.NetworkInterface lowestIdx := int64(indexHighValue) for _, ni := range networkInterfaces { - iface, err := networkinterface.New(ni, "", nil, macToName) + guestNetNS := ifaceToGuestNetNS[aws.StringValue(ni.Name)] + iface, err := networkinterface.New(ni, guestNetNS, networkInterfaces, macToName) if err != nil { return nil, err } diff --git a/ecs-agent/netlib/platform/containerd_linux.go b/ecs-agent/netlib/platform/containerd_linux.go index 5867bff871a..789d0486ea3 100644 --- a/ecs-agent/netlib/platform/containerd_linux.go +++ b/ecs-agent/netlib/platform/containerd_linux.go @@ -32,7 +32,7 @@ func (c *containerd) BuildTaskNetworkConfiguration( taskID string, taskPayload *ecsacs.Task) (*tasknetworkconfig.TaskNetworkConfig, error) { - return c.common.buildTaskNetworkConfiguration(taskID, taskPayload) + return c.common.buildTaskNetworkConfiguration(taskID, taskPayload, false, nil) } func (c *containerd) CreateDNSConfig(taskID string, netNS *tasknetworkconfig.NetworkNamespace) error { diff --git a/ecs-agent/netlib/platform/firecracker_linux.go b/ecs-agent/netlib/platform/firecracker_linux.go new file mode 100644 index 00000000000..936c0083aee --- /dev/null +++ b/ecs-agent/netlib/platform/firecracker_linux.go @@ -0,0 +1,135 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package platform + +import ( + "context" + "errors" + "fmt" + + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/appmesh" + "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" + "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/serviceconnect" + "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/tasknetworkconfig" + + "github.com/aws/aws-sdk-go/aws" +) + +type firecraker struct { + common +} + +func (f *firecraker) BuildTaskNetworkConfiguration( + taskID string, + taskPayload *ecsacs.Task) (*tasknetworkconfig.TaskNetworkConfig, error) { + + // On Firecracker, there is always only one task network namespace on the bare metal host. + // Inside the microVM, a dedicated netns will also be created to separate primary interface + // and secondary interface(s) of the task. The following method invocation inspects the + // container-to-interface mapping to decide which interface resides in which namespace inside + // the microVM. + i2n, err := assignInterfacesToNamespaces(taskPayload) + if err != nil { + return nil, err + } + + return f.common.buildTaskNetworkConfiguration(taskID, taskPayload, true, i2n) +} + +func (f *firecraker) CreateDNSConfig(taskID string, netNS *tasknetworkconfig.NetworkNamespace) error { + return f.common.createDNSConfig(taskID, false, netNS) +} + +func (f *firecraker) ConfigureInterface( + ctx context.Context, netNSPath string, iface *networkinterface.NetworkInterface) error { + return f.common.configureInterface(ctx, netNSPath, iface) +} + +func (f *firecraker) ConfigureAppMesh(ctx context.Context, netNSPath string, cfg *appmesh.AppMesh) error { + return errors.New("not implemented") +} + +func (f *firecraker) ConfigureServiceConnect( + ctx context.Context, + netNSPath string, + primaryIf *networkinterface.NetworkInterface, + scConfig *serviceconnect.ServiceConnectConfig, +) error { + return errors.New("not implemented") +} + +// assignInterfacesToNamespaces computes how many network namespaces the task needs and assigns +// each network interface to a network namespace. +func assignInterfacesToNamespaces(taskPayload *ecsacs.Task) (map[string]string, error) { + // The task payload has a list of containers, a list of network interface names, and a list of + // which interface(s) each container should have access to. For this schema to work, the set of + // interface(s) used by one or more containers need to be grouped into network namespaces. Then + // the container runtime needs to be told to launch each container in its designated network + // namespace. This function computes how many network namespaces are needed, and then returns a + // map of network interface names to network namespace names. + i2n := make(map[string]string) + + // Optimization for the common case: If the task has a single interface, there is nothing to do. + if len(taskPayload.ElasticNetworkInterfaces) == 1 { + return i2n, nil + } + + for _, c := range taskPayload.Containers { + // containerNetNS keeps track of the netns assigned to this container. + containerNetNS := "" + + for _, i := range c.NetworkInterfaceNames { + ifName := aws.StringValue(i) + + netnsName, ok := i2n[ifName] + if !ok { + // This interface was not assigned to a netns yet. + // Create a new netns for this container if it doesn't have one. + if containerNetNS == "" { + // Use the container's first interface's name as the netns name. + // This naming isn't strictly necessary, just convenient when debugging. + containerNetNS = ifName + } + // Assign the interface to this container's netns. + i2n[ifName] = containerNetNS + } else { + // This interface was already assigned to a netns in a previous iteration. + // Assign the interface's netns to this container. + if containerNetNS == "" { + containerNetNS = netnsName + } + // All interfaces for a given container must be in the same netns. + if netnsName != containerNetNS { + return nil, fmt.Errorf("invalid task netns config") + } + } + } + } + + // The logic above names each netns after the first network interface placed in it. However the + // first (primary) netns should always be named "" so that it maps to the default netns. + for _, e := range taskPayload.ElasticNetworkInterfaces { + if *e.Index == int64(0) { + for ifName, netNSName := range i2n { + if netNSName == aws.StringValue(e.Name) { + i2n[ifName] = "" + } + } + break + } + } + + return i2n, nil +}