diff --git a/nomad/host_volume_endpoint.go b/nomad/host_volume_endpoint.go index fcdd8c887b8..193d141ab59 100644 --- a/nomad/host_volume_endpoint.go +++ b/nomad/host_volume_endpoint.go @@ -4,11 +4,13 @@ package nomad import ( + "context" "errors" "fmt" "net/http" "regexp" "strings" + "sync" "time" "github.com/armon/go-metrics" @@ -28,6 +30,9 @@ type HostVolume struct { srv *Server ctx *RPCContext logger hclog.Logger + + // volOps is used to serialize operations per volume ID + volOps sync.Map } func NewHostVolumeEndpoint(srv *Server, ctx *RPCContext) *HostVolume { @@ -263,25 +268,31 @@ func (v *HostVolume) Create(args *structs.HostVolumeCreateRequest, reply *struct return err } - // Attempt to create the volume on the client. - // - // NOTE: creating the volume on the client via the plugin can't be made - // atomic with the registration, and creating the volume provides values we - // want to write on the Volume in raft anyways. - err = v.createVolume(vol) - if err != nil { - return err - } + // serialize client RPC and raft write per volume ID + index, err := v.serializeCall(vol.ID, "create", func() (uint64, error) { + // Attempt to create the volume on the client. + // + // NOTE: creating the volume on the client via the plugin can't be made + // atomic with the registration, and creating the volume provides values + // we want to write on the Volume in raft anyways. + if err = v.createVolume(vol); err != nil { + return 0, err + } - // Write a newly created or modified volume to raft. We create a new request - // here because we've likely mutated the volume. - _, index, err := v.srv.raftApply(structs.HostVolumeRegisterRequestType, - &structs.HostVolumeRegisterRequest{ - Volume: vol, - WriteRequest: args.WriteRequest, - }) + // Write a newly created or modified volume to raft. We create a new + // request here because we've likely mutated the volume. + _, idx, err := v.srv.raftApply(structs.HostVolumeRegisterRequestType, + &structs.HostVolumeRegisterRequest{ + Volume: vol, + WriteRequest: args.WriteRequest, + }) + if err != nil { + v.logger.Error("raft apply failed", "error", err, "method", "register") + return 0, err + } + return idx, nil + }) if err != nil { - v.logger.Error("raft apply failed", "error", err, "method", "register") return err } @@ -356,24 +367,30 @@ func (v *HostVolume) Register(args *structs.HostVolumeRegisterRequest, reply *st return err } - // Attempt to register the volume on the client. - // - // NOTE: registering the volume on the client via the plugin can't be made - // atomic with the registration. - err = v.registerVolume(vol) - if err != nil { - return err - } + // serialize client RPC and raft write per volume ID + index, err := v.serializeCall(vol.ID, "register", func() (uint64, error) { + // Attempt to register the volume on the client. + // + // NOTE: registering the volume on the client via the plugin can't be made + // atomic with the registration. + if err = v.registerVolume(vol); err != nil { + return 0, err + } - // Write a newly created or modified volume to raft. We create a new request - // here because we've likely mutated the volume. - _, index, err := v.srv.raftApply(structs.HostVolumeRegisterRequestType, - &structs.HostVolumeRegisterRequest{ - Volume: vol, - WriteRequest: args.WriteRequest, - }) + // Write a newly created or modified volume to raft. We create a new + // request here because we've likely mutated the volume. + _, idx, err := v.srv.raftApply(structs.HostVolumeRegisterRequestType, + &structs.HostVolumeRegisterRequest{ + Volume: vol, + WriteRequest: args.WriteRequest, + }) + if err != nil { + v.logger.Error("raft apply failed", "error", err, "method", "register") + return 0, err + } + return idx, nil + }) if err != nil { - v.logger.Error("raft apply failed", "error", err, "method", "register") return err } @@ -608,8 +625,6 @@ func (v *HostVolume) Delete(args *structs.HostVolumeDeleteRequest, reply *struct return fmt.Errorf("missing volume ID to delete") } - var index uint64 - snap, err := v.srv.State().Snapshot() if err != nil { return err @@ -631,14 +646,19 @@ func (v *HostVolume) Delete(args *structs.HostVolumeDeleteRequest, reply *struct return fmt.Errorf("volume %s in use by allocations: %v", id, allocIDs) } - err = v.deleteVolume(vol) - if err != nil { - return err - } - - _, index, err = v.srv.raftApply(structs.HostVolumeDeleteRequestType, args) + // serialize client RPC and raft write per volume ID + index, err := v.serializeCall(vol.ID, "delete", func() (uint64, error) { + if err := v.deleteVolume(vol); err != nil { + return 0, err + } + _, idx, err := v.srv.raftApply(structs.HostVolumeDeleteRequestType, args) + if err != nil { + v.logger.Error("raft apply failed", "error", err, "method", "delete") + return 0, err + } + return idx, nil + }) if err != nil { - v.logger.Error("raft apply failed", "error", err, "method", "delete") return err } @@ -665,3 +685,46 @@ func (v *HostVolume) deleteVolume(vol *structs.HostVolume) error { return nil } + +// serializeCall serializes fn() per volume, so DHV plugins can assume that +// Nomad will not run concurrent operations for the same volume, and for us +// to avoid interleaving client RPCs with raft writes. +// Concurrent calls should all run eventually (or timeout, or server shutdown), +// but there is no guarantee that they will run in the order received. +// The passed fn is expected to return a raft index and error. +func (v *HostVolume) serializeCall(volumeID, op string, fn func() (uint64, error)) (uint64, error) { + timeout := 2 * time.Minute // 2x the client RPC timeout + for { + ctx, done := context.WithTimeout(v.srv.shutdownCtx, timeout) + + loaded, occupied := v.volOps.LoadOrStore(volumeID, ctx) + + if !occupied { + v.logger.Trace("HostVolume RPC running ", "operation", op) + // run the fn! + index, err := fn() + + // done() must come after Delete, so that other unblocked requests + // will Store a fresh context when they continue. + v.volOps.Delete(volumeID) + done() + + return index, err + } + + // another one is running; wait for it to finish. + v.logger.Trace("HostVolume RPC waiting", "operation", op) + + // cancel the tentative context; we'll use the one we pulled from + // volOps (set by another RPC call) instead. + done() + + otherCtx := loaded.(context.Context) + select { + case <-otherCtx.Done(): + continue + case <-v.srv.shutdownCh: + return 0, structs.ErrNoLeader + } + } +} diff --git a/nomad/host_volume_endpoint_test.go b/nomad/host_volume_endpoint_test.go index ae5bf08c924..7fbdd457e9b 100644 --- a/nomad/host_volume_endpoint_test.go +++ b/nomad/host_volume_endpoint_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/hashicorp/go-multierror" + "github.com/hashicorp/go-set/v3" msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client" @@ -23,6 +25,7 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/testutil" "github.com/hashicorp/nomad/version" + "github.com/shoenig/test" "github.com/shoenig/test/must" "github.com/shoenig/test/wait" ) @@ -763,6 +766,193 @@ func TestHostVolumeEndpoint_placeVolume(t *testing.T) { } } +// TestHostVolumeEndpoint_concurrency checks that create/register/delete RPC +// calls can not run concurrently for a single volume. +func TestHostVolumeEndpoint_concurrency(t *testing.T) { + ci.Parallel(t) + + srv, cleanup := TestServer(t, func(c *Config) { c.NumSchedulers = 0 }) + t.Cleanup(cleanup) + testutil.WaitForLeader(t, srv.RPC) + + c, node := newMockHostVolumeClient(t, srv, "default") + + vol := &structs.HostVolume{ + Name: "test-vol", + Namespace: "default", + NodeID: node.ID, + PluginID: "mkdir", + HostPath: "/pretend/path", + RequestedCapabilities: []*structs.HostVolumeCapability{ + { + AttachmentMode: structs.HostVolumeAttachmentModeFilesystem, + AccessMode: structs.HostVolumeAccessModeSingleNodeWriter, + }, + }, + } + wr := structs.WriteRequest{Region: srv.Region()} + + // tell the mock client how it should respond to create calls + c.setCreate(&cstructs.ClientHostVolumeCreateResponse{ + VolumeName: "test-vol", + HostPath: "/pretend/path", + }, nil) + + // create the volume for us to attempt concurrent operations on + cVol := vol.Copy() // copy because HostPath gets mutated + cVol.Parameters = map[string]string{"created": "initial"} + createReq := &structs.HostVolumeCreateRequest{ + Volume: cVol, + WriteRequest: wr, + } + var createResp structs.HostVolumeCreateResponse + must.NoError(t, srv.RPC("HostVolume.Create", createReq, &createResp)) + got, err := srv.State().HostVolumeByID(nil, vol.Namespace, createResp.Volume.ID, false) + must.NoError(t, err) + must.Eq(t, map[string]string{"created": "initial"}, got.Parameters) + + // warning: below here be (concurrency) dragons. if this test fails, + // it is rather difficult to troubleshoot. sorry! + + // this is critical -- everything needs to use the same volume ID, + // because that's what the serialization is based on. + vol.ID = createResp.Volume.ID + + // "create" volume #2 (same vol except for parameters) + cVol2 := vol.Copy() + cVol2.Parameters = map[string]string{"created": "again"} + // "register" volume + rVol := vol.Copy() + rVol.Parameters = map[string]string{"registered": "yup"} + + // prepare the mock client to block its calls, and get a CancelFunc + // to make sure we don't get any deadlocked client RPCs. + cancelClientRPCBlocks, err := c.setBlockChan() + must.NoError(t, err) + + // each operation goroutine will put its name in here when it completes, + // so we can wait until the whole RPC completes before checking state. + rpcDoneCh := make(chan string) + rpcDone := func(op string) { + t.Helper() + select { + case rpcDoneCh <- op: + case <-time.After(time.Second): + t.Errorf("timed out writing %q to rpcDoneCh", op) + } + } + + // start all the RPCs concurrently + var funcs multierror.Group + // create + funcs.Go(func() error { + createReq = &structs.HostVolumeCreateRequest{ + Volume: cVol2, + WriteRequest: wr, + } + createResp = structs.HostVolumeCreateResponse{} + err := srv.RPC("HostVolume.Create", createReq, &createResp) + rpcDone("create") + return err + }) + // register + funcs.Go(func() error { + registerReq := &structs.HostVolumeRegisterRequest{ + Volume: rVol, + WriteRequest: wr, + } + var registerResp structs.HostVolumeRegisterResponse + err := srv.RPC("HostVolume.Register", registerReq, ®isterResp) + rpcDone("register") + return err + }) + // delete + funcs.Go(func() error { + deleteReq := &structs.HostVolumeDeleteRequest{ + VolumeID: vol.ID, + WriteRequest: wr, + } + var deleteResp structs.HostVolumeDeleteResponse + err := srv.RPC("HostVolume.Delete", deleteReq, &deleteResp) + rpcDone("delete") + return err + }) + + // NOTE: below here, we avoid `must` methods, because a t.Fatal causes all + // the above goroutines to halt with confusing errors. + + // keep track of which operations have completed + opSet := set.From([]string{"create", "register", "delete"}) + +LOOP: + for { + if opSet.Empty() { + break // all done! + } + + // unblock a client RPC; it will tell us which one it let through. + op, err := c.unblockCurrent() + if err != nil { + t.Errorf("error unblocking client RPC: %v", err) + break + } + + if !opSet.Remove(op) { + t.Errorf("mystery unblocked RPC operation: %q", op) + break + } + + // make sure the server RPC has totally completed (and written state), + // and that the server RPC matches the unblocked client RPC. + select { + case serverOp := <-rpcDoneCh: + if serverOp != op { + t.Errorf("client RPC says %q; server RPC says %q", op, serverOp) + continue + } + case <-time.After(time.Second): + t.Error("timeout waiting for an RPC to finish") + break LOOP + } + + // get the volume to check + got, err := srv.State().HostVolumeByID(nil, vol.Namespace, vol.ID, false) + if err != nil { + t.Errorf("error reading state: %v", err) + break + } + + switch op { + + case "create": + if got == nil { + t.Error("volume should not be nil after create RPC") + continue + } + test.Eq(t, cVol2.Parameters, got.Parameters) + + case "register": + if got == nil { + t.Error("volume should not be nil after register RPC") + continue + } + test.Eq(t, rVol.Parameters, got.Parameters) + + case "delete": + test.Nil(t, got, test.Sprint("")) + } + } + + // everything should be done by now, but just in case. + cancelClientRPCBlocks() + + mErr := funcs.Wait() + test.NoError(t, helper.FlattenMultierror(mErr)) + + // all of 'em should have happened! + test.Eq(t, []string{}, opSet.Slice(), test.Sprint("remaining opSet should be empty")) +} + // mockHostVolumeClient models client RPCs that have side-effects on the // client host type mockHostVolumeClient struct { @@ -771,6 +961,11 @@ type mockHostVolumeClient struct { nextCreateErr error nextRegisterErr error nextDeleteErr error + // blockChan is used to test server->client RPC serialization. + // do not block on this channel while the main lock is held. + blockChan chan string + // shutdownCtx is an escape hatch to release any/all blocked RPCs + shutdownCtx context.Context } // newMockHostVolumeClient configures a RPC-only Nomad test agent and returns a @@ -822,6 +1017,11 @@ func (v *mockHostVolumeClient) setDelete(errMsg string) { func (v *mockHostVolumeClient) Create( req *cstructs.ClientHostVolumeCreateRequest, resp *cstructs.ClientHostVolumeCreateResponse) error { + + if err := v.block("create"); err != nil { + return err + } + v.lock.Lock() defer v.lock.Unlock() if v.nextCreateResponse == nil { @@ -834,6 +1034,11 @@ func (v *mockHostVolumeClient) Create( func (v *mockHostVolumeClient) Register( req *cstructs.ClientHostVolumeRegisterRequest, resp *cstructs.ClientHostVolumeRegisterResponse) error { + + if err := v.block("register"); err != nil { + return err + } + v.lock.Lock() defer v.lock.Unlock() *resp = cstructs.ClientHostVolumeRegisterResponse{} @@ -843,7 +1048,62 @@ func (v *mockHostVolumeClient) Register( func (v *mockHostVolumeClient) Delete( req *cstructs.ClientHostVolumeDeleteRequest, resp *cstructs.ClientHostVolumeDeleteResponse) error { + + if err := v.block("delete"); err != nil { + return err + } + v.lock.Lock() defer v.lock.Unlock() return v.nextDeleteErr } + +func (v *mockHostVolumeClient) setBlockChan() (context.CancelFunc, error) { + v.lock.Lock() + defer v.lock.Unlock() + if v.blockChan != nil { + return nil, errors.New("blockChan already set") + } + v.blockChan = make(chan string) // no buffer to ensure blockage + // timeout context to ensure blockage is not endless + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + v.shutdownCtx = ctx + return cancel, nil +} + +func (v *mockHostVolumeClient) getBlockChan() chan string { + v.lock.Lock() + defer v.lock.Unlock() + return v.blockChan +} + +// block stalls the RPC until something (a test) runs unblockCurrent, +// if something (a test) had previously run setBlockChan to set it up. +func (v *mockHostVolumeClient) block(op string) error { + bc := v.getBlockChan() + if bc == nil { + return nil + } + select { + case bc <- op: + return nil + case <-v.shutdownCtx.Done(): + // if this happens, it'll be because unblockCurrent was not run enough + return fmt.Errorf("shutdownCtx done before blockChan unblocked: %w", v.shutdownCtx.Err()) + } +} + +// unblockCurrent reads from blockChan to unblock a running RPC. +// it must be run once per RPC that is started. +func (v *mockHostVolumeClient) unblockCurrent() (string, error) { + bc := v.getBlockChan() + if bc == nil { + return "", errors.New("no blockChan") + } + select { + case current := <-bc: + return current, nil + case <-time.After(time.Second): + return "", errors.New("unblockCurrent timeout") + } +}