Skip to content

Commit

Permalink
[credential provider] Add a flag mirrorMapping
Browse files Browse the repository at this point in the history
This flag is to mirror registry A to B when fetching credential.

Signed-off-by: Zhecheng Li <zhechengli@microsoft.com>
  • Loading branch information
lzhecheng committed Aug 28, 2024
1 parent 4c72a56 commit 3d563f5
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 25 deletions.
8 changes: 7 additions & 1 deletion cmd/acr-credential-provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ import (

func main() {
rand.Seed(time.Now().UnixNano())

var mirrorMappingStr string

command := &cobra.Command{
Use: "acr-credential-provider configFile",
Short: "Acr credential provider for Kubelet",
Expand All @@ -44,7 +47,7 @@ func main() {
os.Exit(1)
}

acrProvider, err := credentialprovider.NewAcrProviderFromConfig(args[0])
acrProvider, err := credentialprovider.NewAcrProviderFromConfig(args[0], mirrorMappingStr)
if err != nil {
klog.Errorf("Failed to initialize ACR provider: %v", err)
os.Exit(1)
Expand All @@ -60,6 +63,9 @@ func main() {
logs.InitLogs()
defer logs.FlushLogs()

// Flags
command.Flags().StringVarP(&mirrorMappingStr, "mirror-mapping", "m", "", "mirror mapping to use")

if err := command.Execute(); err != nil {
os.Exit(1)
}
Expand Down
2 changes: 2 additions & 0 deletions examples/out-of-tree/credential-provider-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ providers:
- "*.azurecr.cn"
- "*.azurecr.de"
- "*.azurecr.us"
- "mcr.microsoft.com"
args:
- /etc/kubernetes/azure.json
- --mirror-mapping=mcr.microsoft.com:xxx.azurecr.io
80 changes: 60 additions & 20 deletions pkg/credentialprovider/azure_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ type CredentialProvider interface {

// acrProvider implements the credential provider interface for Azure Container Registry.
type acrProvider struct {
config *providerconfig.AzureAuthConfig
environment *azclient.Environment
credential azcore.TokenCredential
config *providerconfig.AzureAuthConfig
environment *azclient.Environment
credential azcore.TokenCredential
mirrorMapping map[string]string // Mirror mapping relation: source registry -> target registry
}

func NewAcrProvider(config *providerconfig.AzureAuthConfig, environment *azclient.Environment, credential azcore.TokenCredential) CredentialProvider {
Expand All @@ -74,7 +75,7 @@ func NewAcrProvider(config *providerconfig.AzureAuthConfig, environment *azclien
}

// NewAcrProvider creates a new instance of the ACR provider.
func NewAcrProviderFromConfig(configFile string) (CredentialProvider, error) {
func NewAcrProviderFromConfig(configFile string, mirrorMappingStr string) (CredentialProvider, error) {
if len(configFile) == 0 {
return nil, errors.New("no azure credential file is provided")
}
Expand Down Expand Up @@ -120,15 +121,16 @@ func NewAcrProviderFromConfig(configFile string) (CredentialProvider, error) {
}

return &acrProvider{
config: config,
credential: managedIdentityCredential,
environment: &envConfig,
config: config,
credential: managedIdentityCredential,
environment: &envConfig,
mirrorMapping: processMirrorMapping(mirrorMappingStr),
}, nil
}

func (a *acrProvider) GetCredentials(ctx context.Context, image string, _ []string) (*v1.CredentialProviderResponse, error) {
loginServer := a.parseACRLoginServerFromImage(image)
if loginServer == "" {
targetloginServer, sourceloginServer := a.parseACRLoginServerFromImage(image)
if targetloginServer == "" {
klog.V(2).Infof("image(%s) is not from ACR, return empty authentication", image)
return &v1.CredentialProviderResponse{
CacheKeyType: v1.RegistryPluginCacheKeyType,
Expand All @@ -150,16 +152,22 @@ func (a *acrProvider) GetCredentials(ctx context.Context, image string, _ []stri
}

if a.config.UseManagedIdentityExtension {
username, password, err := a.getFromACR(ctx, loginServer)
username, password, err := a.getFromACR(ctx, targetloginServer)
if err != nil {
klog.Errorf("error getting credentials from ACR for %s: %s", loginServer, err)
klog.Errorf("error getting credentials from ACR for %s: %s", targetloginServer, err)
return nil, err
}

response.Auth[loginServer] = v1.AuthConfig{
response.Auth[targetloginServer] = v1.AuthConfig{
Username: username,
Password: password,
}
if sourceloginServer != "" {
response.Auth[sourceloginServer] = v1.AuthConfig{
Username: username,
Password: password,
}
}
} else {
// Add our entry for each of the supported container registry URLs
for _, url := range containerRegistryUrls {
Expand Down Expand Up @@ -229,27 +237,59 @@ func (a *acrProvider) getFromACR(ctx context.Context, loginServer string) (strin
return dockerTokenLoginUsernameGUID, registryRefreshToken, nil
}

// parseACRLoginServerFromImage takes image as parameter and returns login server of it.
// Parameter `image` is expected in following format: foo.azurecr.io/bar/imageName:version
// parseACRLoginServerFromImage inputs an image URL and outputs login servers of target registry and source registry if --mirrorMapping is set.
// Input is expected in following format: foo.azurecr.io/bar/imageName:version
// If the provided image is not an acr image, this function will return an empty string.
func (a *acrProvider) parseACRLoginServerFromImage(image string) string {
match := acrRE.FindAllString(image, -1)
func (a *acrProvider) parseACRLoginServerFromImage(image string) (string, string) {
targetImage, sourceRegistry := a.processImageWithMirrorMapping(image)

match := acrRE.FindAllString(targetImage, -1)
if len(match) == 1 {
return match[0]
targetRegistry := match[0]
return targetRegistry, sourceRegistry
}

// handle the custom cloud case
if a != nil && a.environment != nil {
cloudAcrSuffix := a.environment.ContainerRegistryDNSSuffix
cloudAcrSuffixLength := len(cloudAcrSuffix)
if cloudAcrSuffixLength > 0 {
customAcrSuffixIndex := strings.Index(image, cloudAcrSuffix)
customAcrSuffixIndex := strings.Index(targetImage, cloudAcrSuffix)
if customAcrSuffixIndex != -1 {
endIndex := customAcrSuffixIndex + cloudAcrSuffixLength
return image[0:endIndex]
return targetImage[0:endIndex], sourceRegistry
}
}
}

return ""
return "", ""
}

// With acrProvider mirror mapping, e.g. {"mcr.microsoft.com": "abc.azurecr.io"}
// processImageWithMirrorMapping input format: "mcr.microsoft.com/bar/image:version"
// output format: "abc.azurecr.io/bar/image:version", "mcr.microsoft.com"
func (a *acrProvider) processImageWithMirrorMapping(image string) (string, string) {
for sourceRegistry, targetRegistry := range a.mirrorMapping {
if strings.HasPrefix(image, sourceRegistry) {
return strings.Replace(image, sourceRegistry, targetRegistry, 1), sourceRegistry
}
}
return image, ""
}

// processMirrorMapping input format: "--mirror-mapping=aaa:bbb,ccc:ddd"
// output format: map[string]string{"aaa": "bbb", "ccc": "ddd"}
func processMirrorMapping(mirrorMappingStr string) map[string]string {
mirrorMapping := map[string]string{}

mirrorMappingStr = strings.Trim(mirrorMappingStr, " ")
for _, mapping := range strings.Split(mirrorMappingStr, ",") {
parts := strings.Split(mapping, ":")
if len(parts) != 2 {
klog.Errorf("Invalid mirror mapping format: %s", mapping)
continue
}
mirrorMapping[parts[0]] = parts[1]
}
return mirrorMapping
}
88 changes: 84 additions & 4 deletions pkg/credentialprovider/azure_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -137,7 +138,7 @@ func TestGetCredentialsConfig(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error when closing temp file: %v", err)
}
provider, err := NewAcrProviderFromConfig(configFile.Name())
provider, err := NewAcrProviderFromConfig(configFile.Name(), "")
if err != nil && !test.expectError {
t.Fatalf("Unexpected error when creating new acr provider: %v", err)
}
Expand All @@ -163,6 +164,53 @@ func TestGetCredentialsConfig(t *testing.T) {
}
}

func TestProcessImageWithMirrorMapping(t *testing.T) {
configStr := `
{
"aadClientId": "foo",
"aadClientSecret": "bar"
}`

configFile, err := os.CreateTemp(".", "config.json")
assert.Nilf(t, err, "Unexpected error when creating temp file")
defer os.Remove(configFile.Name())
_, err = configFile.WriteString(configStr)
assert.Nilf(t, err, "Unexpected error when writing to temp file")
assert.Nilf(t, configFile.Close(), "Unexpected error when closing temp file")

provider, err := NewAcrProviderFromConfig(configFile.Name(), "mcr.microsoft.com:abc.azurecr.io")
assert.Nilf(t, err, "Unexpected error when creating new acr provider")
acrProvider := provider.(*acrProvider)

testcases := []struct {
description string
image string
expectedLoginServer string
expectedLoginServerMirror string
}{
{
description: "image in mirror mapping map",
image: "mcr.microsoft.com/bar/image:version",
expectedLoginServer: "abc.azurecr.io",
expectedLoginServerMirror: "mcr.microsoft.com",
},
{
description: "image not in mirror mapping map",
image: "foo.azurecr.io/bar/image:version",
expectedLoginServer: "foo.azurecr.io",
expectedLoginServerMirror: "",
},
}

for _, test := range testcases {
t.Run(test.description, func(t *testing.T) {
targetloginServer, sourceloginServer := acrProvider.parseACRLoginServerFromImage(test.image)
assert.Equal(t, targetloginServer, test.expectedLoginServer)
assert.Equal(t, sourceloginServer, test.expectedLoginServerMirror)
})
}
}

func TestParseACRLoginServerFromImage(t *testing.T) {

providerInterface := NewAcrProvider(&config.AzureAuthConfig{
Expand Down Expand Up @@ -215,8 +263,40 @@ func TestParseACRLoginServerFromImage(t *testing.T) {
},
}
for _, test := range tests {
if loginServer := provider.parseACRLoginServerFromImage(test.image); loginServer != test.expected {
t.Errorf("function parseACRLoginServerFromImage returns \"%s\" for image %s, expected \"%s\"", loginServer, test.image, test.expected)
}
t.Run(test.image, func(t *testing.T) {
targetloginServer, _ := provider.parseACRLoginServerFromImage(test.image)
assert.Equal(t, targetloginServer, test.expected)
})
}
}

func TestProcessMirrorMapping(t *testing.T) {
testcases := []struct {
description string
mirrorMappingStr string
expected map[string]string
}{
{
"multiple",
"aaa:bbb,ccc:ddd",
map[string]string{
"aaa": "bbb",
"ccc": "ddd",
},
},
{
"single",
"aaa:bbb",
map[string]string{
"aaa": "bbb",
},
},
}

for _, tc := range testcases {
t.Run(tc.description, func(t *testing.T) {
result := processMirrorMapping(tc.mirrorMappingStr)
assert.True(t, reflect.DeepEqual(result, tc.expected))
})
}
}

0 comments on commit 3d563f5

Please sign in to comment.