Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AsyncDelegate<T> #3998

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ To be released.
- Changed `IMessageCodec.Encode(MessageContent, PrivateKey,
AppProtocolVersion, BoundPeer, DateTimeOffset, byte[]?)` to
`IMessageCodec.Encode(Message, PrivateKey)`. [[#3997]]
- Removed `ITransport.ReplyMessageAsync()` interface method. [[#3998]]
- Changed `AsyncDelegate<T>` class to `AsyncDelegate`. [[#3998]]

### Backward-incompatible network protocol changes

Expand All @@ -29,6 +31,7 @@ To be released.
### CLI tools

[#3997]: https://github.com/planetarium/libplanet/pull/3997
[#3998]: https://github.com/planetarium/libplanet/pull/3998


Version 5.4.0
Expand Down
16 changes: 9 additions & 7 deletions src/Libplanet.Net/AsyncDelegate.cs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Channels;
using System.Threading.Tasks;
using Libplanet.Net.Messages;

namespace Libplanet.Net
{
public class AsyncDelegate<T>
public class AsyncDelegate
{
private IEnumerable<Func<T, Task>> _functions;
private IEnumerable<Func<Message, Channel<MessageContent>, Task>> _functions;

public AsyncDelegate()
{
_functions = new List<Func<T, Task>>();
_functions = new List<Func<Message, Channel<MessageContent>, Task>>();
}

public void Register(Func<T, Task> func)
public void Register(Func<Message, Channel<MessageContent>, Task> func)
{
#pragma warning disable PC002
// Usage of a .NET Standard API that isn’t available on the .NET Framework 4.6.1
Expand All @@ -26,14 +28,14 @@ public void Register(Func<T, Task> func)
#pragma warning restore PC002
}

public void Unregister(Func<T, Task> func)
public void Unregister(Func<Message, Channel<MessageContent>, Task> func)
{
_functions = _functions.Where(f => !f.Equals(func));
}

public async Task InvokeAsync(T arg)
public async Task InvokeAsync(Message message, Channel<MessageContent> channel)
{
IEnumerable<Task> tasks = _functions.Select(f => f(arg));
IEnumerable<Task> tasks = _functions.Select(f => f(message, channel));
await Task.WhenAll(tasks).ConfigureAwait(false);
}
}
Expand Down
133 changes: 71 additions & 62 deletions src/Libplanet.Net/Consensus/Gossip.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Dasync.Collections;
using Libplanet.Net.Messages;
Expand Down Expand Up @@ -145,8 +146,8 @@ public async Task StartAsync(CancellationToken ctx)
nameof(StartAsync));
}

_transport.ProcessMessageHandler.Register(
HandleMessageAsync(_cancellationTokenSource.Token));
_transport.ProcessMessageHandler.Register((message, channel) =>
HandleMessageAsync(message, channel, _cancellationTokenSource.Token));
_logger.Debug("All peers are alive. Starting gossip...");
Running = true;
await Task.WhenAny(
Expand Down Expand Up @@ -307,50 +308,57 @@ private IEnumerable<BoundPeer> PeersToBroadcast(
/// <summary>
/// Handle a message received from <see cref="ITransport.ProcessMessageHandler"/>.
/// </summary>
/// <param name="ctx">A cancellation token used to propagate notification
/// that this operation should be canceled.</param>
/// <returns>A function with parameter of <see cref="Message"/>
/// and return <see cref="Task"/>.</returns>
private Func<Message, Task> HandleMessageAsync(CancellationToken ctx) => async msg =>
private async Task HandleMessageAsync(
Message message,
Channel<MessageContent> channel,
CancellationToken cancellationToken)
{
_logger.Verbose("HandleMessage: {Message}", msg);
_logger.Verbose("HandleMessage: {Message}", message.Content);

if (_denySet.Contains(msg.Remote))
if (_denySet.Contains(message.Remote))
{
_logger.Verbose("Message from denied peer, rejecting: {Message}", msg);
await ReplyMessagePongAsync(msg, ctx);
_logger.Verbose("Message from denied peer, rejecting: {Message}", message.Content);
await channel.Writer
.WriteAsync(new PongMsg(), cancellationToken)
.ConfigureAwait(false);
return;
}

try
{
_validateMessageToReceive(msg);
_validateMessageToReceive(message);
}
catch (Exception e)
{
_logger.Error(
"Invalid message, rejecting: {Message}, {Exception}", msg, e.Message);
e,
"Invalid message, rejecting: {Message}",
message.Content);
return;
}

switch (msg.Content)
switch (message.Content)
{
case PingMsg _:
case FindNeighborsMsg _:
// Ignore protocol related messages, Kadmelia Protocol will handle it.
break;
return;
case HaveMessage _:
await HandleHaveAsync(msg, ctx);
break;
await HandleHaveAsync(message, channel, cancellationToken);
return;
case WantMessage _:
await HandleWantAsync(msg, ctx);
break;
await HandleWantAsync(message, channel, cancellationToken);
return;
default:
await ReplyMessagePongAsync(msg, ctx);
AddMessage(msg.Content);
break;
await channel.Writer
.WriteAsync(new PongMsg(), cancellationToken)
.ConfigureAwait(false);
AddMessage(message.Content);
return;
}
};
}

/// <summary>
/// A lifecycle task which will run in every <see cref="_heartbeatInterval"/>.
Expand Down Expand Up @@ -380,15 +388,23 @@ private async Task HeartbeatTask(CancellationToken ctx)
/// A function handling <see cref="HaveMessage"/>.
/// <seealso cref="HandleMessageAsync"/>
/// </summary>
/// <param name="msg">Target <see cref="HaveMessage"/>.</param>
/// <param name="ctx">A cancellation token used to propagate notification
/// <param name="message">Target <see cref="HaveMessage"/>.</param>
/// <param name="channel">The <see cref="Channel{T}"/> to write
/// reply messages.</param>
/// <param name="cancellationToken">A cancellation token used to propagate notification
/// that this operation should be canceled.</param>
/// <returns>An awaitable task without value.</returns>
private async Task HandleHaveAsync(Message msg, CancellationToken ctx)
private async Task HandleHaveAsync(
Message message,
Channel<MessageContent> channel,
CancellationToken cancellationToken)
{
var haveMessage = (HaveMessage)msg.Content;
var haveMessage = (HaveMessage)message.Content;

await channel.Writer
.WriteAsync(new PongMsg(), cancellationToken)
.ConfigureAwait(false);

await ReplyMessagePongAsync(msg, ctx);
MessageId[] idsToGet = _cache.DiffFrom(haveMessage.Ids);
_logger.Verbose(
"Handle HaveMessage. {Total}/{Count} messages to get.",
Expand All @@ -400,15 +416,15 @@ private async Task HandleHaveAsync(Message msg, CancellationToken ctx)
}

_logger.Verbose("Ids to receive: {Ids}", idsToGet);
if (!_haveDict.ContainsKey(msg.Remote))
if (!_haveDict.ContainsKey(message.Remote))
{
_haveDict.TryAdd(msg.Remote, new HashSet<MessageId>(idsToGet));
_haveDict.TryAdd(message.Remote, new HashSet<MessageId>(idsToGet));
}
else
{
List<MessageId> list = _haveDict[msg.Remote].ToList();
List<MessageId> list = _haveDict[message.Remote].ToList();
list.AddRange(idsToGet.Where(id => !list.Contains(id)));
_haveDict[msg.Remote] = new HashSet<MessageId>(list);
_haveDict[message.Remote] = new HashSet<MessageId>(list);
}
}

Expand Down Expand Up @@ -485,14 +501,19 @@ await optimized.ParallelForEachAsync(
/// A function handling <see cref="WantMessage"/>.
/// <seealso cref="HandleMessageAsync"/>
/// </summary>
/// <param name="msg">Target <see cref="WantMessage"/>.</param>
/// <param name="ctx">A cancellation token used to propagate notification
/// <param name="message">Target <see cref="WantMessage"/>.</param>
/// <param name="channel">The <see cref="Channel{T}"/> to write
/// reply messages.</param>
/// <param name="cancellationToken">A cancellation token used to propagate notification
/// that this operation should be canceled.</param>
/// <returns>An awaitable task without value.</returns>
private async Task HandleWantAsync(Message msg, CancellationToken ctx)
private async Task HandleWantAsync(
Message message,
Channel<MessageContent> channel,
CancellationToken cancellationToken)
{
// FIXME: Message may have been discarded.
var wantMessage = (WantMessage)msg.Content;
WantMessage wantMessage = (WantMessage)message.Content;
MessageContent[] contents = wantMessage.Ids.Select(id => _cache.Get(id)).ToArray();
MessageId[] ids = contents.Select(c => c.Id).ToArray();

Expand All @@ -502,23 +523,23 @@ private async Task HandleWantAsync(Message msg, CancellationToken ctx)
ids,
contents.Select(content => (content.Type, content.Id)));

await contents.ParallelForEachAsync(
async c =>
foreach (var content in contents)
{
try
{
try
{
_validateMessageToSend(c);
await _transport.ReplyMessageAsync(c, msg.Identity, ctx);
}
catch (Exception e)
{
_logger.Error(
"Invalid message, rejecting: {Message}, {Exception}", msg, e.Message);
}
}, ctx);

var id = msg is { Identity: null } ? "unknown" : new Guid(msg.Identity).ToString();
_logger.Debug("Finished replying WantMessage. {RequestId}", id);
_validateMessageToSend(content);
await channel.Writer
.WriteAsync(content, cancellationToken)
.ConfigureAwait(false);
}
catch (Exception e)
{
_logger.Error(
e,
"Invalid message, rejecting {Message}",
message.Content);
}
}
}

/// <summary>
Expand Down Expand Up @@ -575,17 +596,5 @@ private async Task RefreshTableAsync(CancellationToken ctx)
}
}
}

/// <summary>
/// Replies a <see cref="PongMsg"/> of received <paramref name="message"/>.
/// </summary>
/// <param name="message">A message to replies.</param>
/// <param name="ctx">A cancellation token used to propagate notification
/// that this operation should be canceled.</param>
/// <returns>An awaitable task without value.</returns>
private async Task ReplyMessagePongAsync(Message message, CancellationToken ctx)
{
await _transport.ReplyMessageAsync(new PongMsg(), message.Identity, ctx);
}
}
}
22 changes: 8 additions & 14 deletions src/Libplanet.Net/Protocols/KademliaProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Dasync.Collections;
using Libplanet.Crypto;
Expand Down Expand Up @@ -447,19 +448,19 @@ internal async Task PingAsync(
}
}

