diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index edc6230..ba73e12 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -159,10 +159,17 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish return nil, status.Error(codes.InvalidArgument, "request missing required target path") } - if err := unmount(req.TargetPath); err != nil { - return nil, status.Errorf(codes.Internal, "unable to unmount %q: %v", req.TargetPath, err) + // Check if target is a valid mount and issue unmount request + if ok, err := isMountPoint(req.TargetPath); err != nil { + return nil, status.Errorf(codes.Internal, "unable to verify mount point %q: %v", req.TargetPath, err) + } else if ok { + if err := unmount(req.TargetPath); err != nil { + return nil, status.Errorf(codes.Internal, "unable to unmount %q: %v", req.TargetPath, err) + } } - if err := os.Remove(req.TargetPath); err != nil { + + // Check and remove the mount path if present, report an error otherwise + if err := os.Remove(req.TargetPath); err != nil && !errors.Is(err, os.ErrNotExist) { return nil, status.Errorf(codes.Internal, "unable to remove target path %q: %v", req.TargetPath, err) } diff --git a/pkg/driver/driver_test.go b/pkg/driver/driver_test.go index 90e1196..5251d31 100644 --- a/pkg/driver/driver_test.go +++ b/pkg/driver/driver_test.go @@ -24,7 +24,13 @@ import ( ) const ( - testNodeID = "nodeID" + testNodeID = "nodeID" + unmountFailureTest = "unmount failure" + isMountFailureTest = "isMount failure" +) + +var ( + testDescription string ) func init() { @@ -34,6 +40,17 @@ func init() { unmount = func(dst string) error { return os.Remove(metaPath(dst)) } + isMountPoint = func(dst string) (bool, error) { + if testDescription == unmountFailureTest { + return true, nil + } + + if testDescription == isMountFailureTest { + return false, fmt.Errorf("mock invalid mount point") + } + + return true, nil + } } func TestNew(t *testing.T) { @@ -323,7 +340,12 @@ func TestNodeUnpublishVolume(t *testing.T) { expectMsgPrefix: "request missing required target path", }, { - desc: "unmount failure", + desc: isMountFailureTest, + expectCode: codes.Internal, + expectMsgPrefix: "unable to verify mount point", + }, + { + desc: unmountFailureTest, mungeTargetPath: func(t *testing.T, targetPath string) { // Removing the meta file to simulate that it wasn't mounted require.NoError(t, os.Remove(metaPath(targetPath))) @@ -365,6 +387,7 @@ func TestNodeUnpublishVolume(t *testing.T) { if tt.mutateReq != nil { tt.mutateReq(req) } + registerTestDescription(tt.desc) dumpIt(t, "BEFORE", targetPathBase) resp, err := client.NodeUnpublishVolume(context.Background(), req) dumpIt(t, "AFTER", targetPathBase) @@ -379,6 +402,10 @@ func TestNodeUnpublishVolume(t *testing.T) { } } +func registerTestDescription(desc string) { + testDescription = desc +} + func requireGRPCStatusPrefix(tb testing.TB, err error, code codes.Code, msgPrefix string, msgAndArgs ...interface{}) { st := status.Convert(err) if code != st.Code() || !strings.HasPrefix(st.Message(), msgPrefix) {