diff --git a/chain/backend.go b/chain/backend.go index 3a56e1b57..e514850e6 100644 --- a/chain/backend.go +++ b/chain/backend.go @@ -87,3 +87,24 @@ func (s *Syncer) ExistsLiveTickets(ctx context.Context, tickets []*chainhash.Has func (s *Syncer) UsedAddresses(ctx context.Context, addrs []stdaddr.Address) (bitset.Bytes, error) { return s.rpc.UsedAddresses(ctx, addrs) } + +func (s *Syncer) Done() <-chan struct{} { + s.doneMu.Lock() + c := s.done + s.doneMu.Unlock() + return c +} + +func (s *Syncer) Err() error { + s.doneMu.Lock() + c := s.done + err := s.err + s.doneMu.Unlock() + + select { + case <-c: + return err + default: + return nil + } +} diff --git a/chain/sync.go b/chain/sync.go index 82f16fad7..ca2996f51 100644 --- a/chain/sync.go +++ b/chain/sync.go @@ -55,6 +55,10 @@ type Syncer struct { relevantTxs map[chainhash.Hash][]*wire.MsgTx cb *Callbacks + + done chan struct{} + err error + doneMu sync.Mutex } // RPCOptions specifies the network and security settings for establishing a @@ -525,6 +529,17 @@ func (s *Syncer) Run(ctx context.Context) (err error) { } }() + s.doneMu.Lock() + s.done = make(chan struct{}) + s.err = nil + s.doneMu.Unlock() + defer func() { + s.doneMu.Lock() + close(s.done) + s.err = err + s.doneMu.Unlock() + }() + params := s.wallet.ChainParams() s.notifier = ¬ifier{ diff --git a/spv/backend.go b/spv/backend.go index 4a6d90397..ca6ff64c1 100644 --- a/spv/backend.go +++ b/spv/backend.go @@ -619,3 +619,24 @@ func (s *Syncer) Rescan(ctx context.Context, blockHashes []chainhash.Hash, save func (s *Syncer) StakeDifficulty(ctx context.Context) (dcrutil.Amount, error) { return 0, errors.E(errors.Invalid, "stake difficulty is not queryable over wire protocol") } + +func (s *Syncer) Done() <-chan struct{} { + s.doneMu.Lock() + c := s.done + s.doneMu.Unlock() + return c +} + +func (s *Syncer) Err() error { + s.doneMu.Lock() + c := s.done + err := s.err + s.doneMu.Unlock() + + select { + case <-c: + return err + default: + return nil + } +} diff --git a/spv/sync.go b/spv/sync.go index c626346c7..1cf9a57c3 100644 --- a/spv/sync.go +++ b/spv/sync.go @@ -91,6 +91,10 @@ type Syncer struct { // Mempool for non-wallet-relevant transactions. mempool sync.Map // k=chainhash.Hash v=*wire.MsgTx mempoolAdds chan *chainhash.Hash + + done chan struct{} + err error + doneMu sync.Mutex } // Notifications struct to contain all of the upcoming callbacks that will @@ -318,7 +322,18 @@ func (s *Syncer) setRequiredHeight(tipHeight int32) { // Run synchronizes the wallet, returning when synchronization fails or the // context is cancelled. -func (s *Syncer) Run(ctx context.Context) error { +func (s *Syncer) Run(ctx context.Context) (err error) { + s.doneMu.Lock() + s.done = make(chan struct{}) + s.err = nil + s.doneMu.Unlock() + defer func() { + s.doneMu.Lock() + close(s.done) + s.err = err + s.doneMu.Unlock() + }() + tipHash, tipHeight := s.wallet.MainChainTip(ctx) s.setRequiredHeight(tipHeight) rescanPoint, err := s.wallet.RescanPoint(ctx) diff --git a/ticketbuyer/tb.go b/ticketbuyer/tb.go index b74321fe5..a7b9f8a53 100644 --- a/ticketbuyer/tb.go +++ b/ticketbuyer/tb.go @@ -216,6 +216,8 @@ func (tb *TB) buy(ctx context.Context, passphrase []byte, tip *wire.BlockHeader, if err != nil { return err } + ctx, cancel := wallet.WrapNetworkBackendContext(n, ctx) + defer cancel() if len(passphrase) > 0 { // Ensure wallet is unlocked with the current passphrase. If the passphase diff --git a/wallet/mixing.go b/wallet/mixing.go index aa031a838..37c51e3ac 100644 --- a/wallet/mixing.go +++ b/wallet/mixing.go @@ -275,6 +275,13 @@ func (w *Wallet) MixOutput(ctx context.Context, output *wire.OutPoint, changeAcc return errors.E(op, errors.Invalid, s) } + nb, err := w.NetworkBackend() + if err != nil { + return err + } + ctx, cancel := WrapNetworkBackendContext(nb, ctx) + defer cancel() + sdiff, err := w.NextStakeDifficulty(ctx) if err != nil { return errors.E(op, err) diff --git a/wallet/network.go b/wallet/network.go index 511dbb587..6b94ea5b0 100644 --- a/wallet/network.go +++ b/wallet/network.go @@ -6,6 +6,7 @@ package wallet import ( "context" + "sync" "decred.org/dcrwallet/v5/errors" "github.com/decred/dcrd/chaincfg/chainhash" @@ -49,6 +50,12 @@ type NetworkBackend interface { // the wallet to the underlying network, and if not, it returns the // target height that it is attempting to sync to. Synced(ctx context.Context) (bool, int32) + + // Done return a channel that is closed after the syncer disconnects. + // The error (if any) can be returned via Err. + // These semantics match that of context.Context. + Done() <-chan struct{} + Err() error } // NetworkBackend returns the currently associated network backend of the @@ -73,6 +80,47 @@ func (w *Wallet) SetNetworkBackend(n NetworkBackend) { w.networkBackendMu.Unlock() } +type networkContext struct { + context.Context + err error + mu sync.Mutex +} + +func (c *networkContext) Err() error { + c.mu.Lock() + err := c.err + c.mu.Unlock() + + if err != nil { + return err + } + return c.Context.Err() +} + +// WrapNetworkBackendContext returns a derived context that is canceled when +// the NetworkBackend is disconnected. The cancel func must be called +// (e.g. using defer) otherwise a goroutine leak may occur. +func WrapNetworkBackendContext(nb NetworkBackend, ctx context.Context) (context.Context, context.CancelFunc) { + childCtx, cancel := context.WithCancel(ctx) + nbContext := &networkContext{ + Context: childCtx, + } + + go func() { + select { + case <-nb.Done(): + err := nb.Err() + nbContext.mu.Lock() + nbContext.err = err + nbContext.mu.Unlock() + case <-childCtx.Done(): + } + cancel() + }() + + return nbContext, cancel +} + // Caller provides a client interface to perform remote procedure calls. // Serialization and calling conventions are implementation-specific. type Caller interface { @@ -122,6 +170,20 @@ func (o OfflineNetworkBackend) Synced(ctx context.Context) (bool, int32) { return true, 0 } +var closedDone = make(chan struct{}) + +func init() { + close(closedDone) +} + +func (o OfflineNetworkBackend) Done() <-chan struct{} { + return closedDone +} + +func (o OfflineNetworkBackend) Err() error { + return errors.E("offline") +} + // Compile time check to ensure OfflineNetworkBackend fulfills the // NetworkBackend interface. var _ NetworkBackend = OfflineNetworkBackend{} diff --git a/wallet/network_test.go b/wallet/network_test.go index 10966f6c5..eb75a7202 100644 --- a/wallet/network_test.go +++ b/wallet/network_test.go @@ -35,3 +35,5 @@ func (mockNetwork) Rescan(ctx context.Context, blocks []chainhash.Hash, save fun } func (mockNetwork) StakeDifficulty(ctx context.Context) (dcrutil.Amount, error) { return 0, nil } func (mockNetwork) Synced(ctx context.Context) (bool, int32) { return false, 0 } +func (mockNetwork) Done() <-chan struct{} { return nil } +func (mockNetwork) Err() error { return nil } diff --git a/wallet/wallet.go b/wallet/wallet.go index c0eff9e76..0f9dda294 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -1588,6 +1588,9 @@ func (w *Wallet) PurchaseTickets(ctx context.Context, n NetworkBackend, const op errors.Op = "wallet.PurchaseTickets" + ctx, cancel := WrapNetworkBackendContext(n, ctx) + defer cancel() + resp, err := w.purchaseTickets(ctx, op, n, req) if err == nil || !errors.Is(err, errVSPFeeRequiresUTXOSplit) || req.DontSignTx { return resp, err