private async Task ProcessMessageHandler(Message message)
private async Task ProcessMessageHandler(Message message, Channel<MessageContent> channel)
{
switch (message.Content)
{
case PingMsg ping:
{
await ReceivePingAsync(message).ConfigureAwait(false);
await ReceivePingAsync(message, channel).ConfigureAwait(false);
break;
}

case FindNeighborsMsg findNeighbors:
{
await ReceiveFindPeerAsync(message).ConfigureAwait(false);
await ReceiveFindPeerAsync(message, channel).ConfigureAwait(false);
break;
}
}
Expand Down Expand Up @@ -635,18 +636,15 @@ private async Task<IEnumerable<BoundPeer>> GetNeighbors(
}

// Send pong back to remote
private async Task ReceivePingAsync(Message message)
private async Task ReceivePingAsync(Message message, Channel<MessageContent> channel)
{
var ping = (PingMsg)message.Content;
if (message.Remote.Address.Equals(_address))
{
throw new InvalidMessageContentException("Cannot receive ping from self.", ping);
}

var pong = new PongMsg();

await _transport.ReplyMessageAsync(pong, message.Identity, default)
.ConfigureAwait(false);
await channel.Writer.WriteAsync(new PongMsg()).ConfigureAwait(false);
}

/// <summary>
Expand Down Expand Up @@ -769,16 +767,12 @@ private async Task ProcessFoundAsync(

// FIXME: this method is not safe from amplification attack
// maybe ping/pong/ping/pong is required
private async Task ReceiveFindPeerAsync(Message message)
private async Task ReceiveFindPeerAsync(Message message, Channel<MessageContent> channel)
{
var findNeighbors = (FindNeighborsMsg)message.Content;
IEnumerable<BoundPeer> found =
_table.Neighbors(findNeighbors.Target, _table.BucketSize, true);

var neighbors = new NeighborsMsg(found);

await _transport.ReplyMessageAsync(neighbors, message.Identity, default)
.ConfigureAwait(false);
await channel.Writer.WriteAsync(new NeighborsMsg(found)).ConfigureAwait(false);
}
}
}
5 changes: 3 additions & 2 deletions src/Libplanet.Net/Swarm.Evidence.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Libplanet.Blockchain;
using Libplanet.Crypto;
Expand Down Expand Up @@ -132,7 +133,7 @@ private void BroadcastEvidenceIds(Address? except, IEnumerable<EvidenceId> evide
BroadcastMessage(except, message);
}

private async Task TransferEvidenceAsync(Message message)
private async Task TransferEvidenceAsync(Message message, Channel<MessageContent> channel)
{
if (!await _transferEvidenceSemaphore.WaitAsync(TimeSpan.Zero, _cancellationToken))
{
Expand All @@ -158,7 +159,7 @@ private async Task TransferEvidenceAsync(Message message)
}

MessageContent response = new EvidenceMsg(ev.Serialize());
await Transport.ReplyMessageAsync(response, message.Identity, default);
await channel.Writer.WriteAsync(response).ConfigureAwait(false);
}
catch (KeyNotFoundException)
{
Expand Down
Loading
Loading