Skip to content

Commit

Permalink
incorporated PR MarcusWichelmann#26 from upstream repository
Browse files Browse the repository at this point in the history
  • Loading branch information
paviad committed Jul 29, 2024
1 parent ab74417 commit 0eeca47
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public sealed class RfbMessageReceiver : BackgroundThread, IRfbMessageReceiver
/// Initializes a new instance of the <see cref="RfbMessageReceiver" />.
/// </summary>
/// <param name="context">The connection context.</param>
public RfbMessageReceiver(RfbConnectionContext context) : base("RFB Message Receiver")
public RfbMessageReceiver(RfbConnectionContext context)
{
_context = context;
_state = context.GetState<ProtocolState>();
Expand All @@ -51,7 +51,7 @@ public Task StopReceiveLoopAsync()

// This method will not catch exceptions so the BackgroundThread base class will receive them,
// raise a "Failure" and trigger a reconnect.
protected override void ThreadWorker(CancellationToken cancellationToken)
protected override async Task ThreadWorker(CancellationToken cancellationToken)
{
// Get the transport stream so we don't have to call the getter every time
Debug.Assert(_context.Transport != null, "_context.Transport != null");
Expand All @@ -62,12 +62,12 @@ protected override void ThreadWorker(CancellationToken cancellationToken)
ImmutableDictionary<byte, IIncomingMessageType> incomingMessageLookup = _context.SupportedMessageTypes
.OfType<IIncomingMessageType>().ToImmutableDictionary(mt => mt.Id);

Span<byte> messageTypeBuffer = stackalloc byte[1];
var messageTypeBuffer = new byte[1];

while (!cancellationToken.IsCancellationRequested)
{
// Read message type
if (transportStream.Read(messageTypeBuffer) == 0)
if (await transportStream.ReadAsync(messageTypeBuffer.AsMemory(), cancellationToken) == 0)
{
throw new UnexpectedEndOfStreamException("Stream reached its end while reading next message type.");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using MarcusW.VncClient.Protocol.Implementation.MessageTypes.Outgoing;
using MarcusW.VncClient.Protocol.MessageTypes;
Expand All @@ -20,7 +20,7 @@ public class RfbMessageSender : BackgroundThread, IRfbMessageSender
private readonly RfbConnectionContext _context;
private readonly ILogger<RfbMessageSender> _logger;

private readonly BlockingCollection<QueueItem> _queue = new(new ConcurrentQueue<QueueItem>());
private readonly Channel<QueueItem> _queue = Channel.CreateUnbounded<QueueItem>();

private readonly ProtocolState _state;

Expand All @@ -30,7 +30,7 @@ public class RfbMessageSender : BackgroundThread, IRfbMessageSender
/// Initializes a new instance of the <see cref="RfbMessageSender" />.
/// </summary>
/// <param name="context">The connection context.</param>
public RfbMessageSender(RfbConnectionContext context) : base("RFB Message Sender")
public RfbMessageSender(RfbConnectionContext context)
{
_context = context;
_state = context.GetState<ProtocolState>();
Expand Down Expand Up @@ -79,7 +79,7 @@ public void EnqueueMessage<TMessageType>(IOutgoingMessage<TMessageType> message,
var messageType = GetAndCheckMessageType<TMessageType>();

// Add message to queue
_queue.Add(new(message, messageType), cancellationToken);
_queue.Writer.TryWrite(new QueueItem(message, messageType));
}

/// <inheritdoc />
Expand All @@ -104,7 +104,7 @@ public Task SendMessageAndWaitAsync<TMessageType>(IOutgoingMessage<TMessageType>
TaskCompletionSource<object?> completionSource = new(TaskCreationOptions.RunContinuationsAsynchronously);

// Add message to queue
_queue.Add(new(message, messageType, completionSource), cancellationToken);
_queue.Writer.TryWrite(new QueueItem(message, messageType, completionSource));

return completionSource.Task;
}
Expand All @@ -120,7 +120,7 @@ protected override void Dispose(bool disposing)
if (disposing)
{
SetQueueCancelled();
_queue.Dispose();
_queue.Writer.TryComplete();
}

_disposed = true;
Expand All @@ -130,16 +130,17 @@ protected override void Dispose(bool disposing)

// This method will not catch exceptions so the BackgroundThread base class will receive them,
// raise a "Failure" and trigger a reconnect.
protected override void ThreadWorker(CancellationToken cancellationToken)
protected override async Task ThreadWorker(CancellationToken cancellationToken)
{
try
{
Debug.Assert(_context.Transport != null, "_context.Transport != null");
ITransport transport = _context.Transport;

// Iterate over all queued items (will block if the queue is empty)
foreach (QueueItem queueItem in _queue.GetConsumingEnumerable(cancellationToken))
while (await _queue.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
{
QueueItem queueItem = await _queue.Reader.ReadAsync(cancellationToken).ConfigureAwait(false);
IOutgoingMessage<IOutgoingMessageType> message = queueItem.Message;
IOutgoingMessageType messageType = queueItem.MessageType;

Expand Down Expand Up @@ -192,8 +193,7 @@ private TMessageType GetAndCheckMessageType<TMessageType>() where TMessageType :

private void SetQueueCancelled()
{
_queue.CompleteAdding();
foreach (QueueItem queueItem in _queue)
while (_queue.Reader.TryRead(out QueueItem? queueItem))
queueItem.CompletionSource?.TrySetCanceled();
}

Expand Down
108 changes: 33 additions & 75 deletions src/MarcusW.VncClient/Utils/BackgroundThread.cs
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;

namespace MarcusW.VncClient.Utils;

/// <summary>
/// Base class for easier creation and clean cancellation of a background thread.
/// </summary>
[PublicAPI]
public abstract class BackgroundThread : IBackgroundThread
{
private readonly TaskCompletionSource<object?> _completedTcs = new();
private readonly object _startLock = new();

private readonly object _lock = new();
private readonly CancellationTokenSource _stopCts = new();
private readonly Thread _thread;

private volatile bool _disposed;

private bool _started;
private Task? _task;

/// <summary>
/// Initializes a new instance of the <see cref="BackgroundThread" />.
/// </summary>
/// <param name="name">The thread name.</param>
protected BackgroundThread(string name)
{
_thread = new(ThreadStart) {
Name = name,
IsBackground = true,
};
}
[Obsolete("The name field is no longer used")]
protected BackgroundThread(string name) : this() { }

/// <summary>
/// Initializes a new instance of the <see cref="BackgroundThread" />.
/// </summary>
protected BackgroundThread() { }

/// <inheritdoc />
public event EventHandler<BackgroundThreadFailedEventArgs>? Failed;
Expand All @@ -47,27 +44,7 @@ protected virtual void Dispose(bool disposing)

if (disposing)
{
try
{
// Ensure the thread is stopped
_stopCts.Cancel();
if (_thread.IsAlive)
{
// Block and wait for completion or hard-kill the thread after 1 second
if (!_thread.Join(TimeSpan.FromSeconds(1)))
{
// _thread.Abort(); -- This is obsolete and not supported
}
}
}
catch
{
// Ignore
}

// Just to be sure...
_completedTcs.TrySetResult(null);

_stopCts.Cancel();
_stopCts.Dispose();
}

Expand All @@ -84,15 +61,15 @@ protected void Start()
{
ObjectDisposedException.ThrowIf(_disposed, typeof(BackgroundThread));

lock (_startLock)
// Do your work...
try
{
if (_started)
{
throw new InvalidOperationException("Thread already started.");
}

_thread.Start(_stopCts.Token);
_started = true;
lock (_lock)
_task ??= ThreadWorker(_stopCts.Token);
}
catch (Exception exception) when (exception is not (OperationCanceledException or ThreadAbortException))
{
Failed?.Invoke(this, new BackgroundThreadFailedEventArgs(exception));
}
}

Expand All @@ -102,49 +79,30 @@ protected void Start()
/// <remarks>
/// It is safe to call this method multiple times.
/// </remarks>
protected Task StopAndWaitAsync()
protected async Task StopAndWaitAsync()
{
ObjectDisposedException.ThrowIf(_disposed, typeof(BackgroundThread));

lock (_startLock)
// Tell the thread to stop
await _stopCts.CancelAsync();

// Wait for completion
if (_task is not null)
{
if (!_started)
try
{
await _task.ConfigureAwait(false);
}
catch (Exception exception) when (exception is not (OperationCanceledException or ThreadAbortException))
{
throw new InvalidOperationException("Thread has not been started.");
Failed?.Invoke(this, new BackgroundThreadFailedEventArgs(exception));
}
}

// Tell the thread to stop
_stopCts.Cancel();

// Wait for completion
return _completedTcs.Task;
}

/// <summary>
/// Executes the work that should happen in the background.
/// </summary>
/// <param name="cancellationToken">The cancellation token that tells the method implementation when to complete.</param>
protected abstract void ThreadWorker(CancellationToken cancellationToken);

private void ThreadStart(object? parameter)
{
Debug.Assert(parameter != null, nameof(parameter) + " != null");
var cancellationToken = (CancellationToken)parameter;

try
{
// Do your work...
ThreadWorker(cancellationToken);
}
catch (Exception exception) when (!(exception is OperationCanceledException or ThreadAbortException))
{
Failed?.Invoke(this, new(exception));
}
finally
{
// Notify stop method that thread has completed
_completedTcs.TrySetResult(null);
}
}
protected abstract Task ThreadWorker(CancellationToken cancellationToken);
}
6 changes: 3 additions & 3 deletions tests/MarcusW.VncClient.Tests/Utils/BackgroundThreadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ public void Starts_ThreadWorker()
mock.Protected().Verify("ThreadWorker", Times.Exactly(1), ItExpr.IsAny<CancellationToken>());
}

private class CancellableThread() : BackgroundThread("Cancellable Thread")
private class CancellableThread : BackgroundThread
{
public new void Start() => base.Start();

public new Task StopAndWaitAsync() => base.StopAndWaitAsync();

protected override void ThreadWorker(CancellationToken cancellationToken)
protected override async Task ThreadWorker(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
Thread.Sleep(10);
await Task.Delay(10);
}
}
}

0 comments on commit 0eeca47

Please sign in to comment.