From b12bb67d8939c9cb90f1d844f0f31bdd0a8b7b23 Mon Sep 17 00:00:00 2001 From: Zhecheng Li Date: Thu, 8 Aug 2024 14:19:00 +0000 Subject: [PATCH] [credential provider] Add a flag --registry-mirror With registry mirror, e.g. --registry-mirror=mcr.microsoft.com:xxx.azurecr.io credential provider will return credentials of xxx.azurecr.io for mcr.microsoft.com. In this way, an image with URL prefix mcr.microsoft.com, but actually in xxx.azurecr.io, can be successfully pulled. Signed-off-by: Zhecheng Li --- cmd/acr-credential-provider/main.go | 9 +- .../credential-provider-config.yaml | 2 + pkg/credentialprovider/azure_credentials.go | 78 +++++++++++---- .../azure_credentials_test.go | 96 ++++++++++++++++++- 4 files changed, 160 insertions(+), 25 deletions(-) diff --git a/cmd/acr-credential-provider/main.go b/cmd/acr-credential-provider/main.go index 93a78925bc..9cb1840ec0 100644 --- a/cmd/acr-credential-provider/main.go +++ b/cmd/acr-credential-provider/main.go @@ -33,6 +33,9 @@ import ( func main() { rand.Seed(time.Now().UnixNano()) + + var RegistryMirrorStr string + command := &cobra.Command{ Use: "acr-credential-provider configFile", Short: "Acr credential provider for Kubelet", @@ -44,7 +47,7 @@ func main() { os.Exit(1) } - acrProvider, err := credentialprovider.NewAcrProviderFromConfig(args[0]) + acrProvider, err := credentialprovider.NewAcrProviderFromConfig(args[0], RegistryMirrorStr) if err != nil { klog.Errorf("Failed to initialize ACR provider: %v", err) os.Exit(1) @@ -60,6 +63,10 @@ func main() { logs.InitLogs() defer logs.FlushLogs() + // Flags + command.Flags().StringVarP(&RegistryMirrorStr, "registry-mirror", "r", "", + "Mirror a source registry host to a target registry host, and image pull credential will be requested to the target registry host when the image is from source registry host") + if err := command.Execute(); err != nil { os.Exit(1) } diff --git a/examples/out-of-tree/credential-provider-config.yaml b/examples/out-of-tree/credential-provider-config.yaml index f079b05392..f9f6e3b6b0 100644 --- a/examples/out-of-tree/credential-provider-config.yaml +++ b/examples/out-of-tree/credential-provider-config.yaml @@ -9,5 +9,7 @@ providers: - "*.azurecr.cn" - "*.azurecr.de" - "*.azurecr.us" + - "mcr.microsoft.com" args: - /etc/kubernetes/azure.json + - --registry-mirror=mcr.microsoft.com:xxx.azurecr.io diff --git a/pkg/credentialprovider/azure_credentials.go b/pkg/credentialprovider/azure_credentials.go index 2b929b9ab7..a9cf3d27a5 100644 --- a/pkg/credentialprovider/azure_credentials.go +++ b/pkg/credentialprovider/azure_credentials.go @@ -60,9 +60,10 @@ type CredentialProvider interface { // acrProvider implements the credential provider interface for Azure Container Registry. type acrProvider struct { - config *providerconfig.AzureAuthConfig - environment *azclient.Environment - credential azcore.TokenCredential + config *providerconfig.AzureAuthConfig + environment *azclient.Environment + credential azcore.TokenCredential + registryMirror map[string]string // Registry mirror relation: source registry -> target registry } func NewAcrProvider(config *providerconfig.AzureAuthConfig, environment *azclient.Environment, credential azcore.TokenCredential) CredentialProvider { @@ -74,7 +75,7 @@ func NewAcrProvider(config *providerconfig.AzureAuthConfig, environment *azclien } // NewAcrProvider creates a new instance of the ACR provider. -func NewAcrProviderFromConfig(configFile string) (CredentialProvider, error) { +func NewAcrProviderFromConfig(configFile string, registryMirrorStr string) (CredentialProvider, error) { if len(configFile) == 0 { return nil, errors.New("no azure credential file is provided") } @@ -120,15 +121,16 @@ func NewAcrProviderFromConfig(configFile string) (CredentialProvider, error) { } return &acrProvider{ - config: config, - credential: managedIdentityCredential, - environment: &envConfig, + config: config, + credential: managedIdentityCredential, + environment: &envConfig, + registryMirror: parseRegistryMirror(registryMirrorStr), }, nil } func (a *acrProvider) GetCredentials(ctx context.Context, image string, _ []string) (*v1.CredentialProviderResponse, error) { - loginServer := a.parseACRLoginServerFromImage(image) - if loginServer == "" { + targetloginServer, sourceloginServer := a.parseACRLoginServerFromImage(image) + if targetloginServer == "" { klog.V(2).Infof("image(%s) is not from ACR, return empty authentication", image) return &v1.CredentialProviderResponse{ CacheKeyType: v1.RegistryPluginCacheKeyType, @@ -150,16 +152,20 @@ func (a *acrProvider) GetCredentials(ctx context.Context, image string, _ []stri } if a.config.UseManagedIdentityExtension { - username, password, err := a.getFromACR(ctx, loginServer) + username, password, err := a.getFromACR(ctx, targetloginServer) if err != nil { - klog.Errorf("error getting credentials from ACR for %s: %s", loginServer, err) + klog.Errorf("error getting credentials from ACR for %s: %s", targetloginServer, err) return nil, err } - response.Auth[loginServer] = v1.AuthConfig{ + authConfig := v1.AuthConfig{ Username: username, Password: password, } + response.Auth[targetloginServer] = authConfig + if sourceloginServer != "" { + response.Auth[sourceloginServer] = authConfig + } } else { // Add our entry for each of the supported container registry URLs for _, url := range containerRegistryUrls { @@ -229,13 +235,16 @@ func (a *acrProvider) getFromACR(ctx context.Context, loginServer string) (strin return dockerTokenLoginUsernameGUID, registryRefreshToken, nil } -// parseACRLoginServerFromImage takes image as parameter and returns login server of it. -// Parameter `image` is expected in following format: foo.azurecr.io/bar/imageName:version +// parseACRLoginServerFromImage inputs an image URL and outputs login servers of target registry and source registry if --registry-mirror is set. +// Input is expected in following format: foo.azurecr.io/bar/imageName:version // If the provided image is not an acr image, this function will return an empty string. -func (a *acrProvider) parseACRLoginServerFromImage(image string) string { - match := acrRE.FindAllString(image, -1) +func (a *acrProvider) parseACRLoginServerFromImage(image string) (string, string) { + targetImage, sourceRegistry := a.processImageWithRegistryMirror(image) + + match := acrRE.FindAllString(targetImage, -1) if len(match) == 1 { - return match[0] + targetRegistry := match[0] + return targetRegistry, sourceRegistry } // handle the custom cloud case @@ -243,13 +252,42 @@ func (a *acrProvider) parseACRLoginServerFromImage(image string) string { cloudAcrSuffix := a.environment.ContainerRegistryDNSSuffix cloudAcrSuffixLength := len(cloudAcrSuffix) if cloudAcrSuffixLength > 0 { - customAcrSuffixIndex := strings.Index(image, cloudAcrSuffix) + customAcrSuffixIndex := strings.Index(targetImage, cloudAcrSuffix) if customAcrSuffixIndex != -1 { endIndex := customAcrSuffixIndex + cloudAcrSuffixLength - return image[0:endIndex] + return targetImage[0:endIndex], sourceRegistry } } } - return "" + return "", "" +} + +// With acrProvider registry mirror, e.g. {"mcr.microsoft.com": "abc.azurecr.io"} +// processImageWithRegistryMirror input format: "mcr.microsoft.com/bar/image:version" +// output format: "abc.azurecr.io/bar/image:version", "mcr.microsoft.com" +func (a *acrProvider) processImageWithRegistryMirror(image string) (string, string) { + for sourceRegistry, targetRegistry := range a.registryMirror { + if strings.HasPrefix(image, sourceRegistry) { + return strings.Replace(image, sourceRegistry, targetRegistry, 1), sourceRegistry + } + } + return image, "" +} + +// parseRegistryMirror input format: "--registry-mirror=aaa:bbb,ccc:ddd" +// output format: map[string]string{"aaa": "bbb", "ccc": "ddd"} +func parseRegistryMirror(registryMirrorStr string) map[string]string { + registryMirror := map[string]string{} + + registryMirrorStr = strings.ReplaceAll(registryMirrorStr, " ", "") + for _, mapping := range strings.Split(registryMirrorStr, ",") { + parts := strings.Split(mapping, ":") + if len(parts) != 2 { + klog.Errorf("Invalid registry mirror format: %s", mapping) + continue + } + registryMirror[parts[0]] = parts[1] + } + return registryMirror } diff --git a/pkg/credentialprovider/azure_credentials_test.go b/pkg/credentialprovider/azure_credentials_test.go index 5097e5ff39..2dd8d6918c 100644 --- a/pkg/credentialprovider/azure_credentials_test.go +++ b/pkg/credentialprovider/azure_credentials_test.go @@ -21,6 +21,7 @@ import ( "net/http" "net/http/httptest" "os" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -137,7 +138,7 @@ func TestGetCredentialsConfig(t *testing.T) { if err != nil { t.Fatalf("Unexpected error when closing temp file: %v", err) } - provider, err := NewAcrProviderFromConfig(configFile.Name()) + provider, err := NewAcrProviderFromConfig(configFile.Name(), "") if err != nil && !test.expectError { t.Fatalf("Unexpected error when creating new acr provider: %v", err) } @@ -163,6 +164,53 @@ func TestGetCredentialsConfig(t *testing.T) { } } +func TestProcessImageWithMirrorMapping(t *testing.T) { + configStr := ` + { + "aadClientId": "foo", + "aadClientSecret": "bar" + }` + + configFile, err := os.CreateTemp(".", "config.json") + assert.Nilf(t, err, "Unexpected error when creating temp file") + defer os.Remove(configFile.Name()) + _, err = configFile.WriteString(configStr) + assert.Nilf(t, err, "Unexpected error when writing to temp file") + assert.Nilf(t, configFile.Close(), "Unexpected error when closing temp file") + + provider, err := NewAcrProviderFromConfig(configFile.Name(), "mcr.microsoft.com:abc.azurecr.io") + assert.Nilf(t, err, "Unexpected error when creating new acr provider") + acrProvider := provider.(*acrProvider) + + testcases := []struct { + description string + image string + expectedLoginServer string + expectedLoginServerMirror string + }{ + { + description: "image in registry mirror map", + image: "mcr.microsoft.com/bar/image:version", + expectedLoginServer: "abc.azurecr.io", + expectedLoginServerMirror: "mcr.microsoft.com", + }, + { + description: "image not in registry mirror map", + image: "foo.azurecr.io/bar/image:version", + expectedLoginServer: "foo.azurecr.io", + expectedLoginServerMirror: "", + }, + } + + for _, test := range testcases { + t.Run(test.description, func(t *testing.T) { + targetloginServer, sourceloginServer := acrProvider.parseACRLoginServerFromImage(test.image) + assert.Equal(t, targetloginServer, test.expectedLoginServer) + assert.Equal(t, sourceloginServer, test.expectedLoginServerMirror) + }) + } +} + func TestParseACRLoginServerFromImage(t *testing.T) { providerInterface := NewAcrProvider(&config.AzureAuthConfig{ @@ -215,8 +263,48 @@ func TestParseACRLoginServerFromImage(t *testing.T) { }, } for _, test := range tests { - if loginServer := provider.parseACRLoginServerFromImage(test.image); loginServer != test.expected { - t.Errorf("function parseACRLoginServerFromImage returns \"%s\" for image %s, expected \"%s\"", loginServer, test.image, test.expected) - } + t.Run(test.image, func(t *testing.T) { + targetloginServer, _ := provider.parseACRLoginServerFromImage(test.image) + assert.Equal(t, targetloginServer, test.expected) + }) + } +} + +func TestProcessMirrorMapping(t *testing.T) { + testcases := []struct { + description string + mirrorMappingStr string + expected map[string]string + }{ + { + "multiple", + "aaa:bbb,ccc:ddd", + map[string]string{ + "aaa": "bbb", + "ccc": "ddd", + }, + }, + { + "multiple with some spaces", + "aaa: bbb, ccc:ddd", + map[string]string{ + "aaa": "bbb", + "ccc": "ddd", + }, + }, + { + "single", + "aaa:bbb", + map[string]string{ + "aaa": "bbb", + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.description, func(t *testing.T) { + result := parseRegistryMirror(tc.mirrorMappingStr) + assert.True(t, reflect.DeepEqual(result, tc.expected)) + }) } }