diff --git a/pkg/azurefile/azure_common_darwin.go b/pkg/azurefile/azure_common_darwin.go index 127c06e734..e640e4dff4 100644 --- a/pkg/azurefile/azure_common_darwin.go +++ b/pkg/azurefile/azure_common_darwin.go @@ -31,6 +31,10 @@ func SMBMount(m *mount.SafeFormatAndMount, source, target, fsType string, option return nil } +func SMBUnmount(m *mount.SafeFormatAndMount, target string, _, _ bool) error { + return nil +} + func CleanupMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool) error { return nil } diff --git a/pkg/azurefile/azure_common_linux.go b/pkg/azurefile/azure_common_linux.go index 43be1a3999..b3e98cf4af 100644 --- a/pkg/azurefile/azure_common_linux.go +++ b/pkg/azurefile/azure_common_linux.go @@ -33,6 +33,10 @@ func SMBMount(m *mount.SafeFormatAndMount, source, target, fsType string, option return m.MountSensitive(source, target, fsType, options, sensitiveMountOptions) } +func SMBUnmount(m *mount.SafeFormatAndMount, target string, _, _ bool) error { + return mount.CleanupMountPoint(target, m.Interface, true /*extensiveMountPointCheck*/) +} + func CleanupMountPoint(m *mount.SafeFormatAndMount, target string, _ bool) error { return mount.CleanupMountPoint(target, m.Interface, true /*extensiveMountPointCheck*/) } diff --git a/pkg/azurefile/azure_common_windows.go b/pkg/azurefile/azure_common_windows.go index 9a90225b6f..236a811be6 100644 --- a/pkg/azurefile/azure_common_windows.go +++ b/pkg/azurefile/azure_common_windows.go @@ -39,6 +39,16 @@ func SMBMount(m *mount.SafeFormatAndMount, source, target, fsType string, mountO return fmt.Errorf("could not cast to csi proxy class") } +func SMBUnmount(m *mount.SafeFormatAndMount, target string, extensiveMountCheck, removeSMBMountOnWindows bool) error { + if proxy, ok := m.Interface.(mounter.CSIProxyMounter); ok { + if removeSMBMountOnWindows { + return proxy.Unmount(target) + } + return proxy.Rmdir(target) + } + return fmt.Errorf("could not cast to csi proxy class") +} + func CleanupMountPoint(m *mount.SafeFormatAndMount, target string, extensiveMountCheck bool) error { if proxy, ok := m.Interface.(mounter.CSIProxyMounter); ok { return proxy.Rmdir(target) diff --git a/pkg/azurefile/azurefile.go b/pkg/azurefile/azurefile.go index 57ed55c901..c672c1a80e 100644 --- a/pkg/azurefile/azurefile.go +++ b/pkg/azurefile/azurefile.go @@ -220,6 +220,7 @@ type Driver struct { kubeAPIQPS float64 kubeAPIBurst int enableWindowsHostProcess bool + removeSMBMountOnWindows bool appendClosetimeoOption bool appendNoShareSockOption bool appendNoResvPortOption bool @@ -287,6 +288,7 @@ func NewDriver(options *DriverOptions) *Driver { driver.kubeAPIQPS = options.KubeAPIQPS driver.kubeAPIBurst = options.KubeAPIBurst driver.enableWindowsHostProcess = options.EnableWindowsHostProcess + driver.removeSMBMountOnWindows = options.RemoveSMBMountOnWindows driver.appendClosetimeoOption = options.AppendClosetimeoOption driver.appendNoShareSockOption = options.AppendNoShareSockOption driver.appendNoResvPortOption = options.AppendNoResvPortOption diff --git a/pkg/azurefile/azurefile_options.go b/pkg/azurefile/azurefile_options.go index b9f06478e5..a0ea5808f5 100644 --- a/pkg/azurefile/azurefile_options.go +++ b/pkg/azurefile/azurefile_options.go @@ -37,6 +37,7 @@ type DriverOptions struct { KubeAPIQPS float64 KubeAPIBurst int EnableWindowsHostProcess bool + RemoveSMBMountOnWindows bool AppendClosetimeoOption bool AppendNoShareSockOption bool AppendNoResvPortOption bool @@ -72,6 +73,7 @@ func (o *DriverOptions) AddFlags() *flag.FlagSet { fs.Float64Var(&o.KubeAPIQPS, "kube-api-qps", 25.0, "QPS to use while communicating with the kubernetes apiserver.") fs.IntVar(&o.KubeAPIBurst, "kube-api-burst", 50, "Burst to use while communicating with the kubernetes apiserver.") fs.BoolVar(&o.EnableWindowsHostProcess, "enable-windows-host-process", false, "enable windows host process") + fs.BoolVar(&o.RemoveSMBMountOnWindows, "remove-smb-mount-on-windows", true, "remove smb global mapping on windows during unmount") fs.BoolVar(&o.AppendClosetimeoOption, "append-closetimeo-option", false, "Whether appending closetimeo=0 option to smb mount command") fs.BoolVar(&o.AppendNoShareSockOption, "append-nosharesock-option", true, "Whether appending nosharesock option to smb mount command") fs.BoolVar(&o.AppendNoResvPortOption, "append-noresvport-option", true, "Whether appending noresvport option to nfs mount command") diff --git a/pkg/azurefile/nodeserver.go b/pkg/azurefile/nodeserver.go index abbc05f57a..6699bee693 100644 --- a/pkg/azurefile/nodeserver.go +++ b/pkg/azurefile/nodeserver.go @@ -423,15 +423,17 @@ func (d *Driver) NodeUnstageVolume(_ context.Context, req *csi.NodeUnstageVolume mc.ObserveOperationWithResult(isOperationSucceeded, VolumeID, volumeID) }() - klog.V(2).Infof("NodeUnstageVolume: CleanupMountPoint volume %s on %s", volumeID, stagingTargetPath) - if err := CleanupMountPoint(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/); err != nil { + klog.V(2).Infof("NodeUnstageVolume: unmount volume %s on %s", volumeID, stagingTargetPath) + if err := SMBUnmount(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/, d.removeSMBMountOnWindows); err != nil { return nil, status.Errorf(codes.Internal, "failed to unmount staging target %s: %v", stagingTargetPath, err) } - targetPath := filepath.Join(filepath.Dir(stagingTargetPath), proxyMount) - klog.V(2).Infof("NodeUnstageVolume: CleanupMountPoint volume %s on %s", volumeID, targetPath) - if err := CleanupMountPoint(d.mounter, targetPath, false); err != nil { - return nil, status.Errorf(codes.Internal, "failed to unmount staging target %s: %v", targetPath, err) + if runtime.GOOS != "windows" { + targetPath := filepath.Join(filepath.Dir(stagingTargetPath), proxyMount) + klog.V(2).Infof("NodeUnstageVolume: CleanupMountPoint volume %s on %s", volumeID, targetPath) + if err := CleanupMountPoint(d.mounter, targetPath, false); err != nil { + return nil, status.Errorf(codes.Internal, "failed to unmount staging target %s: %v", targetPath, err) + } } klog.V(2).Infof("NodeUnstageVolume: unmount volume %s on %s successfully", volumeID, stagingTargetPath) diff --git a/pkg/mounter/safe_mounter_host_process_windows.go b/pkg/mounter/safe_mounter_host_process_windows.go index 77e51f8068..6030cb5b73 100644 --- a/pkg/mounter/safe_mounter_host_process_windows.go +++ b/pkg/mounter/safe_mounter_host_process_windows.go @@ -33,6 +33,8 @@ import ( "sigs.k8s.io/azurefile-csi-driver/pkg/os/smb" ) +var driverGlobalMountPath = "C:\\var\\lib\\kubelet\\plugins\\kubernetes.io\\csi\\file.csi.azure.com" + var _ CSIProxyMounter = &winMounter{} type winMounter struct{} @@ -116,11 +118,6 @@ func (mounter *winMounter) SMBMount(source, target, fsType string, mountOptions, return nil } -func (mounter *winMounter) SMBUnmount(target string) error { - klog.V(2).Infof("SMBUnmount: local path: %s", target) - return mounter.Rmdir(target) -} - // Mount just creates a soft link at target pointing to source. func (mounter *winMounter) Mount(source, target, fstype string, options []string) error { return filesystem.LinkPath(normalizeWindowsPath(source), normalizeWindowsPath(target)) @@ -133,7 +130,28 @@ func (mounter *winMounter) Rmdir(path string) error { // Unmount - Removes the directory - equivalent to unmount on Linux. func (mounter *winMounter) Unmount(target string) error { - klog.V(2).Infof("Unmount: %s", target) + target = normalizeWindowsPath(target) + remoteServer, err := smb.GetRemoteServerFromTarget(target) + if err == nil { + klog.V(2).Infof("remote server path: %s, local path: %s", remoteServer, target) + hasDupSMBMount, err := smb.CheckForDuplicateSMBMounts(driverGlobalMountPath, target, remoteServer) + if err == nil { + if !hasDupSMBMount { + remoteServer = strings.Replace(remoteServer, "UNC\\", "\\\\", 1) + if err := smb.RemoveSmbGlobalMapping(remoteServer); err != nil { + klog.Errorf("RemoveSmbGlobalMapping(%s) failed with %v", target, err) + } + } else { + klog.V(2).Infof("skip unmount as there are other SMB mounts on the same remote server %s", remoteServer) + } + } else { + klog.Errorf("CheckForDuplicateSMBMounts(%s, %s) failed with %v", target, remoteServer, err) + } + } else { + klog.Errorf("GetRemoteServerFromTarget(%s) failed with %v", target, err) + } + + klog.V(2).Infof("Unmount: remote path: %s local path: %s", remoteServer, target) return mounter.Rmdir(target) } diff --git a/pkg/mounter/safe_mounter_v1beta_windows.go b/pkg/mounter/safe_mounter_v1beta_windows.go index b202cc4d4a..409dcce0eb 100644 --- a/pkg/mounter/safe_mounter_v1beta_windows.go +++ b/pkg/mounter/safe_mounter_v1beta_windows.go @@ -91,13 +91,6 @@ func (mounter *csiProxyMounterV1Beta) SMBMount(source, target, fsType string, mo return nil } -func (mounter *csiProxyMounterV1Beta) SMBUnmount(target string) error { - klog.V(4).Infof("SMBUnmount: local path: %s", target) - // TODO: We need to remove the SMB mapping. The change to remove the - // directory brings the CSI code in parity with the in-tree. - return mounter.Rmdir(target) -} - // Mount just creates a soft link at target pointing to source. func (mounter *csiProxyMounterV1Beta) Mount(source string, target string, fstype string, options []string) error { klog.V(4).Infof("Mount: old name: %s. new name: %s", source, target) diff --git a/pkg/os/smb/smb.go b/pkg/os/smb/smb.go index be9809620b..30fb9cd479 100644 --- a/pkg/os/smb/smb.go +++ b/pkg/os/smb/smb.go @@ -18,6 +18,8 @@ package smb import ( "fmt" + "os" + "path/filepath" "strings" "k8s.io/klog/v2" @@ -62,3 +64,46 @@ func RemoveSmbGlobalMapping(remotePath string) error { } return nil } + +// GetRemoteServerFromTarget- gets the remote server path given a mount point, the function is recursive until it find the remote server or errors out +func GetRemoteServerFromTarget(mount string) (string, error) { + cmd := "(Get-Item -Path $Env:mount).Target" + out, err := util.RunPowershellCmd(cmd, fmt.Sprintf("mount=%s", mount)) + if err != nil || len(out) == 0 { + return "", fmt.Errorf("error getting volume from mount. cmd: %s, output: %s, error: %v", cmd, string(out), err) + } + return strings.TrimSpace(string(out)), nil +} + +// CheckForDuplicateSMBMounts checks if there is any other SMB mount exists on the same remote server +func CheckForDuplicateSMBMounts(dir, mount, remoteServer string) (bool, error) { + files, err := os.ReadDir(dir) + if err != nil { + return false, err + } + + for _, file := range files { + klog.V(6).Infof("checking file %s", file.Name()) + if file.IsDir() { + globalMountPath := filepath.Join(dir, file.Name(), "globalmount") + if strings.EqualFold(filepath.Clean(globalMountPath), filepath.Clean(mount)) { + klog.V(2).Infof("skip current mount path %s", mount) + } else { + fileInfo, err := os.Lstat(globalMountPath) + // check if the file is a symlink, if yes, check if it is pointing to the same remote server + if err == nil && fileInfo.Mode()&os.ModeSymlink != 0 { + remoteServerPath, err := GetRemoteServerFromTarget(globalMountPath) + klog.V(2).Infof("checking remote server path %s on local path %s", remoteServerPath, globalMountPath) + if err == nil { + if remoteServerPath == remoteServer { + return true, nil + } + } else { + klog.Errorf("GetRemoteServerFromTarget(%s) failed with %v", globalMountPath, err) + } + } + } + } + } + return false, err +} diff --git a/pkg/os/smb/smb_test.go b/pkg/os/smb/smb_test.go new file mode 100644 index 0000000000..98dc67406c --- /dev/null +++ b/pkg/os/smb/smb_test.go @@ -0,0 +1,59 @@ +//go:build windows +// +build windows + +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package smb + +import ( + "fmt" + "testing" +) + +func TestCheckForDuplicateSMBMounts(t *testing.T) { + tests := []struct { + name string + dir string + mount string + remoteServer string + expectedResult bool + expectedError error + }{ + { + name: "directory does not exist", + dir: "non-existing-mount", + expectedResult: false, + expectedError: fmt.Errorf("open non-existing-mount: The system cannot find the file specified."), + }, + } + + for _, test := range tests { + result, err := CheckForDuplicateSMBMounts(test.dir, test.mount, test.remoteServer) + if result != test.expectedResult { + t.Errorf("Expected %v, got %v", test.expectedResult, result) + } + if err == nil && test.expectedError != nil { + t.Errorf("Expected error %v, got nil", test.expectedError) + } + if err != nil && test.expectedError == nil { + t.Errorf("Expected nil, got %v", err) + } + if err != nil && test.expectedError != nil && err.Error() != test.expectedError.Error() { + t.Errorf("Expected error %v, got %v", test.expectedError, err) + } + } +}