Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cvvz committed Dec 27, 2023
1 parent 44c4812 commit 8388a01
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pkg/azurefile/azurefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r

// if client id is specified, we only use service account token to get account key
if clientID != "" {
klog.Info("clientID is specified, use service account token to get account key")
klog.V(2).Infof("clientID(%s) is specified, use service account token to get account key", clientID)
accountKey, err := d.cloud.GetStorageAccesskeyFromServiceAccountToken(ctx, subsID, accountName, rgName, clientID, tenantID, serviceAccountToken)
return rgName, accountName, accountKey, fileShareName, diskName, subsID, err
}
Expand Down
14 changes: 7 additions & 7 deletions pkg/azurefile/nodeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu
context := req.GetVolumeContext()
if context != nil {
// token request
if context[serviceAccountTokenField] != "" && hasClientID(context) {
klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, VolumeContext: %v", volumeID, target, context)
if context[serviceAccountTokenField] != "" && getClientID(context) != "" {
klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, clientID: %s", volumeID, target, getClientID(context))
_, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{
StagingTargetPath: target,
VolumeContext: context,
Expand Down Expand Up @@ -169,8 +169,8 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe
volumeID := req.GetVolumeId()
context := req.GetVolumeContext()

if hasClientID(context) && context[serviceAccountTokenField] == "" {
klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID is provided but service account token is empty", volumeID)
if getClientID(context) != "" && context[serviceAccountTokenField] == "" {
klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID %s is provided but service account token is empty", volumeID, getClientID(context))
return &csi.NodeStageVolumeResponse{}, nil
}

Expand Down Expand Up @@ -613,11 +613,11 @@ func checkGidPresentInMountFlags(mountFlags []string) bool {
return false
}

func hasClientID(context map[string]string) bool {
func getClientID(context map[string]string) string {
for k, v := range context {
if strings.EqualFold(k, clientIDField) && v != "" {
return true
return v
}
}
return false
return ""
}
18 changes: 9 additions & 9 deletions pkg/azurefile/nodeserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1068,23 +1068,23 @@ func makeFakeOutput(output string, err error) testingexec.FakeAction {
}
}

func Test_hasClientID(t *testing.T) {
func Test_getClientID(t *testing.T) {
type args struct {
context map[string]string
}
tests := []struct {
name string
args args
want bool
want string
}{
{
name: "has client id",
name: "get client id",
args: args{
context: map[string]string{
clientIDField: "test-client-id",
},
},
want: true,
want: "test-client-id",
},
{
name: "case not sensitive client id",
Expand All @@ -1093,14 +1093,14 @@ func Test_hasClientID(t *testing.T) {
"ClientId": "test-client-id",
},
},
want: true,
want: "test-client-id",
},
{
name: "no client id",
args: args{
context: map[string]string{},
},
want: false,
want: "",
},
{
name: "client id empty",
Expand All @@ -1109,13 +1109,13 @@ func Test_hasClientID(t *testing.T) {
clientIDField: "",
},
},
want: false,
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := hasClientID(tt.args.context); got != tt.want {
t.Errorf("hasClientID() = %v, want %v", got, tt.want)
if got := getClientID(tt.args.context); got != tt.want {
t.Errorf("getClientID() = %v, want %v", got, tt.want)
}
})
}
Expand Down

0 comments on commit 8388a01

Please sign in to comment.