From 64f159121d83ce31f52acc98a36cdcc5e9d0d467 Mon Sep 17 00:00:00 2001 From: Say Cheong Date: Mon, 18 Nov 2024 04:49:38 +0900 Subject: [PATCH 1/3] Removed ReplyMessageAsync; changed to use internal channels --- src/Libplanet.Net/AsyncDelegate.cs | 16 ++- src/Libplanet.Net/Consensus/Gossip.cs | 133 ++++++++++-------- .../Protocols/KademliaProtocol.cs | 22 ++- src/Libplanet.Net/Swarm.Evidence.cs | 5 +- src/Libplanet.Net/Swarm.MessageHandlers.cs | 63 ++++----- src/Libplanet.Net/Transports/ITransport.cs | 19 +-- .../Transports/NetMQTransport.cs | 30 +++- .../Consensus/GossipTest.cs | 23 ++- .../Protocols/TestTransport.cs | 71 ++++++---- .../SwarmTest.Broadcast.cs | 12 +- .../Transports/NetMQTransportTest.cs | 10 +- .../Transports/TransportTest.cs | 30 ++-- 12 files changed, 209 insertions(+), 225 deletions(-) diff --git a/src/Libplanet.Net/AsyncDelegate.cs b/src/Libplanet.Net/AsyncDelegate.cs index 85d7846e3d5..3ac23ba4f0c 100644 --- a/src/Libplanet.Net/AsyncDelegate.cs +++ b/src/Libplanet.Net/AsyncDelegate.cs @@ -1,20 +1,22 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Channels; using System.Threading.Tasks; +using Libplanet.Net.Messages; namespace Libplanet.Net { - public class AsyncDelegate + public class AsyncDelegate { - private IEnumerable> _functions; + private IEnumerable, Task>> _functions; public AsyncDelegate() { - _functions = new List>(); + _functions = new List, Task>>(); } - public void Register(Func func) + public void Register(Func, Task> func) { #pragma warning disable PC002 // Usage of a .NET Standard API that isn’t available on the .NET Framework 4.6.1 @@ -26,14 +28,14 @@ public void Register(Func func) #pragma warning restore PC002 } - public void Unregister(Func func) + public void Unregister(Func, Task> func) { _functions = _functions.Where(f => !f.Equals(func)); } - public async Task InvokeAsync(T arg) + public async Task InvokeAsync(Message message, Channel channel) { - IEnumerable tasks = _functions.Select(f => f(arg)); + IEnumerable tasks = _functions.Select(f => f(message, channel)); await Task.WhenAll(tasks).ConfigureAwait(false); } } diff --git a/src/Libplanet.Net/Consensus/Gossip.cs b/src/Libplanet.Net/Consensus/Gossip.cs index 3e0a0031eb0..36526863491 100644 --- a/src/Libplanet.Net/Consensus/Gossip.cs +++ b/src/Libplanet.Net/Consensus/Gossip.cs @@ -4,6 +4,7 @@ using System.Collections.Immutable; using System.Linq; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Dasync.Collections; using Libplanet.Net.Messages; @@ -145,8 +146,8 @@ public async Task StartAsync(CancellationToken ctx) nameof(StartAsync)); } - _transport.ProcessMessageHandler.Register( - HandleMessageAsync(_cancellationTokenSource.Token)); + _transport.ProcessMessageHandler.Register((message, channel) => + HandleMessageAsync(message, channel, _cancellationTokenSource.Token)); _logger.Debug("All peers are alive. Starting gossip..."); Running = true; await Task.WhenAny( @@ -307,50 +308,57 @@ private IEnumerable PeersToBroadcast( /// /// Handle a message received from . /// - /// A cancellation token used to propagate notification - /// that this operation should be canceled. /// A function with parameter of /// and return . - private Func HandleMessageAsync(CancellationToken ctx) => async msg => + private async Task HandleMessageAsync( + Message message, + Channel channel, + CancellationToken cancellationToken) { - _logger.Verbose("HandleMessage: {Message}", msg); + _logger.Verbose("HandleMessage: {Message}", message.Content); - if (_denySet.Contains(msg.Remote)) + if (_denySet.Contains(message.Remote)) { - _logger.Verbose("Message from denied peer, rejecting: {Message}", msg); - await ReplyMessagePongAsync(msg, ctx); + _logger.Verbose("Message from denied peer, rejecting: {Message}", message.Content); + await channel.Writer + .WriteAsync(new PongMsg(), cancellationToken) + .ConfigureAwait(false); return; } try { - _validateMessageToReceive(msg); + _validateMessageToReceive(message); } catch (Exception e) { _logger.Error( - "Invalid message, rejecting: {Message}, {Exception}", msg, e.Message); + e, + "Invalid message, rejecting: {Message}", + message.Content); return; } - switch (msg.Content) + switch (message.Content) { case PingMsg _: case FindNeighborsMsg _: // Ignore protocol related messages, Kadmelia Protocol will handle it. - break; + return; case HaveMessage _: - await HandleHaveAsync(msg, ctx); - break; + await HandleHaveAsync(message, channel, cancellationToken); + return; case WantMessage _: - await HandleWantAsync(msg, ctx); - break; + await HandleWantAsync(message, channel, cancellationToken); + return; default: - await ReplyMessagePongAsync(msg, ctx); - AddMessage(msg.Content); - break; + await channel.Writer + .WriteAsync(new PongMsg(), cancellationToken) + .ConfigureAwait(false); + AddMessage(message.Content); + return; } - }; + } /// /// A lifecycle task which will run in every . @@ -380,15 +388,23 @@ private async Task HeartbeatTask(CancellationToken ctx) /// A function handling . /// /// - /// Target . - /// A cancellation token used to propagate notification + /// Target . + /// The to write + /// reply messages. + /// A cancellation token used to propagate notification /// that this operation should be canceled. /// An awaitable task without value. - private async Task HandleHaveAsync(Message msg, CancellationToken ctx) + private async Task HandleHaveAsync( + Message message, + Channel channel, + CancellationToken cancellationToken) { - var haveMessage = (HaveMessage)msg.Content; + var haveMessage = (HaveMessage)message.Content; + + await channel.Writer + .WriteAsync(new PongMsg(), cancellationToken) + .ConfigureAwait(false); - await ReplyMessagePongAsync(msg, ctx); MessageId[] idsToGet = _cache.DiffFrom(haveMessage.Ids); _logger.Verbose( "Handle HaveMessage. {Total}/{Count} messages to get.", @@ -400,15 +416,15 @@ private async Task HandleHaveAsync(Message msg, CancellationToken ctx) } _logger.Verbose("Ids to receive: {Ids}", idsToGet); - if (!_haveDict.ContainsKey(msg.Remote)) + if (!_haveDict.ContainsKey(message.Remote)) { - _haveDict.TryAdd(msg.Remote, new HashSet(idsToGet)); + _haveDict.TryAdd(message.Remote, new HashSet(idsToGet)); } else { - List list = _haveDict[msg.Remote].ToList(); + List list = _haveDict[message.Remote].ToList(); list.AddRange(idsToGet.Where(id => !list.Contains(id))); - _haveDict[msg.Remote] = new HashSet(list); + _haveDict[message.Remote] = new HashSet(list); } } @@ -485,14 +501,19 @@ await optimized.ParallelForEachAsync( /// A function handling . /// /// - /// Target . - /// A cancellation token used to propagate notification + /// Target . + /// The to write + /// reply messages. + /// A cancellation token used to propagate notification /// that this operation should be canceled. /// An awaitable task without value. - private async Task HandleWantAsync(Message msg, CancellationToken ctx) + private async Task HandleWantAsync( + Message message, + Channel channel, + CancellationToken cancellationToken) { // FIXME: Message may have been discarded. - var wantMessage = (WantMessage)msg.Content; + WantMessage wantMessage = (WantMessage)message.Content; MessageContent[] contents = wantMessage.Ids.Select(id => _cache.Get(id)).ToArray(); MessageId[] ids = contents.Select(c => c.Id).ToArray(); @@ -502,23 +523,23 @@ private async Task HandleWantAsync(Message msg, CancellationToken ctx) ids, contents.Select(content => (content.Type, content.Id))); - await contents.ParallelForEachAsync( - async c => + foreach (var content in contents) + { + try { - try - { - _validateMessageToSend(c); - await _transport.ReplyMessageAsync(c, msg.Identity, ctx); - } - catch (Exception e) - { - _logger.Error( - "Invalid message, rejecting: {Message}, {Exception}", msg, e.Message); - } - }, ctx); - - var id = msg is { Identity: null } ? "unknown" : new Guid(msg.Identity).ToString(); - _logger.Debug("Finished replying WantMessage. {RequestId}", id); + _validateMessageToSend(content); + await channel.Writer + .WriteAsync(content, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception e) + { + _logger.Error( + e, + "Invalid message, rejecting {Message}", + message.Content); + } + } } /// @@ -575,17 +596,5 @@ private async Task RefreshTableAsync(CancellationToken ctx) } } } - - /// - /// Replies a of received . - /// - /// A message to replies. - /// A cancellation token used to propagate notification - /// that this operation should be canceled. - /// An awaitable task without value. - private async Task ReplyMessagePongAsync(Message message, CancellationToken ctx) - { - await _transport.ReplyMessageAsync(new PongMsg(), message.Identity, ctx); - } } } diff --git a/src/Libplanet.Net/Protocols/KademliaProtocol.cs b/src/Libplanet.Net/Protocols/KademliaProtocol.cs index 2a0e9f035ac..3c03ae3c6a4 100644 --- a/src/Libplanet.Net/Protocols/KademliaProtocol.cs +++ b/src/Libplanet.Net/Protocols/KademliaProtocol.cs @@ -4,6 +4,7 @@ using System.Collections.Immutable; using System.Linq; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Dasync.Collections; using Libplanet.Crypto; @@ -447,19 +448,19 @@ internal async Task PingAsync( } } - private async Task ProcessMessageHandler(Message message) + private async Task ProcessMessageHandler(Message message, Channel channel) { switch (message.Content) { case PingMsg ping: { - await ReceivePingAsync(message).ConfigureAwait(false); + await ReceivePingAsync(message, channel).ConfigureAwait(false); break; } case FindNeighborsMsg findNeighbors: { - await ReceiveFindPeerAsync(message).ConfigureAwait(false); + await ReceiveFindPeerAsync(message, channel).ConfigureAwait(false); break; } } @@ -635,7 +636,7 @@ private async Task> GetNeighbors( } // Send pong back to remote - private async Task ReceivePingAsync(Message message) + private async Task ReceivePingAsync(Message message, Channel channel) { var ping = (PingMsg)message.Content; if (message.Remote.Address.Equals(_address)) @@ -643,10 +644,7 @@ private async Task ReceivePingAsync(Message message) throw new InvalidMessageContentException("Cannot receive ping from self.", ping); } - var pong = new PongMsg(); - - await _transport.ReplyMessageAsync(pong, message.Identity, default) - .ConfigureAwait(false); + await channel.Writer.WriteAsync(new PongMsg()).ConfigureAwait(false); } /// @@ -769,16 +767,12 @@ private async Task ProcessFoundAsync( // FIXME: this method is not safe from amplification attack // maybe ping/pong/ping/pong is required - private async Task ReceiveFindPeerAsync(Message message) + private async Task ReceiveFindPeerAsync(Message message, Channel channel) { var findNeighbors = (FindNeighborsMsg)message.Content; IEnumerable found = _table.Neighbors(findNeighbors.Target, _table.BucketSize, true); - - var neighbors = new NeighborsMsg(found); - - await _transport.ReplyMessageAsync(neighbors, message.Identity, default) - .ConfigureAwait(false); + await channel.Writer.WriteAsync(new NeighborsMsg(found)).ConfigureAwait(false); } } } diff --git a/src/Libplanet.Net/Swarm.Evidence.cs b/src/Libplanet.Net/Swarm.Evidence.cs index 1b191bddde3..c25a5c7eeb5 100644 --- a/src/Libplanet.Net/Swarm.Evidence.cs +++ b/src/Libplanet.Net/Swarm.Evidence.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Libplanet.Blockchain; using Libplanet.Crypto; @@ -132,7 +133,7 @@ private void BroadcastEvidenceIds(Address? except, IEnumerable evide BroadcastMessage(except, message); } - private async Task TransferEvidenceAsync(Message message) + private async Task TransferEvidenceAsync(Message message, Channel channel) { if (!await _transferEvidenceSemaphore.WaitAsync(TimeSpan.Zero, _cancellationToken)) { @@ -158,7 +159,7 @@ private async Task TransferEvidenceAsync(Message message) } MessageContent response = new EvidenceMsg(ev.Serialize()); - await Transport.ReplyMessageAsync(response, message.Identity, default); + await channel.Writer.WriteAsync(response).ConfigureAwait(false); } catch (KeyNotFoundException) { diff --git a/src/Libplanet.Net/Swarm.MessageHandlers.cs b/src/Libplanet.Net/Swarm.MessageHandlers.cs index f49fafa8b31..113cec44fd2 100644 --- a/src/Libplanet.Net/Swarm.MessageHandlers.cs +++ b/src/Libplanet.Net/Swarm.MessageHandlers.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Channels; using System.Threading.Tasks; using Libplanet.Net.Messages; using Libplanet.Types.Blocks; @@ -15,13 +16,15 @@ public partial class Swarm private readonly NullableSemaphore _transferTxsSemaphore; private readonly NullableSemaphore _transferEvidenceSemaphore; - private Task ProcessMessageHandlerAsync(Message message) + private async Task ProcessMessageHandlerAsync( + Message message, + Channel channel) { switch (message.Content) { case PingMsg _: case FindNeighborsMsg _: - return Task.CompletedTask; + return; case GetChainStatusMsg getChainStatus: { @@ -38,10 +41,8 @@ private Task ProcessMessageHandlerAsync(Message message) tip.Hash ); - return Transport.ReplyMessageAsync( - chainStatus, - message.Identity, - default); + await channel.Writer.WriteAsync(chainStatus).ConfigureAwait(false); + return; } case GetBlockHashesMsg getBlockHashes: @@ -60,53 +61,47 @@ private Task ProcessMessageHandlerAsync(Message message) getBlockHashes.Locator.Hash); var reply = new BlockHashesMsg(hashes); - return Transport.ReplyMessageAsync(reply, message.Identity, default); + await channel.Writer.WriteAsync(reply).ConfigureAwait(false); + return; } case GetBlocksMsg getBlocksMsg: - return TransferBlocksAsync(message); + await TransferBlocksAsync(message, channel); + return; case GetTxsMsg getTxs: - return TransferTxsAsync(message); + await TransferTxsAsync(message, channel); + return; case GetEvidenceMsg getTxs: - return TransferEvidenceAsync(message); + await TransferEvidenceAsync(message, channel); + return; case TxIdsMsg txIds: ProcessTxIds(message); - return Transport.ReplyMessageAsync( - new PongMsg(), - message.Identity, - default - ); + await channel.Writer.WriteAsync(new PongMsg()).ConfigureAwait(false); + return; case EvidenceIdsMsg evidenceIds: ProcessEvidenceIds(message); - return Transport.ReplyMessageAsync( - new PongMsg(), - message.Identity, - default - ); + await channel.Writer.WriteAsync(new PongMsg()).ConfigureAwait(false); + return; case BlockHashesMsg _: _logger.Error( "{MessageType} messages are only for IBD", nameof(BlockHashesMsg)); - return Task.CompletedTask; + return; case BlockHeaderMsg blockHeader: ProcessBlockHeader(message); - return Transport.ReplyMessageAsync( - new PongMsg(), - message.Identity, - default - ); + await channel.Writer.WriteAsync(new PongMsg()).ConfigureAwait(false); + return; default: throw new InvalidMessageContentException( $"Failed to handle message: {message.Content}", - message.Content - ); + message.Content); } } @@ -192,7 +187,7 @@ private void ProcessBlockHeader(Message message) } } - private async Task TransferTxsAsync(Message message) + private async Task TransferTxsAsync(Message message, Channel channel) { if (!await _transferTxsSemaphore.WaitAsync(TimeSpan.Zero, _cancellationToken)) { @@ -218,7 +213,7 @@ private async Task TransferTxsAsync(Message message) } MessageContent response = new TxMsg(tx.Serialize()); - await Transport.ReplyMessageAsync(response, message.Identity, default); + await channel.Writer.WriteAsync(response).ConfigureAwait(false); } catch (KeyNotFoundException) { @@ -252,7 +247,7 @@ private void ProcessTxIds(Message message) TxCompletion.Demand(message.Remote, txIdsMsg.Ids); } - private async Task TransferBlocksAsync(Message message) + private async Task TransferBlocksAsync(Message message, Channel channel) { if (!await _transferBlocksSemaphore.WaitAsync(TimeSpan.Zero, _cancellationToken)) { @@ -304,7 +299,7 @@ private async Task TransferBlocksAsync(Message message) count, total ); - await Transport.ReplyMessageAsync(response, message.Identity, default); + await channel.Writer.WriteAsync(response).ConfigureAwait(false); payloads.Clear(); } } @@ -317,7 +312,7 @@ private async Task TransferBlocksAsync(Message message) count, total, reqId); - await Transport.ReplyMessageAsync(response, message.Identity, default); + await channel.Writer.WriteAsync(response).ConfigureAwait(false); } if (count == 0) @@ -328,7 +323,7 @@ private async Task TransferBlocksAsync(Message message) count, total, reqId); - await Transport.ReplyMessageAsync(response, message.Identity, default); + await channel.Writer.WriteAsync(response).ConfigureAwait(false); } _logger.Debug("{Count} blocks were transferred to {Identity}", count, reqId); diff --git a/src/Libplanet.Net/Transports/ITransport.cs b/src/Libplanet.Net/Transports/ITransport.cs index a9804cff7a0..5dcde3f7e18 100644 --- a/src/Libplanet.Net/Transports/ITransport.cs +++ b/src/Libplanet.Net/Transports/ITransport.cs @@ -25,7 +25,7 @@ public interface ITransport : IDisposable /// The list of tasks invoked when a message that is not /// a reply is received. /// - AsyncDelegate ProcessMessageHandler { get; } + AsyncDelegate ProcessMessageHandler { get; } /// /// The current representation of . @@ -154,22 +154,5 @@ Task> SendMessageAsync( /// Thrown when instance /// is already disposed. void BroadcastMessage(IEnumerable peers, MessageContent content); - - /// - /// Sends a as a reply. - /// - /// The to send as a reply. - /// The byte array that represents identification of the - /// to respond. - /// - /// A cancellation token used to propagate notification that this - /// operation should be canceled. - /// An awaitable task without value. - /// - /// Thrown when instance is already disposed. - Task ReplyMessageAsync( - MessageContent content, - byte[] identity, - CancellationToken cancellationToken); } } diff --git a/src/Libplanet.Net/Transports/NetMQTransport.cs b/src/Libplanet.Net/Transports/NetMQTransport.cs index 4c197c6c11a..b32391320a3 100644 --- a/src/Libplanet.Net/Transports/NetMQTransport.cs +++ b/src/Libplanet.Net/Transports/NetMQTransport.cs @@ -122,11 +122,11 @@ private NetMQTransport( ); _runningEvent = new AsyncManualResetEvent(); - ProcessMessageHandler = new AsyncDelegate(); + ProcessMessageHandler = new AsyncDelegate(); } /// - public AsyncDelegate ProcessMessageHandler { get; } + public AsyncDelegate ProcessMessageHandler { get; } /// public BoundPeer AsPeer => _turnClient is TurnClient turnClient @@ -526,8 +526,7 @@ await boundPeers.ParallelForEachAsync( ); } - /// - public async Task ReplyMessageAsync( + private async Task ReplyMessageAsync( MessageContent content, byte[] identity, CancellationToken cancellationToken) @@ -650,7 +649,28 @@ private void ReceiveMessage(object? sender, NetMQSocketEventArgs e) { _messageValidator.ValidateTimestamp(message); _messageValidator.ValidateAppProtocolVersion(message); - await ProcessMessageHandler.InvokeAsync(message); + Channel channel = + Channel.CreateUnbounded(); + try + { + await ProcessMessageHandler.InvokeAsync( + message, + channel); + } + finally + { + channel.Writer.Complete(); + } + + await foreach ( + var messageContent in channel.Reader.ReadAllAsync( + _runtimeCancellationTokenSource.Token)) + { + await ReplyMessageAsync( + messageContent, + message.Identity ?? Array.Empty(), + _runtimeCancellationTokenSource.Token); + } } catch (InvalidMessageTimestampException imte) { diff --git a/test/Libplanet.Net.Tests/Consensus/GossipTest.cs b/test/Libplanet.Net.Tests/Consensus/GossipTest.cs index 2ff3bcf8671..a02fac035cf 100644 --- a/test/Libplanet.Net.Tests/Consensus/GossipTest.cs +++ b/test/Libplanet.Net.Tests/Consensus/GossipTest.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Net; +using System.Threading.Channels; using System.Threading.Tasks; using Libplanet.Crypto; using Libplanet.Net.Consensus; @@ -223,7 +224,7 @@ public async void AddPeerWithHaveMessage() var receivedEvent = new AsyncAutoResetEvent(); var transport1 = CreateTransport(key1, 6001); - async Task HandleMessage(Message message) + async Task HandleMessage(Message message, Channel channel) { received = true; receivedEvent.Set(); @@ -269,14 +270,10 @@ await transport2.SendMessageAsync( public async void DoNotBroadcastToSeedPeers() { bool received = false; - async Task ProcessMessage(Message msg) + Task ProcessMessage(Message message, Channel channel) { - if (msg.Content is HaveMessage) - { - received = true; - } - - await Task.CompletedTask; + received = received | message.Content is HaveMessage; + return Task.CompletedTask; } ITransport seed = CreateTransport(); @@ -308,14 +305,10 @@ async Task ProcessMessage(Message msg) public async void DoNotSendDuplicateMessageRequest() { int received = 0; - async Task ProcessMessage(Message msg) + Task ProcessMessage(Message message, Channel channel) { - if (msg.Content is WantMessage) - { - received++; - } - - await Task.CompletedTask; + received += message.Content is WantMessage ? 1 : 0; + return Task.CompletedTask; } Gossip receiver = CreateGossip(_ => { }); diff --git a/test/Libplanet.Net.Tests/Protocols/TestTransport.cs b/test/Libplanet.Net.Tests/Protocols/TestTransport.cs index 730e94d32fb..4972f29641c 100644 --- a/test/Libplanet.Net.Tests/Protocols/TestTransport.cs +++ b/test/Libplanet.Net.Tests/Protocols/TestTransport.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Net; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Libplanet.Common; using Libplanet.Crypto; @@ -63,12 +64,12 @@ public TestTransport( _ignoreTestMessageWithData = new List(); _random = new Random(); Table = new RoutingTable(Address, tableSize, bucketSize); - ProcessMessageHandler = new AsyncDelegate(); + ProcessMessageHandler = new AsyncDelegate(); Protocol = new KademliaProtocol(Table, this, Address); MessageHistory = new FixedSizedQueue(30); } - public AsyncDelegate ProcessMessageHandler { get; } + public AsyncDelegate ProcessMessageHandler { get; } public AsyncAutoResetEvent MessageReceived { get; } @@ -436,33 +437,6 @@ await SendMessageAsync(peer, content, timeout, cancellationToken), }; } - public async Task ReplyMessageAsync( - MessageContent content, - byte[] identity, - CancellationToken cancellationToken) - { - if (_disposed) - { - throw new ObjectDisposedException(nameof(TestTransport)); - } - - if (!Running) - { - throw new TransportException("Start transport before use."); - } - - _logger.Debug("Replying {Content}...", content); - var message = new Message( - content, - AppProtocolVersion, - AsPeer, - DateTimeOffset.UtcNow, - identity); - await Task.Delay(_networkDelay, cancellationToken); - _transports[_peersToReply[identity]].ReceiveReply(message); - _peersToReply.TryRemove(identity, out Address addr); - } - public async Task WaitForTestMessageWithData( string data, CancellationToken token = default) @@ -495,7 +469,9 @@ public bool ReceivedTestMessageOfData(string data) .Any(c => c.Data == data); } - private void ReceiveMessage(Message message) +#pragma warning disable S4457 // Split the method. + private async void ReceiveMessage(Message message) +#pragma warning restore S4457 { if (_swarmCancellationTokenSource.IsCancellationRequested) { @@ -532,8 +508,41 @@ private void ReceiveMessage(Message message) LastMessageTimestamp = DateTimeOffset.UtcNow; ReceivedMessages.Add(message); - _ = ProcessMessageHandler.InvokeAsync(message); MessageReceived.Set(); + Channel channel = Channel.CreateUnbounded(); + await ProcessMessageHandler.InvokeAsync(message, channel); + channel.Writer.TryComplete(); + await foreach (var content in channel.Reader.ReadAllAsync()) + { + await ReplyMessageAsync(content, message.Identity, default); + } + } + + private async Task ReplyMessageAsync( + MessageContent content, + byte[] identity, + CancellationToken cancellationToken) + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(TestTransport)); + } + + if (!Running) + { + throw new TransportException("Start transport before use."); + } + + _logger.Debug("Replying {Content}...", content); + var message = new Message( + content, + AppProtocolVersion, + AsPeer, + DateTimeOffset.UtcNow, + identity); + await Task.Delay(_networkDelay, cancellationToken); + _transports[_peersToReply[identity]].ReceiveReply(message); + _peersToReply.TryRemove(identity, out Address addr); } private void ReceiveReply(Message message) diff --git a/test/Libplanet.Net.Tests/SwarmTest.Broadcast.cs b/test/Libplanet.Net.Tests/SwarmTest.Broadcast.cs index 30d26cf20eb..94dd0f126e5 100644 --- a/test/Libplanet.Net.Tests/SwarmTest.Broadcast.cs +++ b/test/Libplanet.Net.Tests/SwarmTest.Broadcast.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Net; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Libplanet.Action; using Libplanet.Action.Loader; @@ -1006,21 +1007,18 @@ public async Task DoNotSpawnMultipleTaskForSinglePeer() Array.Empty())); int requestCount = 0; - async Task MessageHandler(Message message) + async Task MessageHandler(Message message, Channel channel) { _logger.Debug("Received message: {Content}", message); switch (message.Content) { case PingMsg ping: - await mockTransport.ReplyMessageAsync( - new PongMsg(), - message.Identity, - default); - break; + await channel.Writer.WriteAsync(new PongMsg()).ConfigureAwait(false); + return; case GetBlockHashesMsg gbhm: requestCount++; - break; + return; } } diff --git a/test/Libplanet.Net.Tests/Transports/NetMQTransportTest.cs b/test/Libplanet.Net.Tests/Transports/NetMQTransportTest.cs index 0f2e50e0e0b..5ecd836f464 100644 --- a/test/Libplanet.Net.Tests/Transports/NetMQTransportTest.cs +++ b/test/Libplanet.Net.Tests/Transports/NetMQTransportTest.cs @@ -64,15 +64,7 @@ public async Task SendMessageAsyncNetMQSocketLeak() new HostOptions(IPAddress.Loopback.ToString(), new IceServer[] { }, 0) ).ConfigureAwait(false); transport.ProcessMessageHandler.Register( - async m => - { - await transport.ReplyMessageAsync( - new PongMsg(), - m.Identity, - CancellationToken.None - ); - } - ); + async (m, c) => await c.Writer.WriteAsync(new PongMsg())); await InitializeAsync(transport); string invalidHost = Guid.NewGuid().ToString(); diff --git a/test/Libplanet.Net.Tests/Transports/TransportTest.cs b/test/Libplanet.Net.Tests/Transports/TransportTest.cs index 9b8a3b2a241..f1739c30677 100644 --- a/test/Libplanet.Net.Tests/Transports/TransportTest.cs +++ b/test/Libplanet.Net.Tests/Transports/TransportTest.cs @@ -6,6 +6,7 @@ using System.Net; using System.Net.Sockets; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Libplanet.Crypto; using Libplanet.Net.Messages; @@ -109,11 +110,6 @@ await Assert.ThrowsAsync( default)); Assert.Throws( () => transport.BroadcastMessage(null, message)); - await Assert.ThrowsAsync( - async () => await transport.ReplyMessageAsync( - message, - Array.Empty(), - default)); // To check multiple Dispose() throws error or not. transport.Dispose(); @@ -152,14 +148,11 @@ public async Task SendMessageAsync() ITransport transportA = await CreateTransportAsync().ConfigureAwait(false); ITransport transportB = await CreateTransportAsync().ConfigureAwait(false); - transportB.ProcessMessageHandler.Register(async message => + transportB.ProcessMessageHandler.Register(async (message, channel) => { if (message.Content is PingMsg) { - await transportB.ReplyMessageAsync( - new PongMsg(), - message.Identity, - CancellationToken.None); + await channel.Writer.WriteAsync(new PongMsg()).ConfigureAwait(false); } }); @@ -221,18 +214,12 @@ public async Task SendMessageMultipleRepliesAsync() ITransport transportA = await CreateTransportAsync().ConfigureAwait(false); ITransport transportB = await CreateTransportAsync().ConfigureAwait(false); - transportB.ProcessMessageHandler.Register(async message => + transportB.ProcessMessageHandler.Register(async (message, channel) => { if (message.Content is PingMsg) { - await transportB.ReplyMessageAsync( - new PingMsg(), - message.Identity, - default); - await transportB.ReplyMessageAsync( - new PongMsg(), - message.Identity, - default); + await channel.Writer.WriteAsync(new PingMsg()).ConfigureAwait(false); + await channel.Writer.WriteAsync(new PongMsg()).ConfigureAwait(false); } }); @@ -370,9 +357,10 @@ public async Task BroadcastMessage() transportC.ProcessMessageHandler.Register(MessageHandler(tcsC)); transportD.ProcessMessageHandler.Register(MessageHandler(tcsD)); - Func MessageHandler(TaskCompletionSource tcs) + Func, Task> + MessageHandler(TaskCompletionSource tcs) { - return async message => + return async (message, channel) => { if (message.Content is PingMsg) { From dd93cf468d84643f1e6a11b01dad3667c0988b50 Mon Sep 17 00:00:00 2001 From: Say Cheong Date: Mon, 25 Nov 2024 17:43:31 +0900 Subject: [PATCH 2/3] Changelog --- CHANGES.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 3b8078aeb5d..497f1c2f91f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -13,6 +13,8 @@ To be released. - Changed `IMessageCodec.Encode(MessageContent, PrivateKey, AppProtocolVersion, BoundPeer, DateTimeOffset, byte[]?)` to `IMessageCodec.Encode(Message, PrivateKey)`. [[#3997]] + - Removed `ITransport.ReplyMessageAsync()` interface method. [[#3998]] + - Changed `AsyncDelegate` class to `AsyncDelegate`. [[#3998]] ### Backward-incompatible network protocol changes @@ -29,6 +31,7 @@ To be released. ### CLI tools [#3997]: https://github.com/planetarium/libplanet/pull/3997 +[#3998]: https://github.com/planetarium/libplanet/pull/3998 Version 5.4.0 From 7c26f68c63739f548620969a479ab72f7a46bea8 Mon Sep 17 00:00:00 2001 From: Say Cheong Date: Tue, 26 Nov 2024 01:23:03 +0900 Subject: [PATCH 3/3] Workaround for ReadAllAsync() --- src/Libplanet.Net/Transports/NetMQTransport.cs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/Libplanet.Net/Transports/NetMQTransport.cs b/src/Libplanet.Net/Transports/NetMQTransport.cs index b32391320a3..443bcf985fd 100644 --- a/src/Libplanet.Net/Transports/NetMQTransport.cs +++ b/src/Libplanet.Net/Transports/NetMQTransport.cs @@ -662,10 +662,26 @@ await ProcessMessageHandler.InvokeAsync( channel.Writer.Complete(); } +#if NETCOREAPP3_0 || NETCOREAPP3_1 || NET await foreach ( var messageContent in channel.Reader.ReadAllAsync( _runtimeCancellationTokenSource.Token)) { +#else + while (true) + { + MessageContent messageContent; + try + { + messageContent = await channel.Reader.ReadAsync( + _runtimeCancellationTokenSource.Token); + } + catch (ChannelClosedException) + { + break; + } + +#endif await ReplyMessageAsync( messageContent, message.Identity ?? Array.Empty(),