diff --git a/Microsoft.DurableTask.sln b/Microsoft.DurableTask.sln index 2b8bab8a..168be614 100644 --- a/Microsoft.DurableTask.sln +++ b/Microsoft.DurableTask.sln @@ -71,6 +71,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Analyzers.Tests", "test\Ana EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AzureFunctionsApp.Tests", "samples\AzureFunctionsUnitTests\AzureFunctionsApp.Tests.csproj", "{FC2692E7-79AE-400E-A50F-8E0BCC8C9BD9}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Services", "Services", "{A9CA1883-133C-49BE-8FA1-B6D6E27110A8}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Sidecar", "src\Services\Sidecar\Sidecar.csproj", "{47ACE256-E8C8-4734-B1D6-B9B39EBF6990}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Sidecar.App", "src\Services\Sidecar.App\Sidecar.App.csproj", "{2F2A8D76-6294-420D-B308-DC2D087EE6B1}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -185,6 +191,14 @@ Global {FC2692E7-79AE-400E-A50F-8E0BCC8C9BD9}.Debug|Any CPU.Build.0 = Debug|Any CPU {FC2692E7-79AE-400E-A50F-8E0BCC8C9BD9}.Release|Any CPU.ActiveCfg = Release|Any CPU {FC2692E7-79AE-400E-A50F-8E0BCC8C9BD9}.Release|Any CPU.Build.0 = Release|Any CPU + {47ACE256-E8C8-4734-B1D6-B9B39EBF6990}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {47ACE256-E8C8-4734-B1D6-B9B39EBF6990}.Debug|Any CPU.Build.0 = Debug|Any CPU + {47ACE256-E8C8-4734-B1D6-B9B39EBF6990}.Release|Any CPU.ActiveCfg = Release|Any CPU + {47ACE256-E8C8-4734-B1D6-B9B39EBF6990}.Release|Any CPU.Build.0 = Release|Any CPU + {2F2A8D76-6294-420D-B308-DC2D087EE6B1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2F2A8D76-6294-420D-B308-DC2D087EE6B1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2F2A8D76-6294-420D-B308-DC2D087EE6B1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2F2A8D76-6294-420D-B308-DC2D087EE6B1}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -220,6 +234,9 @@ Global {998E9D97-BD36-4A9D-81FC-5DAC1CE40083} = {8AFC9781-F6F1-4696-BB4A-9ED7CA9D612B} {541FCCCE-1059-4691-B027-F761CD80DE92} = {E5637F81-2FB9-4CD7-900D-455363B142A7} {FC2692E7-79AE-400E-A50F-8E0BCC8C9BD9} = {EFF7632B-821E-4CFC-B4A0-ED4B24296B17} + {A9CA1883-133C-49BE-8FA1-B6D6E27110A8} = {8AFC9781-F6F1-4696-BB4A-9ED7CA9D612B} + {47ACE256-E8C8-4734-B1D6-B9B39EBF6990} = {A9CA1883-133C-49BE-8FA1-B6D6E27110A8} + {2F2A8D76-6294-420D-B308-DC2D087EE6B1} = {A9CA1883-133C-49BE-8FA1-B6D6E27110A8} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {AB41CB55-35EA-4986-A522-387AB3402E71} diff --git a/src/Services/Sidecar.App/BackendType.cs b/src/Services/Sidecar.App/BackendType.cs new file mode 100644 index 00000000..21eed22b --- /dev/null +++ b/src/Services/Sidecar.App/BackendType.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Sidecar.App; + +/// +/// Represents the supported Durable Task storage provider backends. +/// +enum BackendType +{ + AzureStorage, + MSSQL, + Netherite, + Emulator, +} diff --git a/src/Services/Sidecar.App/IInputReader.cs b/src/Services/Sidecar.App/IInputReader.cs new file mode 100644 index 00000000..852d3de9 --- /dev/null +++ b/src/Services/Sidecar.App/IInputReader.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Sidecar.App; + +/// +/// Abstraction for reading from standard input. This abstraction allows tests to mock stdin. +/// +interface IInputReader +{ + /// + /// Reads a single line from standard input. + /// + Task ReadLineAsync(); +} diff --git a/src/Services/Sidecar.App/Logs.cs b/src/Services/Sidecar.App/Logs.cs new file mode 100644 index 00000000..61050744 --- /dev/null +++ b/src/Services/Sidecar.App/Logs.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.App +{ + static partial class Logs + { + [LoggerMessage( + EventId = 1, + Level = LogLevel.Information, + Message = "Initializing the Durable Task sidecar. Listen address = {address}, backend type = {backendType}.")] + public static partial void InitializingSidecar( + this ILogger logger, + string address, + string backendType); + + [LoggerMessage( + EventId = 2, + Level = LogLevel.Information, + Message = "Sidecar initialized successfully in {latencyMs}ms.")] + public static partial void SidecarInitialized( + this ILogger logger, + long latencyMs); + + [LoggerMessage( + EventId = 3, + Level = LogLevel.Error, + Message = "Sidecar listen port {port} is in use by another process!")] + public static partial void SidecarListenPortAlreadyInUse( + this ILogger logger, + int port); + + [LoggerMessage( + EventId = 4, + Level = LogLevel.Information, + Message = "The Durable Task sidecar is shutting down.")] + public static partial void SidecarShuttingDown(this ILogger logger); + } +} + diff --git a/src/Services/Sidecar.App/Program.cs b/src/Services/Sidecar.App/Program.cs new file mode 100644 index 00000000..f95369fc --- /dev/null +++ b/src/Services/Sidecar.App/Program.cs @@ -0,0 +1,254 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using CommandLine; +using DurableTask.AzureStorage; +using DurableTask.Core; +using DurableTask.SqlServer; +using Grpc.Net.Client; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.DurableTask.Protobuf; +using Microsoft.DurableTask.Sidecar.Grpc; +using static Microsoft.DurableTask.Protobuf.TaskHubSidecarService; + +namespace Microsoft.DurableTask.Sidecar.App; + +static class Program +{ + public static readonly InMemoryOrchestrationService SingletonLocalOrchestrationService = new(); + + // We allow stdin to be overwritten for in-process testing + public static IInputReader InputReader { get; set; } = new StandardInputReader(); + + // We allow an additional logger provider for in-process testing + public static ILoggerProvider? AdditionalLoggerProvider { get; set; } + + public static async Task Main(string[] args) => + await Parser.Default.ParseArguments(args).MapResult( + (StartOptions options) => OnStartCommand(options), + errors => Task.FromResult(1)); + + static async Task OnStartCommand(StartOptions options) + { + Stopwatch startupLatencyStopwatch = Stopwatch.StartNew(); + ILoggerFactory loggerFactory = GetLoggerFactory(); + ILogger log = loggerFactory.CreateLogger("Microsoft.DurableTask.Sidecar"); + + string listenAddress = $"http://0.0.0.0:{options.ListenPort}"; + log.InitializingSidecar(listenAddress, options.BackendType.ToString()); + + IOrchestrationService orchestrationService = GetOrchestrationService(options, loggerFactory); + await orchestrationService.CreateIfNotExistsAsync(); + + // TODO: Support clients that don't share the same runtime type as the service + IOrchestrationServiceClient orchestrationServiceClient = (IOrchestrationServiceClient)orchestrationService; + + IWebHost host; + try + { + host = new WebHostBuilder() + .UseKestrel(options => + { + // Need to force Http2 in Kestrel in unencrypted scenarios + // https://docs.microsoft.com/en-us/aspnet/core/grpc/troubleshoot?view=aspnetcore-3.0 + options.ConfigureEndpointDefaults(listenOptions => listenOptions.Protocols = HttpProtocols.Http2); + }) + .UseUrls(listenAddress) + .ConfigureServices(services => + { + services.AddGrpc(); + services.AddSingleton(loggerFactory); + services.AddSingleton(orchestrationService); + services.AddSingleton(orchestrationServiceClient); + services.AddSingleton(); + }) + .Configure(app => + { + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGrpcService(); + }); + }) + .Build(); + await host.StartAsync(); + + log.SidecarInitialized(startupLatencyStopwatch.ElapsedMilliseconds); + } + catch (IOException e) when (e.InnerException is AddressInUseException) + { + log.SidecarListenPortAlreadyInUse(options.ListenPort); + return 1; + } + + if (options.Interactive) + { + Console.ForegroundColor = ConsoleColor.White; + Console.WriteLine("Interactive mode. Type the name of an orchestrator and press [ENTER] to submit. Type 'exit' to quit."); + Console.WriteLine(); + Console.Write("> "); + Console.ResetColor(); + + // Create a gRPC channel to talk to the management service endpoint that we just started. + // Alternatively, we could consider making direct calls using TaskHubClient. + string localListenAddress = $"http://localhost:{options.ListenPort}"; + GrpcChannel grpcChannel = GrpcChannel.ForAddress(localListenAddress, new GrpcChannelOptions + { + // NOTE: This is a localhost connection, so we can safely disable TLS. + UnsafeUseInsecureChannelCallCredentials = true, + }); + + var client = new TaskHubSidecarServiceClient(grpcChannel); + + try + { + while (true) + { + string? input = (await ReadLineAsync())?.Trim(); + if (string.IsNullOrEmpty(input) || string.Equals(input, "help", StringComparison.OrdinalIgnoreCase)) + { + Console.ForegroundColor = ConsoleColor.White; + Console.WriteLine("Usage: {orchestrator-name} [{orchestrator-input}]"); + Console.Write("> "); + Console.ResetColor(); + continue; + } + + if (string.Equals(input, "exit", StringComparison.OrdinalIgnoreCase)) + { + break; + } + + string[] parts = input.Split(' '); + string name = parts.First(); + + var request = new CreateInstanceRequest + { + Name = name, + InstanceId = $"dt-interactive-{Guid.NewGuid():N}", + }; + + if (parts.Length > 1) + { + request.Input = parts[1]; + } + + await client.StartInstanceAsync(request); + } + } + finally + { + await grpcChannel.ShutdownAsync(); + } + } + else + { + // TODO: Block until we receive a SIGTERM or SIGKILL + await Task.Delay(Timeout.Infinite); + } + + log.SidecarShuttingDown(); + await host.StopAsync(); + host.Dispose(); + + return 0; + } + + static ILoggerFactory GetLoggerFactory() => LoggerFactory.Create(builder => + { + builder.AddSimpleConsole(options => + { + options.SingleLine = true; + options.UseUtcTimestamp = true; + options.TimestampFormat = "yyyy-MM-ddThh:mm:ss.ffffffZ "; + }); + + // TODO: Support Application Insights URLs for sovereign clouds + string? appInsightsKey = Environment.GetEnvironmentVariable("APPINSIGHTS_INSTRUMENTATIONKEY"); + if (!string.IsNullOrEmpty(appInsightsKey)) + { + builder.AddApplicationInsights(appInsightsKey); + } + + // Support a statically configured logger provider for in-memory testing. + if (AdditionalLoggerProvider != null) + { + builder.AddProvider(AdditionalLoggerProvider); + } + + // Sidecar logging can be optionally configured using environment variables. + string? sidecarLogLevelString = Environment.GetEnvironmentVariable("DURABLETASK_SIDECAR_LOGLEVEL"); + if (!Enum.TryParse(sidecarLogLevelString, ignoreCase: true, out LogLevel sidecarLogLevel)) + { + sidecarLogLevel = LogLevel.Information; + } + + // Storage provider logs should be warning+ by default and + // core execution logs (DurableTask.Core) should be information+ by default + // to support basic tracking. + builder.AddFilter("DurableTask", LogLevel.Warning); + builder.AddFilter("DurableTask.Core", LogLevel.Information); + builder.AddFilter("Microsoft.DurableTask.Sidecar", sidecarLogLevel); + + // ASP.NET Core logs to warning since they can otherwise be noisy. + // This should be increased if it's necessary to debug gRPC request/response issues. + builder.AddFilter("Microsoft.AspNetCore", LogLevel.Warning); + }); + + static IOrchestrationService GetOrchestrationService(StartOptions options, ILoggerFactory loggerFactory) + { + switch (options.BackendType) + { + case BackendType.AzureStorage: + const string AzureStorageConnectionStringName = "DURABLETASK_AZURESTORAGE_CONNECTIONSTRING"; + string? storageConnectionString = Environment.GetEnvironmentVariable(AzureStorageConnectionStringName); + if (string.IsNullOrEmpty(storageConnectionString)) + { + // Local storage emulator: "UseDevelopmentStorage=true" + throw new InvalidOperationException($"The Azure Storage provider requires a {AzureStorageConnectionStringName} environment variable."); + } + + var azureStorageSettings = new AzureStorageOrchestrationServiceSettings + { + TaskHubName = "DurableServerTests", + StorageConnectionString = storageConnectionString, + MaxQueuePollingInterval = TimeSpan.FromSeconds(5), + LoggerFactory = loggerFactory, + }; + return new AzureStorageOrchestrationService(azureStorageSettings); + + case BackendType.Emulator: + return SingletonLocalOrchestrationService; + + case BackendType.MSSQL: + const string SqlConnectionStringName = "DURABLETASK_MSSQL_CONNECTIONSTRING"; + string? sqlConnectionString = Environment.GetEnvironmentVariable(SqlConnectionStringName); + if (string.IsNullOrEmpty(sqlConnectionString)) + { + // Local Windows install: "Server=localhost;Database=DurableDB;Trusted_Connection=True;" + throw new InvalidOperationException($"The MSSQL storage provider requires a {SqlConnectionStringName} environment variable."); + } + + var mssqlSettings = new SqlOrchestrationServiceSettings(sqlConnectionString) + { + LoggerFactory = loggerFactory, + }; + return new SqlOrchestrationService(mssqlSettings); + + case BackendType.Netherite: + throw new NotSupportedException("Netherite is not yet supported."); + + default: + throw new ArgumentException($"Unknown backend type: {options.BackendType}"); + } + } + + static Task ReadLineAsync() => InputReader.ReadLineAsync(); + + class StandardInputReader : IInputReader + { + public Task ReadLineAsync() => Console.In.ReadLineAsync(); + } +} diff --git a/src/Services/Sidecar.App/Properties/launchSettings.json b/src/Services/Sidecar.App/Properties/launchSettings.json new file mode 100644 index 00000000..1fe6f7c6 --- /dev/null +++ b/src/Services/Sidecar.App/Properties/launchSettings.json @@ -0,0 +1,8 @@ +{ + "profiles": { + "DurableTask.Sidecar.App": { + "commandName": "Project", + "commandLineArgs": "--backend Emulator" + } + } +} \ No newline at end of file diff --git a/src/Services/Sidecar.App/Sidecar.App.csproj b/src/Services/Sidecar.App/Sidecar.App.csproj new file mode 100644 index 00000000..16f9b523 --- /dev/null +++ b/src/Services/Sidecar.App/Sidecar.App.csproj @@ -0,0 +1,28 @@ + + + + + + + Exe + dt + false + + + + + + + + + + + + + + + + + + + diff --git a/src/Services/Sidecar.App/StartOptions.cs b/src/Services/Sidecar.App/StartOptions.cs new file mode 100644 index 00000000..884d442f --- /dev/null +++ b/src/Services/Sidecar.App/StartOptions.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using CommandLine; + +namespace Microsoft.DurableTask.Sidecar.App; + +[Verb("start", HelpText = "Start a Durable Task sidecar")] +class StartOptions +{ + [Option("interactive", HelpText = "Interactively start and manage orchestrations.")] + public bool Interactive { get; set; } + + [Option("listenPort", HelpText = "The inbound gRPC port used to handle client requests.")] + public int ListenPort { get; set; } = 4001; + + [Option("backend", HelpText = "Storage backend to use for the started sidecar (AzureStorage, MSSQL, Netherite, or Emulator).")] + public BackendType BackendType { get; set; } +} diff --git a/src/Services/Sidecar/AsyncManualResetEvent.cs b/src/Services/Sidecar/AsyncManualResetEvent.cs new file mode 100644 index 00000000..b7cf6dc4 --- /dev/null +++ b/src/Services/Sidecar/AsyncManualResetEvent.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Sidecar; + +class AsyncManualResetEvent +{ + readonly object mutex = new(); + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public AsyncManualResetEvent(bool isSignaled) + { + if (isSignaled) + { + this.tcs.TrySetCanceled(); + } + } + + public async Task WaitAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + Task delayTask = Task.Delay(timeout, cancellationToken); + Task waitTask = this.tcs.Task; + + Task winner = await Task.WhenAny(waitTask, delayTask); + + // Await ensures we get a TaskCancelledException if there was a cancellation. + await winner; + + return winner == waitTask; + } + + public bool IsSignaled => this.tcs.Task.IsCompleted; + + /// + /// Puts the event in the signaled state, unblocking any waiting threads. + /// + public bool Set() + { + lock (this.mutex) + { + return this.tcs.TrySetResult(); + } + } + + /// + /// Puts this event into the unsignaled state, causing threads to block. + /// + public void Reset() + { + lock (this.mutex) + { + if (this.tcs.Task.IsCompleted) + { + this.tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + } +} diff --git a/src/Services/Sidecar/Dispatcher/ITaskExecutor.cs b/src/Services/Sidecar/Dispatcher/ITaskExecutor.cs new file mode 100644 index 00000000..535b9502 --- /dev/null +++ b/src/Services/Sidecar/Dispatcher/ITaskExecutor.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.History; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +interface ITaskExecutor +{ + /// + /// When implemented by a concrete type, executes an orchestrator and returns the next set of orchestrator actions. + /// + /// The instance ID information of the orchestrator to execute. + /// The history events for previous executions of this orchestration instance. + /// The history events that have not yet been executed by this orchestration instance. + /// + /// Returns a task containing the result of the orchestrator execution. These are effectively the side-effects of the + /// orchestrator code, such as calling activities, scheduling timers, etc. + /// + Task ExecuteOrchestrator( + OrchestrationInstance instance, + IEnumerable pastEvents, + IEnumerable newEvents); + + /// + /// When implemented by a concreate type, executes an activity task and returns its results. + /// + /// The instance ID information of the orchestration that scheduled this activity task. + /// The metadata of the activity task execution, including the activity name and input. + /// Returns a task that contains the execution result of the activity. + Task ExecuteActivity( + OrchestrationInstance instance, + TaskScheduledEvent activityEvent); +} diff --git a/src/Services/Sidecar/Dispatcher/ITrafficSignal.cs b/src/Services/Sidecar/Dispatcher/ITrafficSignal.cs new file mode 100644 index 00000000..60caafc1 --- /dev/null +++ b/src/Services/Sidecar/Dispatcher/ITrafficSignal.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +/// +/// A simple primitive that can be used to block logical threads until some condition occurs. +/// +interface ITrafficSignal +{ + /// + /// Provides a human-friendly reason for why the signal is in the "wait" state. + /// + string WaitReason { get; } + + /// + /// Blocks the caller for amount of time or until + /// the signal gets signaled (which is a detail of the implementation). + /// + /// The amount of time to wait until the signal is unblocked. + /// A cancellation token that can be used to interrupt a waiting caller. + /// + /// Returns true if the traffic signal is all-clear; false if we timed-out waiting for the signal to clear. + /// + /// + /// Thrown if is triggered while waiting. + /// + Task WaitAsync(TimeSpan waitTime, CancellationToken cancellationToken); +} + diff --git a/src/Services/Sidecar/Dispatcher/TaskActivityDispatcher.cs b/src/Services/Sidecar/Dispatcher/TaskActivityDispatcher.cs new file mode 100644 index 00000000..2a953951 --- /dev/null +++ b/src/Services/Sidecar/Dispatcher/TaskActivityDispatcher.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.History; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +class TaskActivityDispatcher( + ILogger log, + ITrafficSignal trafficSignal, + IOrchestrationService service, + ITaskExecutor taskExecutor) : WorkItemDispatcher(log, trafficSignal) +{ + readonly IOrchestrationService service = service; + readonly ITaskExecutor taskExecutor = taskExecutor; + + public override int MaxWorkItems => this.service.MaxConcurrentTaskActivityWorkItems; + + public override Task AbandonWorkItemAsync(TaskActivityWorkItem workItem) + { + return this.service.AbandonTaskActivityWorkItemAsync(workItem); + } + + public override Task FetchWorkItemAsync( + TimeSpan timeout, + CancellationToken cancellationToken) + { + return this.service.LockNextTaskActivityWorkItem(timeout, cancellationToken); + } + + protected override async Task ExecuteWorkItemAsync(TaskActivityWorkItem workItem) + { + TaskScheduledEvent scheduledEvent = (TaskScheduledEvent)workItem.TaskMessage.Event; + + // TODO: Error handling for internal errors (user code exceptions are handled by the executor). + ActivityExecutionResult result = await this.taskExecutor.ExecuteActivity( + instance: workItem.TaskMessage.OrchestrationInstance, + activityEvent: scheduledEvent); + + TaskMessage responseMessage = new() + { + Event = result.ResponseEvent, + OrchestrationInstance = workItem.TaskMessage.OrchestrationInstance, + }; + + await this.service.CompleteTaskActivityWorkItemAsync(workItem, responseMessage); + } + + public override int GetDelayInSecondsOnFetchException(Exception ex) + { + return this.service.GetDelayInSecondsAfterOnFetchException(ex); + } + + public override string GetWorkItemId(TaskActivityWorkItem workItem) + { + return workItem.Id; + } + + // No-op + public override Task ReleaseWorkItemAsync(TaskActivityWorkItem workItem) + { + return Task.CompletedTask; + } + + public override Task RenewWorkItemAsync(TaskActivityWorkItem workItem) + { + return this.service.RenewTaskActivityWorkItemLockAsync(workItem); + } +} \ No newline at end of file diff --git a/src/Services/Sidecar/Dispatcher/TaskHubDispatcherHost.cs b/src/Services/Sidecar/Dispatcher/TaskHubDispatcherHost.cs new file mode 100644 index 00000000..6890eaee --- /dev/null +++ b/src/Services/Sidecar/Dispatcher/TaskHubDispatcherHost.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +class TaskHubDispatcherHost : IDisposable +{ + readonly TaskOrchestrationDispatcher orchestrationDispatcher; + readonly TaskActivityDispatcher activityDispatcher; + + readonly IOrchestrationService orchestrationService; + readonly ILogger log; + + public TaskHubDispatcherHost( + ILoggerFactory loggerFactory, + ITrafficSignal trafficSignal, + IOrchestrationService orchestrationService, + ITaskExecutor taskExecutor) + { + this.orchestrationService = orchestrationService ?? throw new ArgumentNullException(nameof(orchestrationService)); + this.log = loggerFactory.CreateLogger("Microsoft.DurableTask.Sidecar"); + + this.orchestrationDispatcher = new TaskOrchestrationDispatcher(this.log, trafficSignal, orchestrationService, taskExecutor); + this.activityDispatcher = new TaskActivityDispatcher(this.log, trafficSignal, orchestrationService, taskExecutor); + } + + public void Dispose() + { + this.orchestrationDispatcher.Dispose(); + this.activityDispatcher.Dispose(); + } + + public async Task StartAsync(CancellationToken cancellationToken) + { + // Start any background processing in the orchestration service + await this.orchestrationService.StartAsync(); + + // Start the dispatchers, which will allow orchestrations/activities to execute + await Task.WhenAll( + this.orchestrationDispatcher.StartAsync(cancellationToken), + this.activityDispatcher.StartAsync(cancellationToken)); + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + // Stop the dispatchers from polling the orchestration service + await Task.WhenAll( + this.orchestrationDispatcher.StopAsync(cancellationToken), + this.activityDispatcher.StopAsync(cancellationToken)); + + // Tell the storage provider to stop doing any background work. + await this.orchestrationService.StopAsync(); + } +} diff --git a/src/Services/Sidecar/Dispatcher/TaskOrchestrationDispatcher.cs b/src/Services/Sidecar/Dispatcher/TaskOrchestrationDispatcher.cs new file mode 100644 index 00000000..3320aa36 --- /dev/null +++ b/src/Services/Sidecar/Dispatcher/TaskOrchestrationDispatcher.cs @@ -0,0 +1,435 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text; +using DurableTask.Core; +using DurableTask.Core.Command; +using DurableTask.Core.History; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +class TaskOrchestrationDispatcher( + ILogger log, + ITrafficSignal trafficSignal, + IOrchestrationService service, + ITaskExecutor taskExecutor) : WorkItemDispatcher(log, trafficSignal) +{ + readonly ILogger log = log; + readonly IOrchestrationService service = service; + readonly ITaskExecutor taskExecutor = taskExecutor; + + public override int MaxWorkItems => this.service.MaxConcurrentTaskOrchestrationWorkItems; + + public override Task AbandonWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + return this.service.AbandonTaskOrchestrationWorkItemAsync(workItem); + } + + public override Task FetchWorkItemAsync( + TimeSpan timeout, + CancellationToken cancellationToken) + { + return this.service.LockNextTaskOrchestrationWorkItemAsync(timeout, cancellationToken); + } + + protected override async Task ExecuteWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + // Convert the new messages into new history events + workItem.OrchestrationRuntimeState.AddEvent(new OrchestratorStartedEvent(-1)); + foreach (TaskMessage message in FilterAndSortMessages(workItem)) + { + workItem.OrchestrationRuntimeState.AddEvent(message.Event); + } + + OrchestrationInstance? instance = workItem.OrchestrationRuntimeState.OrchestrationInstance; + if (string.IsNullOrEmpty(instance?.InstanceId)) + { + throw new ArgumentException($"Could not find an orchestration instance ID in the work item's runtime state.", nameof(workItem)); + } + + // We loop for as long as the orchestrator does a ContinueAsNew + while (true) + { + if (this.log.IsEnabled(LogLevel.Debug)) + { + IList newEvents = workItem.OrchestrationRuntimeState.NewEvents; + string newEventSummary = GetEventSummaryForLogging(newEvents); + this.log.OrchestratorExecuting( + workItem.InstanceId, + workItem.OrchestrationRuntimeState.Name, + newEvents.Count, + newEventSummary); + } + + // Execute the orchestrator code and get back a set of new actions to take. + // IMPORTANT: This IEnumerable may be lazily evaluated and should only be enumerated once! + OrchestratorExecutionResult result = await this.taskExecutor.ExecuteOrchestrator( + instance, + workItem.OrchestrationRuntimeState.PastEvents, + workItem.OrchestrationRuntimeState.NewEvents); + + // Convert the actions into history events and messages. + // If the actions result in a continue-as-new state, + this.ApplyOrchestratorActions( + result, + ref workItem.OrchestrationRuntimeState, + out IList activityMessages, + out IList orchestratorMessages, + out IList timerMessages, + out OrchestrationState? updatedStatus, + out bool continueAsNew); + if (continueAsNew) + { + // Continue running the orchestration with a new history. + // Renew the lock if we're getting close to its expiration. + if (workItem.LockedUntilUtc != default && DateTime.UtcNow.AddMinutes(1) > workItem.LockedUntilUtc) + { + await this.service.RenewTaskOrchestrationWorkItemLockAsync(workItem); + } + + continue; + } + + // Commit the changes to the durable store + await this.service.CompleteTaskOrchestrationWorkItemAsync( + workItem, + workItem.OrchestrationRuntimeState, + activityMessages, + orchestratorMessages, + timerMessages, + continuedAsNewMessage: null /* not supported */, + updatedStatus); + + break; + } + } + + static string GetEventSummaryForLogging(IList actions) + { + if (actions.Count == 0) + { + return string.Empty; + } + else if (actions.Count == 1) + { + return actions[0].EventType.ToString(); + } + else + { + // Returns something like "TaskCompleted x5, TimerFired x1,..." + return string.Join(", ", actions + .GroupBy(a => a.EventType) + .Select(group => $"{group.Key} x{group.Count()}")); + } + } + + static IEnumerable FilterAndSortMessages(TaskOrchestrationWorkItem workItem) + { + // Group messages by their instance ID + static string GetGroupingKey(TaskMessage msg) => msg.OrchestrationInstance.InstanceId; + + // Within a group, put messages with a non-null execution ID first + static int GetSortOrderWithinGroup(TaskMessage msg) + { + if (msg.Event.EventType == EventType.ExecutionStarted) + { + // Prioritize ExecutionStarted messages + return 0; + } + else if (msg.OrchestrationInstance.ExecutionId != null) + { + // Prioritize messages with an execution ID + return 1; + } + else + { + return 2; + } + } + + string? executionId = workItem.OrchestrationRuntimeState?.OrchestrationInstance?.ExecutionId; + + foreach (var group in workItem.NewMessages.GroupBy(GetGroupingKey)) + { + // TODO: Filter out invalid messages (wrong execution ID, duplicate start/complete messages, etc.) + foreach (TaskMessage msg in group.OrderBy(GetSortOrderWithinGroup)) + { + yield return msg; + } + } + } + + void ApplyOrchestratorActions( + OrchestratorExecutionResult result, + ref OrchestrationRuntimeState runtimeState, + out IList activityMessages, + out IList orchestratorMessages, + out IList timerMessages, + out OrchestrationState? updatedStatus, + out bool continueAsNew) + { + if (string.IsNullOrEmpty(runtimeState.OrchestrationInstance?.InstanceId)) + { + throw new ArgumentException($"The provided {nameof(OrchestrationRuntimeState)} doesn't contain an instance ID!", nameof(runtimeState)); + } + + List? newActivityMessages = null; + List? newTimerMessages = null; + List? newOrchestratorMessages = null; + FailureDetails? failureDetails = null; + continueAsNew = false; + + runtimeState.Status = result.CustomStatus; + + foreach (OrchestratorAction action in result.Actions) + { + // TODO: Determine how to handle remaining actions if the instance completed with ContinueAsNew. + // TODO: Validate each of these actions to make sure they have the appropriate data. + if (action is ScheduleTaskOrchestratorAction scheduleTaskAction) + { + if (string.IsNullOrEmpty(scheduleTaskAction.Name)) + { + throw new ArgumentException($"The provided {nameof(ScheduleTaskOrchestratorAction)} has no Name property specified!", nameof(result)); + } + + TaskScheduledEvent scheduledEvent = new( + scheduleTaskAction.Id, + scheduleTaskAction.Name, + scheduleTaskAction.Version, + scheduleTaskAction.Input); + + newActivityMessages ??= new List(); + newActivityMessages.Add(new TaskMessage + { + Event = scheduledEvent, + OrchestrationInstance = runtimeState.OrchestrationInstance, + }); + + runtimeState.AddEvent(scheduledEvent); + } + else if (action is CreateTimerOrchestratorAction timerAction) + { + TimerCreatedEvent timerEvent = new(timerAction.Id, timerAction.FireAt); + + newTimerMessages ??= new List(); + newTimerMessages.Add(new TaskMessage + { + Event = new TimerFiredEvent(-1, timerAction.FireAt) + { + TimerId = timerAction.Id, + }, + OrchestrationInstance = runtimeState.OrchestrationInstance, + }); + + runtimeState.AddEvent(timerEvent); + } + else if (action is CreateSubOrchestrationAction subOrchestrationAction) + { + runtimeState.AddEvent(new SubOrchestrationInstanceCreatedEvent(subOrchestrationAction.Id) + { + Name = subOrchestrationAction.Name, + Version = subOrchestrationAction.Version, + InstanceId = subOrchestrationAction.InstanceId, + Input = subOrchestrationAction.Input, + }); + + ExecutionStartedEvent startedEvent = new(-1, subOrchestrationAction.Input) + { + Name = subOrchestrationAction.Name, + Version = subOrchestrationAction.Version, + OrchestrationInstance = new OrchestrationInstance + { + InstanceId = subOrchestrationAction.InstanceId, + ExecutionId = Guid.NewGuid().ToString("N"), + }, + ParentInstance = new ParentInstance + { + OrchestrationInstance = runtimeState.OrchestrationInstance, + Name = runtimeState.Name, + Version = runtimeState.Version, + TaskScheduleId = subOrchestrationAction.Id, + }, + Tags = subOrchestrationAction.Tags, + }; + + newOrchestratorMessages ??= new List(); + newOrchestratorMessages.Add(new TaskMessage + { + Event = startedEvent, + OrchestrationInstance = startedEvent.OrchestrationInstance, + }); + } + else if (action is SendEventOrchestratorAction sendEventAction) + { + if (string.IsNullOrEmpty(sendEventAction.Instance?.InstanceId)) + { + throw new ArgumentException($"The provided {nameof(SendEventOrchestratorAction)} doesn't contain an instance ID!"); + } + + EventSentEvent sendEvent = new(sendEventAction.Id) + { + InstanceId = sendEventAction.Instance.InstanceId, + Name = sendEventAction.EventName, + Input = sendEventAction.EventData, + }; + + runtimeState.AddEvent(sendEvent); + + newOrchestratorMessages ??= new List(); + newOrchestratorMessages.Add(new TaskMessage + { + Event = sendEvent, + OrchestrationInstance = runtimeState.OrchestrationInstance, + }); + } + else if (action is OrchestrationCompleteOrchestratorAction completeAction) + { + if (completeAction.OrchestrationStatus == OrchestrationStatus.ContinuedAsNew) + { + // Replace the existing runtime state with a complete new runtime state. + OrchestrationRuntimeState newRuntimeState = new(); + newRuntimeState.AddEvent(new OrchestratorStartedEvent(-1)); + newRuntimeState.AddEvent(new ExecutionStartedEvent(-1, completeAction.Result) + { + OrchestrationInstance = new OrchestrationInstance + { + InstanceId = runtimeState.OrchestrationInstance.InstanceId, + ExecutionId = Guid.NewGuid().ToString("N"), + }, + Tags = runtimeState.Tags, + ParentInstance = runtimeState.ParentInstance, + Name = runtimeState.Name, + Version = completeAction.NewVersion ?? runtimeState.Version + }); + newRuntimeState.Status = runtimeState.Status; + + // The orchestration may have completed with some pending events that need to be carried + // over to the new generation, such as unprocessed external event messages. + if (completeAction.CarryoverEvents != null) + { + foreach (HistoryEvent carryoverEvent in completeAction.CarryoverEvents) + { + newRuntimeState.AddEvent(carryoverEvent); + } + } + + runtimeState = newRuntimeState; + activityMessages = Array.Empty(); + orchestratorMessages = Array.Empty(); + timerMessages = Array.Empty(); + continueAsNew = true; + updatedStatus = null; + return; + } + else + { + this.log.OrchestratorCompleted( + runtimeState.OrchestrationInstance.InstanceId, + runtimeState.Name, + completeAction.OrchestrationStatus, + Encoding.UTF8.GetByteCount(completeAction.Result ?? string.Empty)); + } + + if (completeAction.OrchestrationStatus == OrchestrationStatus.Failed) + { + failureDetails = completeAction.FailureDetails; + } + + // NOTE: Failure details aren't being stored in the orchestration history, currently. + runtimeState.AddEvent(new ExecutionCompletedEvent( + completeAction.Id, + completeAction.Result, + completeAction.OrchestrationStatus)); + + // CONSIDER: Add support for fire-and-forget sub-orchestrations where + // we don't notify the parent that the orchestration completed. + if (runtimeState.ParentInstance != null) + { + HistoryEvent subOrchestratorCompletedEvent; + if (completeAction.OrchestrationStatus == OrchestrationStatus.Completed) + { + subOrchestratorCompletedEvent = new SubOrchestrationInstanceCompletedEvent( + eventId: -1, + runtimeState.ParentInstance.TaskScheduleId, + completeAction.Result); + } + else + { + subOrchestratorCompletedEvent = new SubOrchestrationInstanceFailedEvent( + eventId: -1, + runtimeState.ParentInstance.TaskScheduleId, + completeAction.Result, + completeAction.Details, + completeAction.FailureDetails); + } + + newOrchestratorMessages ??= new List(); + newOrchestratorMessages.Add(new TaskMessage + { + Event = subOrchestratorCompletedEvent, + OrchestrationInstance = runtimeState.ParentInstance.OrchestrationInstance, + }); + } + } + else + { + this.log.IgnoringUnknownOrchestratorAction( + runtimeState.OrchestrationInstance.InstanceId, + action.OrchestratorActionType); + } + } + + runtimeState.AddEvent(new OrchestratorCompletedEvent(-1)); + + activityMessages = (IList?)newActivityMessages ?? Array.Empty(); + timerMessages = (IList?)newTimerMessages ?? Array.Empty(); + orchestratorMessages = (IList?)newOrchestratorMessages ?? Array.Empty(); + + updatedStatus = new OrchestrationState + { + OrchestrationInstance = runtimeState.OrchestrationInstance, + ParentInstance = runtimeState.ParentInstance, + Name = runtimeState.Name, + Version = runtimeState.Version, + Status = runtimeState.Status, + Tags = runtimeState.Tags, + OrchestrationStatus = runtimeState.OrchestrationStatus, + CreatedTime = runtimeState.CreatedTime, + CompletedTime = runtimeState.CompletedTime, + LastUpdatedTime = DateTime.UtcNow, + Size = runtimeState.Size, + CompressedSize = runtimeState.CompressedSize, + Input = runtimeState.Input, + Output = runtimeState.Output, + ScheduledStartTime = runtimeState.ExecutionStartedEvent?.ScheduledStartTime, + FailureDetails = failureDetails, + }; + } + + static string GetShortHistoryEventDescription(HistoryEvent e) + { + if (Utils.TryGetTaskScheduledId(e, out int taskScheduledId)) + { + return $"{e.EventType}#{taskScheduledId}"; + } + else + { + return e.EventType.ToString(); + } + } + + public override int GetDelayInSecondsOnFetchException(Exception ex) => + this.service.GetDelayInSecondsAfterOnFetchException(ex); + + public override string GetWorkItemId(TaskOrchestrationWorkItem workItem) => workItem.InstanceId; + + public override Task ReleaseWorkItemAsync(TaskOrchestrationWorkItem workItem) => + this.service.ReleaseTaskOrchestrationWorkItemAsync(workItem); + + public override async Task RenewWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + await this.service.RenewTaskOrchestrationWorkItemLockAsync(workItem); + return workItem; + } +} diff --git a/src/Services/Sidecar/Dispatcher/WorkItemDispatcher.cs b/src/Services/Sidecar/Dispatcher/WorkItemDispatcher.cs new file mode 100644 index 00000000..071f6085 --- /dev/null +++ b/src/Services/Sidecar/Dispatcher/WorkItemDispatcher.cs @@ -0,0 +1,268 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +abstract class WorkItemDispatcher : IDisposable where T : class +{ + static int nextDispatcherId; + + readonly string name; + readonly ILogger log; + readonly ITrafficSignal trafficSignal; + + CancellationTokenSource? shutdownTcs; + Task? workItemExecuteLoop; + int currentWorkItems; + + public WorkItemDispatcher(ILogger log, ITrafficSignal trafficSignal) + { + this.log = log ?? throw new ArgumentNullException(nameof(log)); + this.trafficSignal = trafficSignal; + + this.name = $"{this.GetType().Name}-{Interlocked.Increment(ref nextDispatcherId)}"; + } + + public virtual int MaxWorkItems => 100; + + public virtual void Dispose() + { + this.shutdownTcs?.Dispose(); + } + + public abstract Task FetchWorkItemAsync(TimeSpan timeout, CancellationToken cancellationToken); + + protected abstract Task ExecuteWorkItemAsync(T workItem); + + public abstract Task ReleaseWorkItemAsync(T workItem); + + public abstract Task AbandonWorkItemAsync(T workItem); + + public abstract Task RenewWorkItemAsync(T workItem); + + public abstract string GetWorkItemId(T workItem); + + public abstract int GetDelayInSecondsOnFetchException(Exception ex); + + public virtual Task StartAsync(CancellationToken cancellationToken) + { + // Dispatchers can be stopped and started back up again + this.shutdownTcs?.Dispose(); + this.shutdownTcs = new CancellationTokenSource(); + + this.workItemExecuteLoop = Task.Run( + () => this.FetchAndExecuteLoop(this.shutdownTcs.Token), + CancellationToken.None); + + return Task.CompletedTask; + } + + public virtual async Task StopAsync(CancellationToken cancellationToken) + { + // Trigger the cancellation tokens being used for background processing. + this.shutdownTcs?.Cancel(); + + // Wait for the execution loop to complete to ensure we're not scheduling any new work + Task? executeLoop = this.workItemExecuteLoop; + if (executeLoop != null) + { + await executeLoop.WaitAsync(cancellationToken); + } + + // Wait for all outstanding work-item processing to complete for a fully graceful shutdown + await this.WaitForOutstandingWorkItems(cancellationToken); + } + + async Task WaitForAllClear(CancellationToken cancellationToken) + { + TimeSpan logInterval = TimeSpan.FromMinutes(1); + + // IMPORTANT: This logic assumes only a single logical "thread" is executing the receive loop, + // and that there's no possible race condition when comparing work-item counts. + DateTime nextLogTime = DateTime.MinValue; + while (this.currentWorkItems >= this.MaxWorkItems) + { + // Periodically log that we're waiting for available concurrency. + // No need to use UTC for this. Local time is a bit easier to debug. + DateTime now = DateTime.Now; + if (now >= nextLogTime) + { + this.log.FetchingThrottled( + dispatcher: this.name, + details: "The current active work-item count has reached the allowed maximum.", + this.currentWorkItems, + this.MaxWorkItems); + nextLogTime = now.Add(logInterval); + } + + // CONSIDER: Use a notification instead of polling. + await Task.Delay(TimeSpan.FromMilliseconds(500), cancellationToken); + } + + // The dispatcher can also be paused by external signals. + while (!await this.trafficSignal.WaitAsync(logInterval, cancellationToken)) + { + this.log.FetchingThrottled( + dispatcher: this.name, + details: this.trafficSignal.WaitReason, + this.currentWorkItems, + this.MaxWorkItems); + } + } + + async Task WaitForOutstandingWorkItems(CancellationToken cancellationToken) + { + DateTime nextLogTime = DateTime.MinValue; + while (this.currentWorkItems > 0) + { + // Periodically log that we're waiting for outstanding work items to complete. + // No need to use UTC for this. Local time is a bit easier to debug. + DateTime now = DateTime.Now; + if (now >= nextLogTime) + { + this.log.DispatcherStopping(this.name, this.currentWorkItems); + nextLogTime = now.AddMinutes(1); + } + + // CONSIDER: Use a notification instead of polling. + await Task.Delay(TimeSpan.FromMilliseconds(200), cancellationToken); + } + } + + // This method does not throw + async Task DelayOnException( + Exception exception, + string workItemId, + Func delayInSecondsPolicy, + CancellationToken cancellationToken) + { + try + { + int delaySeconds = delayInSecondsPolicy(exception); + if (delaySeconds > 0) + { + await Task.Delay(delaySeconds, cancellationToken); + } + } + catch (OperationCanceledException) + { + // Shutting down, do nothing + } + catch (Exception ex) + { + this.log.DispatchWorkItemFailure( + dispatcher: this.name, + action: "delay-on-exception", + workItemId, + details: ex.ToString()); + try + { + await Task.Delay(TimeSpan.FromSeconds(1), cancellationToken); + } + catch (OperationCanceledException) + { + // shutting down + } + } + } + + async Task FetchAndExecuteLoop(CancellationToken cancellationToken) + { + try + { + // The work-item receive loop feeds the execution loop + while (true) + { + T? workItem = null; + try + { + await this.WaitForAllClear(cancellationToken); + + this.log.FetchWorkItemStarting(this.name, this.currentWorkItems, this.MaxWorkItems); + Stopwatch sw = Stopwatch.StartNew(); + + workItem = await this.FetchWorkItemAsync(Timeout.InfiniteTimeSpan, cancellationToken); + + if (workItem != null) + { + this.currentWorkItems++; + this.log.FetchWorkItemCompleted( + this.name, + this.GetWorkItemId(workItem), + sw.ElapsedMilliseconds, + this.currentWorkItems, + this.MaxWorkItems); + + // Run the execution on a background thread, which must never be canceled. + _ = Task.Run(() => this.ExecuteWorkItem(workItem), CancellationToken.None); + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // shutting down + break; + } + catch (Exception ex) + { + string unknownWorkItemId = "(unknown)"; + this.log.DispatchWorkItemFailure( + dispatcher: this.name, + action: "fetchWorkItem", + workItemId: unknownWorkItemId, + details: ex.ToString()); + await this.DelayOnException( + ex, + unknownWorkItemId, + this.GetDelayInSecondsOnFetchException, + cancellationToken); + continue; + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // graceful shutdown + } + } + + async Task ExecuteWorkItem(T workItem) + { + try + { + // Execute the work item and wait for it to complete + await this.ExecuteWorkItemAsync(workItem); + } + catch (Exception ex) + { + this.log.DispatchWorkItemFailure( + dispatcher: this.name, + action: "execute", + workItemId: this.GetWorkItemId(workItem), + details: ex.ToString()); + + await this.AbandonWorkItemAsync(workItem); + } + finally + { + try + { + await this.ReleaseWorkItemAsync(workItem); + } + catch (Exception ex) + { + // Best effort + this.log.DispatchWorkItemFailure( + dispatcher: this.name, + action: "release", + workItemId: this.GetWorkItemId(workItem), + details: ex.ToString()); + } + + this.currentWorkItems--; + } + } +} + diff --git a/src/Services/Sidecar/Grpc/ProtobufUtils.cs b/src/Services/Sidecar/Grpc/ProtobufUtils.cs new file mode 100644 index 00000000..ab816b99 --- /dev/null +++ b/src/Services/Sidecar/Grpc/ProtobufUtils.cs @@ -0,0 +1,397 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.Command; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using Google.Protobuf.WellKnownTypes; +using Proto = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Sidecar.Grpc; + +static class ProtobufUtils +{ + internal static Proto.HistoryEvent ToHistoryEventProto(HistoryEvent e) + { + var payload = new Proto.HistoryEvent() + { + EventId = e.EventId, + Timestamp = Timestamp.FromDateTime(e.Timestamp), + }; + + switch (e.EventType) + { + case EventType.ContinueAsNew: + var continueAsNew = (ContinueAsNewEvent)e; + payload.ContinueAsNew = new Proto.ContinueAsNewEvent + { + Input = continueAsNew.Result, + }; + break; + case EventType.EventRaised: + var eventRaised = (EventRaisedEvent)e; + payload.EventRaised = new Proto.EventRaisedEvent + { + Name = eventRaised.Name, + Input = eventRaised.Input, + }; + break; + case EventType.EventSent: + var eventSent = (EventSentEvent)e; + payload.EventSent = new Proto.EventSentEvent + { + Name = eventSent.Name, + Input = eventSent.Input, + InstanceId = eventSent.InstanceId, + }; + break; + case EventType.ExecutionCompleted: + var completedEvent = (ExecutionCompletedEvent)e; + payload.ExecutionCompleted = new Proto.ExecutionCompletedEvent + { + OrchestrationStatus = Proto.OrchestrationStatus.Completed, + Result = completedEvent.Result, + }; + break; + case EventType.ExecutionFailed: + var failedEvent = (ExecutionCompletedEvent)e; + payload.ExecutionCompleted = new Proto.ExecutionCompletedEvent + { + OrchestrationStatus = Proto.OrchestrationStatus.Failed, + Result = failedEvent.Result, + }; + break; + case EventType.ExecutionStarted: + // Start of a new orchestration instance + var startedEvent = (ExecutionStartedEvent)e; + payload.ExecutionStarted = new Proto.ExecutionStartedEvent + { + Name = startedEvent.Name, + Version = startedEvent.Version, + Input = startedEvent.Input, + OrchestrationInstance = new Proto.OrchestrationInstance + { + InstanceId = startedEvent.OrchestrationInstance.InstanceId, + ExecutionId = startedEvent.OrchestrationInstance.ExecutionId, + }, + ParentInstance = startedEvent.ParentInstance == null ? null : new Proto.ParentInstanceInfo + { + Name = startedEvent.ParentInstance.Name, + Version = startedEvent.ParentInstance.Version, + TaskScheduledId = startedEvent.ParentInstance.TaskScheduleId, + OrchestrationInstance = new Proto.OrchestrationInstance + { + InstanceId = startedEvent.ParentInstance.OrchestrationInstance.InstanceId, + ExecutionId = startedEvent.ParentInstance.OrchestrationInstance.ExecutionId, + }, + }, + ScheduledStartTimestamp = startedEvent.ScheduledStartTime == null ? null : Timestamp.FromDateTime(startedEvent.ScheduledStartTime.Value), + ParentTraceContext = startedEvent.ParentTraceContext is null ? null : new Proto.TraceContext + { + TraceParent = startedEvent.ParentTraceContext.TraceParent, + TraceState = startedEvent.ParentTraceContext.TraceState, + }, + }; + break; + case EventType.ExecutionTerminated: + var terminatedEvent = (ExecutionTerminatedEvent)e; + payload.ExecutionTerminated = new Proto.ExecutionTerminatedEvent + { + Input = terminatedEvent.Input, + }; + break; + case EventType.TaskScheduled: + var taskScheduledEvent = (TaskScheduledEvent)e; + payload.TaskScheduled = new Proto.TaskScheduledEvent + { + Name = taskScheduledEvent.Name, + Version = taskScheduledEvent.Version, + Input = taskScheduledEvent.Input, + ParentTraceContext = taskScheduledEvent.ParentTraceContext is null ? null : new Proto.TraceContext + { + TraceParent = taskScheduledEvent.ParentTraceContext.TraceParent, + TraceState = taskScheduledEvent.ParentTraceContext.TraceState, + }, + }; + break; + case EventType.TaskCompleted: + var taskCompletedEvent = (TaskCompletedEvent)e; + payload.TaskCompleted = new Proto.TaskCompletedEvent + { + Result = taskCompletedEvent.Result, + TaskScheduledId = taskCompletedEvent.TaskScheduledId, + }; + break; + case EventType.TaskFailed: + var taskFailedEvent = (TaskFailedEvent)e; + payload.TaskFailed = new Proto.TaskFailedEvent + { + FailureDetails = GetFailureDetails(taskFailedEvent.FailureDetails), + TaskScheduledId = taskFailedEvent.TaskScheduledId, + }; + break; + case EventType.SubOrchestrationInstanceCreated: + var subOrchestrationCreated = (SubOrchestrationInstanceCreatedEvent)e; + payload.SubOrchestrationInstanceCreated = new Proto.SubOrchestrationInstanceCreatedEvent + { + Input = subOrchestrationCreated.Input, + InstanceId = subOrchestrationCreated.InstanceId, + Name = subOrchestrationCreated.Name, + Version = subOrchestrationCreated.Version, + }; + break; + case EventType.SubOrchestrationInstanceCompleted: + var subOrchestrationCompleted = (SubOrchestrationInstanceCompletedEvent)e; + payload.SubOrchestrationInstanceCompleted = new Proto.SubOrchestrationInstanceCompletedEvent + { + Result = subOrchestrationCompleted.Result, + TaskScheduledId = subOrchestrationCompleted.TaskScheduledId, + }; + break; + case EventType.SubOrchestrationInstanceFailed: + var subOrchestrationFailed = (SubOrchestrationInstanceFailedEvent)e; + payload.SubOrchestrationInstanceFailed = new Proto.SubOrchestrationInstanceFailedEvent + { + FailureDetails = GetFailureDetails(subOrchestrationFailed.FailureDetails), + TaskScheduledId = subOrchestrationFailed.TaskScheduledId, + }; + break; + case EventType.TimerCreated: + var timerCreatedEvent = (TimerCreatedEvent)e; + payload.TimerCreated = new Proto.TimerCreatedEvent + { + FireAt = Timestamp.FromDateTime(timerCreatedEvent.FireAt), + }; + break; + case EventType.TimerFired: + var timerFiredEvent = (TimerFiredEvent)e; + payload.TimerFired = new Proto.TimerFiredEvent + { + FireAt = Timestamp.FromDateTime(timerFiredEvent.FireAt), + TimerId = timerFiredEvent.TimerId, + }; + break; + case EventType.OrchestratorStarted: + // This event has no data + payload.OrchestratorStarted = new Proto.OrchestratorStartedEvent(); + break; + case EventType.OrchestratorCompleted: + // This event has no data + payload.OrchestratorCompleted = new Proto.OrchestratorCompletedEvent(); + break; + case EventType.GenericEvent: + var genericEvent = (GenericEvent)e; + payload.GenericEvent = new Proto.GenericEvent + { + Data = genericEvent.Data, + }; + break; + case EventType.HistoryState: + var historyStateEvent = (HistoryStateEvent)e; + payload.HistoryState = new Proto.HistoryStateEvent + { + OrchestrationState = new Proto.OrchestrationState + { + InstanceId = historyStateEvent.State.OrchestrationInstance.InstanceId, + Name = historyStateEvent.State.Name, + Version = historyStateEvent.State.Version, + Input = historyStateEvent.State.Input, + Output = historyStateEvent.State.Output, + ScheduledStartTimestamp = historyStateEvent.State.ScheduledStartTime == null ? null : Timestamp.FromDateTime(historyStateEvent.State.ScheduledStartTime.Value), + CreatedTimestamp = Timestamp.FromDateTime(historyStateEvent.State.CreatedTime), + LastUpdatedTimestamp = Timestamp.FromDateTime(historyStateEvent.State.LastUpdatedTime), + OrchestrationStatus = (Proto.OrchestrationStatus)historyStateEvent.State.OrchestrationStatus, + CustomStatus = historyStateEvent.State.Status, + }, + }; + break; + case EventType.ExecutionSuspended: + var suspendedEvent = (ExecutionSuspendedEvent)e; + payload.ExecutionSuspended = new Proto.ExecutionSuspendedEvent + { + Input = suspendedEvent.Reason, + }; + break; + case EventType.ExecutionResumed: + var resumedEvent = (ExecutionResumedEvent)e; + payload.ExecutionResumed = new Proto.ExecutionResumedEvent + { + Input = resumedEvent.Reason, + }; + break; + default: + throw new NotSupportedException($"Found unsupported history event '{e.EventType}'."); + } + + return payload; + } + + internal static OrchestratorAction ToOrchestratorAction(Proto.OrchestratorAction a) + { + switch (a.OrchestratorActionTypeCase) + { + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.ScheduleTask: + return new ScheduleTaskOrchestratorAction + { + Id = a.Id, + Input = a.ScheduleTask.Input, + Name = a.ScheduleTask.Name, + Version = a.ScheduleTask.Version, + }; + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.CreateSubOrchestration: + return new CreateSubOrchestrationAction + { + Id = a.Id, + Input = a.CreateSubOrchestration.Input, + Name = a.CreateSubOrchestration.Name, + InstanceId = a.CreateSubOrchestration.InstanceId, + Tags = null, // TODO + Version = a.CreateSubOrchestration.Version, + }; + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.CreateTimer: + return new CreateTimerOrchestratorAction + { + Id = a.Id, + FireAt = a.CreateTimer.FireAt.ToDateTime(), + }; + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.SendEvent: + return new SendEventOrchestratorAction + { + Id = a.Id, + Instance = new OrchestrationInstance + { + InstanceId = a.SendEvent.Instance.InstanceId, + ExecutionId = a.SendEvent.Instance.ExecutionId, + }, + EventName = a.SendEvent.Name, + EventData = a.SendEvent.Data, + }; + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.CompleteOrchestration: + var completedAction = a.CompleteOrchestration; + var action = new OrchestrationCompleteOrchestratorAction + { + Id = a.Id, + OrchestrationStatus = (OrchestrationStatus)completedAction.OrchestrationStatus, + Result = completedAction.Result, + Details = completedAction.Details, + FailureDetails = GetFailureDetails(completedAction.FailureDetails), + NewVersion = completedAction.NewVersion, + }; + + if (completedAction.CarryoverEvents?.Count > 0) + { + foreach (var e in completedAction.CarryoverEvents) + { + // Only raised events are supported for carryover + if (e.EventRaised is Proto.EventRaisedEvent eventRaised) + { + action.CarryoverEvents.Add(new EventRaisedEvent(e.EventId, eventRaised.Input) + { + Name = eventRaised.Name, + }); + } + + } + } + + return action; + default: + throw new NotSupportedException($"Received unsupported action type '{a.OrchestratorActionTypeCase}'."); + } + } + + internal static FailureDetails? GetFailureDetails(Proto.TaskFailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new FailureDetails( + failureDetails.ErrorType, + failureDetails.ErrorMessage, + failureDetails.StackTrace, + GetFailureDetails(failureDetails.InnerFailure), + failureDetails.IsNonRetriable); + } + + internal static Proto.TaskFailureDetails? GetFailureDetails(FailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new Proto.TaskFailureDetails + { + ErrorType = failureDetails.ErrorType, + ErrorMessage = failureDetails.ErrorMessage, + StackTrace = failureDetails.StackTrace, + InnerFailure = GetFailureDetails(failureDetails.InnerFailure), + IsNonRetriable = failureDetails.IsNonRetriable, + }; + } + + internal static OrchestrationQuery ToOrchestrationQuery(Proto.QueryInstancesRequest request) + { + var query = new OrchestrationQuery() + { + RuntimeStatus = request.Query.RuntimeStatus?.Select(status => (OrchestrationStatus)status).ToList(), + CreatedTimeFrom = request.Query.CreatedTimeFrom?.ToDateTime(), + CreatedTimeTo = request.Query.CreatedTimeTo?.ToDateTime(), + TaskHubNames = request.Query.TaskHubNames, + PageSize = request.Query.MaxInstanceCount, + ContinuationToken = request.Query.ContinuationToken, + InstanceIdPrefix = request.Query.InstanceIdPrefix, + FetchInputsAndOutputs = request.Query.FetchInputsAndOutputs, + }; + + return query; + } + + internal static Proto.QueryInstancesResponse CreateQueryInstancesResponse(OrchestrationQueryResult result, Proto.QueryInstancesRequest request) + { + Proto.QueryInstancesResponse response = new Proto.QueryInstancesResponse + { + ContinuationToken = result.ContinuationToken + }; + foreach (OrchestrationState state in result.OrchestrationState) + { + var orchestrationState = new Proto.OrchestrationState + { + InstanceId = state.OrchestrationInstance.InstanceId, + Name = state.Name, + Version = state.Version, + Input = state.Input, + Output = state.Output, + ScheduledStartTimestamp = state.ScheduledStartTime == null ? null : Timestamp.FromDateTime(state.ScheduledStartTime.Value), + CreatedTimestamp = Timestamp.FromDateTime(state.CreatedTime), + LastUpdatedTimestamp = Timestamp.FromDateTime(state.LastUpdatedTime), + OrchestrationStatus = (Proto.OrchestrationStatus)state.OrchestrationStatus, + CustomStatus = state.Status, + }; + response.OrchestrationState.Add(orchestrationState); + } + return response; + } + + internal static PurgeInstanceFilter ToPurgeInstanceFilter(Proto.PurgeInstancesRequest request) + { + var purgeInstanceFilter = new PurgeInstanceFilter( + request.PurgeInstanceFilter.CreatedTimeFrom.ToDateTime(), + request.PurgeInstanceFilter.CreatedTimeTo?.ToDateTime(), + request.PurgeInstanceFilter.RuntimeStatus?.Select(status => (OrchestrationStatus)status).ToList() + ); + return purgeInstanceFilter; + } + + internal static Proto.PurgeInstancesResponse CreatePurgeInstancesResponse(PurgeResult result) + { + Proto.PurgeInstancesResponse response = new Proto.PurgeInstancesResponse + { + DeletedInstanceCount = result.DeletedInstanceCount + }; + return response; + } +} diff --git a/src/Services/Sidecar/Grpc/TaskHubGrpcServer.cs b/src/Services/Sidecar/Grpc/TaskHubGrpcServer.cs new file mode 100644 index 00000000..0e31a5b6 --- /dev/null +++ b/src/Services/Sidecar/Grpc/TaskHubGrpcServer.cs @@ -0,0 +1,637 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using System.Diagnostics; +using DurableTask.Core; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using Microsoft.DurableTask.Sidecar.Dispatcher; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using P = Microsoft.DurableTask.Protobuf; +using System.Globalization; + +namespace Microsoft.DurableTask.Sidecar.Grpc; + +/// +/// A gRPC server that implements the service contract. +/// +public class TaskHubGrpcServer : P.TaskHubSidecarService.TaskHubSidecarServiceBase, ITaskExecutor, IDisposable +{ + static readonly Task EmptyCompleteTaskResponse = Task.FromResult(new P.CompleteTaskResponse()); + + readonly ConcurrentDictionary> pendingOrchestratorTasks = new(StringComparer.OrdinalIgnoreCase); + readonly ConcurrentDictionary> pendingActivityTasks = new(StringComparer.OrdinalIgnoreCase); + + readonly ILogger log; + readonly IOrchestrationService service; + readonly IOrchestrationServiceClient client; + readonly IHostApplicationLifetime hostLifetime; + readonly IOptions options; + readonly TaskHubDispatcherHost dispatcherHost; + readonly IsConnectedSignal isConnectedSignal = new(); + readonly SemaphoreSlim sendWorkItemLock = new(initialCount: 1); + + // Initialized when a client connects to this service to receive work-item commands. + IServerStreamWriter? workerToClientStream; + + /// + /// Initializes a new instance of the class. + /// + public TaskHubGrpcServer( + IHostApplicationLifetime hostApplicationLifetime, + ILoggerFactory loggerFactory, + IOrchestrationService service, + IOrchestrationServiceClient client, + IOptions options) + { + ArgumentNullException.ThrowIfNull(hostApplicationLifetime, nameof(hostApplicationLifetime)); + ArgumentNullException.ThrowIfNull(loggerFactory, nameof(loggerFactory)); + ArgumentNullException.ThrowIfNull(service, nameof(service)); + ArgumentNullException.ThrowIfNull(client, nameof(client)); + ArgumentNullException.ThrowIfNull(options, nameof(options)); + + this.service = service; + this.client = client; + this.log = loggerFactory.CreateLogger("Microsoft.DurableTask.Sidecar"); + this.dispatcherHost = new TaskHubDispatcherHost( + loggerFactory, + trafficSignal: this.isConnectedSignal, + orchestrationService: service, + taskExecutor: this); + + this.hostLifetime = hostApplicationLifetime; + this.options = options; + this.hostLifetime.ApplicationStarted.Register(this.OnApplicationStarted); + this.hostLifetime.ApplicationStopping.Register(this.OnApplicationStopping); + } + + /// + /// Disposes of the resources used by this instance. + /// + public void Dispose() + { + this.sendWorkItemLock.Dispose(); + GC.SuppressFinalize(this); + } + + async void OnApplicationStarted() + { + if (this.options.Value.Mode == TaskHubGrpcServerMode.ApiServerAndDispatcher) + { + // Wait for a client connection to be established before starting the dispatcher host. + // This ensures we don't do any wasteful polling of resources if no clients are available to process events. + await this.WaitForWorkItemClientConnection(); + await this.dispatcherHost.StartAsync(this.hostLifetime.ApplicationStopping); + } + } + + async void OnApplicationStopping() + { + if (this.options.Value.Mode == TaskHubGrpcServerMode.ApiServerAndDispatcher) + { + // Give a maximum of 60 minutes for outstanding tasks to complete. + // REVIEW: Is this enough? What if there's an activity job that takes 4 hours to complete? Should this be configurable? + using CancellationTokenSource timeoutCts = new(TimeSpan.FromMinutes(60)); + await this.dispatcherHost.StopAsync(timeoutCts.Token); + } + } + + /// + /// Blocks until a remote client calls the operation to start fetching work items. + /// + /// Returns a task that completes once a work-item client is connected. + async Task WaitForWorkItemClientConnection() + { + Stopwatch waitTimeStopwatch = Stopwatch.StartNew(); + TimeSpan logInterval = TimeSpan.FromMinutes(1); + + try + { + while (!await this.isConnectedSignal.WaitAsync(logInterval, this.hostLifetime.ApplicationStopping)) + { + this.log.WaitingForClientConnection(waitTimeStopwatch.Elapsed); + } + } + catch (OperationCanceledException) + { + // shutting down + } + } + + /// + public override Task Hello(Empty request, ServerCallContext context) => Task.FromResult(new Empty()); + + /// + public override Task CreateTaskHub(P.CreateTaskHubRequest request, ServerCallContext context) + { + this.service.CreateAsync(request.RecreateIfExists); + return Task.FromResult(new P.CreateTaskHubResponse()); + } + + /// + public override Task DeleteTaskHub(P.DeleteTaskHubRequest request, ServerCallContext context) + { + this.service.DeleteAsync(); + return Task.FromResult(new P.DeleteTaskHubResponse()); + } + + /// + public override async Task StartInstance(P.CreateInstanceRequest request, ServerCallContext context) + { + var instance = new OrchestrationInstance + { + InstanceId = request.InstanceId ?? Guid.NewGuid().ToString("N"), + ExecutionId = Guid.NewGuid().ToString(), + }; + + this.log.CreatingNewInstance(instance.InstanceId); + + await this.client.CreateTaskOrchestrationAsync( + new TaskMessage + { + Event = new ExecutionStartedEvent(-1, request.Input) + { + Name = request.Name, + Version = request.Version, + OrchestrationInstance = instance, + }, + OrchestrationInstance = instance, + }); + + return new P.CreateInstanceResponse + { + InstanceId = instance.InstanceId, + }; + } + + /// + public override async Task RaiseEvent(P.RaiseEventRequest request, ServerCallContext context) + { + this.log.RaisingEvent(request.InstanceId, request.Name); + + await this.client.SendTaskOrchestrationMessageAsync( + new TaskMessage + { + Event = new EventRaisedEvent(-1, request.Input) + { + Name = request.Name, + }, + OrchestrationInstance = new OrchestrationInstance + { + InstanceId = request.InstanceId, + }, + }); + + // No fields in the response + return new P.RaiseEventResponse(); + } + + /// + public override async Task TerminateInstance( + P.TerminateRequest request, + ServerCallContext context) + { + this.log.TerminatingInstance(request.InstanceId); + + await this.client.ForceTerminateTaskOrchestrationAsync( + request.InstanceId, + request.Output); + + // No fields in the response + return new P.TerminateResponse(); + } + + /// + public override async Task GetInstance( + P.GetInstanceRequest request, + ServerCallContext context) + { + OrchestrationState state = await this.client.GetOrchestrationStateAsync(request.InstanceId, executionId: null); + if (state == null) + { + return new P.GetInstanceResponse() { Exists = false }; + } + + return CreateGetInstanceResponse(state, request); + } + + /// + public override async Task QueryInstances(P.QueryInstancesRequest request, ServerCallContext context) + { + if (this.client is IOrchestrationServiceQueryClient queryClient) + { + OrchestrationQuery query = ProtobufUtils.ToOrchestrationQuery(request); + OrchestrationQueryResult result = await queryClient.GetOrchestrationWithQueryAsync(query, context.CancellationToken); + return ProtobufUtils.CreateQueryInstancesResponse(result, request); + } + else + { + throw new NotSupportedException($"{this.client.GetType().Name} doesn't support query operations."); + } + } + + /// + public override async Task PurgeInstances(P.PurgeInstancesRequest request, ServerCallContext context) + { + if (this.client is IOrchestrationServicePurgeClient purgeClient) + { + PurgeResult result; + switch (request.RequestCase) + { + case P.PurgeInstancesRequest.RequestOneofCase.InstanceId: + result = await purgeClient.PurgeInstanceStateAsync(request.InstanceId); + break; + + case P.PurgeInstancesRequest.RequestOneofCase.PurgeInstanceFilter: + PurgeInstanceFilter purgeInstanceFilter = ProtobufUtils.ToPurgeInstanceFilter(request); + result = await purgeClient.PurgeInstanceStateAsync(purgeInstanceFilter); + break; + + default: + throw new ArgumentException($"Unknown purge request type '{request.RequestCase}'."); + } + return ProtobufUtils.CreatePurgeInstancesResponse(result); + } + else + { + throw new NotSupportedException($"{this.client.GetType().Name} doesn't support purge operations."); + } + } + + /// + public override async Task WaitForInstanceStart(P.GetInstanceRequest request, ServerCallContext context) + { + while (true) + { + // Keep fetching the status until we get one of the states we care about + OrchestrationState state = await this.client.GetOrchestrationStateAsync(request.InstanceId, executionId: null); + if (state != null && state.OrchestrationStatus != OrchestrationStatus.Pending) + { + return CreateGetInstanceResponse(state, request); + } + + // TODO: Backoff strategy if we're delaying for a long time. + // The cancellation token is what will break us out of this loop if the orchestration + // never leaves the "Pending" state. + await Task.Delay(TimeSpan.FromMilliseconds(500), context.CancellationToken); + } + } + + /// + public override async Task WaitForInstanceCompletion(P.GetInstanceRequest request, ServerCallContext context) + { + OrchestrationState state = await this.client.WaitForOrchestrationAsync( + request.InstanceId, + executionId: null, + timeout: Timeout.InfiniteTimeSpan, + context.CancellationToken); + + return CreateGetInstanceResponse(state, request); + } + + static P.GetInstanceResponse CreateGetInstanceResponse(OrchestrationState state, P.GetInstanceRequest request) + { + return new P.GetInstanceResponse + { + Exists = true, + OrchestrationState = new P.OrchestrationState + { + InstanceId = state.OrchestrationInstance.InstanceId, + Name = state.Name, + OrchestrationStatus = (P.OrchestrationStatus)state.OrchestrationStatus, + CreatedTimestamp = Timestamp.FromDateTime(state.CreatedTime), + LastUpdatedTimestamp = Timestamp.FromDateTime(state.LastUpdatedTime), + Input = request.GetInputsAndOutputs ? state.Input : null, + Output = request.GetInputsAndOutputs ? state.Output : null, + CustomStatus = request.GetInputsAndOutputs ? state.Status : null, + FailureDetails = request.GetInputsAndOutputs ? GetFailureDetails(state.FailureDetails) : null, + }, + }; + } + + /// + public override async Task SuspendInstance(P.SuspendRequest request, ServerCallContext context) + { + TaskMessage taskMessage = new() + { + OrchestrationInstance = new OrchestrationInstance { InstanceId = request.InstanceId }, + Event = new ExecutionSuspendedEvent(-1, request.Reason), + }; + + await this.client.SendTaskOrchestrationMessageAsync(taskMessage); + return new P.SuspendResponse(); + } + + /// + public override async Task ResumeInstance(P.ResumeRequest request, ServerCallContext context) + { + TaskMessage taskMessage = new() + { + OrchestrationInstance = new OrchestrationInstance { InstanceId = request.InstanceId }, + Event = new ExecutionResumedEvent(-1, request.Reason), + }; + + await this.client.SendTaskOrchestrationMessageAsync(taskMessage); + return new P.ResumeResponse(); + } + + static P.TaskFailureDetails? GetFailureDetails(FailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new P.TaskFailureDetails + { + ErrorType = failureDetails.ErrorType, + ErrorMessage = failureDetails.ErrorMessage, + StackTrace = failureDetails.StackTrace, + InnerFailure = GetFailureDetails(failureDetails.InnerFailure), + }; + } + + /// + /// Invoked by the remote SDK over gRPC when an orchestrator task (episode) is completed. + /// + /// Details about the orchestration execution, including the list of orchestrator actions. + /// Context for the server-side gRPC call. + /// Returns an empty ack back to the remote SDK that we've received the completion. + public override Task CompleteOrchestratorTask(P.OrchestratorResponse request, ServerCallContext context) + { + if (!this.pendingOrchestratorTasks.TryRemove( + request.InstanceId, + out TaskCompletionSource? tcs)) + { + // TODO: Log? + throw new RpcException(new Status(StatusCode.NotFound, $"Orchestration not found")); + } + + OrchestratorExecutionResult result = new() + { + Actions = request.Actions.Select(ProtobufUtils.ToOrchestratorAction), + CustomStatus = request.CustomStatus, + }; + + tcs.TrySetResult(result); + + return EmptyCompleteTaskResponse; + } + + /// + /// Invoked by the remote SDK over gRPC when an activity task (episode) is completed. + /// + /// Details about the completed activity task, including the output. + /// Context for the server-side gRPC call. + /// Returns an empty ack back to the remote SDK that we've received the completion. + public override Task CompleteActivityTask( + P.ActivityResponse request, + ServerCallContext context) + { + string taskIdKey = GetTaskIdKey(request.InstanceId, request.TaskId); + if (!this.pendingActivityTasks.TryRemove(taskIdKey, out TaskCompletionSource? tcs)) + { + // TODO: Log? + throw new RpcException(new Status(StatusCode.NotFound, $"Activity not found")); + } + + HistoryEvent resultEvent; + if (request.FailureDetails == null) + { + resultEvent = new TaskCompletedEvent(-1, request.TaskId, request.Result); + } + else + { + resultEvent = new TaskFailedEvent( + eventId: -1, + taskScheduledId: request.TaskId, + reason: null, + details: null, + failureDetails: ProtobufUtils.GetFailureDetails(request.FailureDetails)); + } + + tcs.TrySetResult(new ActivityExecutionResult { ResponseEvent = resultEvent }); + return EmptyCompleteTaskResponse; + } + + /// + public override async Task GetWorkItems(P.GetWorkItemsRequest request, IServerStreamWriter responseStream, ServerCallContext context) + { + // Use a lock to mitigate the race condition where we signal the dispatch host to start but haven't + // yet saved a reference to the client response stream. + lock (this.isConnectedSignal) + { + int retryCount = 0; + while (!this.isConnectedSignal.Set()) + { + // Retries are needed when a client (like a test suite) connects and disconnects rapidly, causing a race + // condition where we don't reset the signal quickly enough to avoid ResourceExausted errors. + if (retryCount <= 5) + { + Thread.Sleep(10); // Can't use await inside the body of a lock statement so we have to block the thread + } + else + { + throw new RpcException(new Status(StatusCode.ResourceExhausted, "Another client is already connected")); + } + } + + this.log.ClientConnected(context.Peer, context.Deadline); + this.workerToClientStream = responseStream; + } + + try + { + await Task.Delay(Timeout.InfiniteTimeSpan, context.CancellationToken); + } + catch (OperationCanceledException) + { + this.log.ClientDisconnected(context.Peer); + } + finally + { + // Resetting this signal causes the dispatchers to stop fetching new work. + this.isConnectedSignal.Reset(); + + // Transition back to the "waiting for connection" state. + // This background task is just to log "waiting for connection" messages. + _ = Task.Run(this.WaitForWorkItemClientConnection); + } + } + + /// + /// Invoked by the when a work item is available, proxies the call to execute an orchestrator over a gRPC channel. + /// + /// + async Task ITaskExecutor.ExecuteOrchestrator( + OrchestrationInstance instance, + IEnumerable pastEvents, + IEnumerable newEvents) + { + // Create a task completion source that represents the async completion of the orchestrator execution. + // This must be done before we start the orchestrator execution. + TaskCompletionSource tcs = + this.CreateTaskCompletionSourceForOrchestrator(instance.InstanceId); + + try + { + await this.SendWorkItemToClientAsync(new P.WorkItem + { + OrchestratorRequest = new P.OrchestratorRequest + { + InstanceId = instance.InstanceId, + ExecutionId = instance.ExecutionId, + NewEvents = { newEvents.Select(ProtobufUtils.ToHistoryEventProto) }, + PastEvents = { pastEvents.Select(ProtobufUtils.ToHistoryEventProto) }, + } + }); + } + catch + { + // Remove the TaskCompletionSource that we just created + this.RemoveOrchestratorTaskCompletionSource(instance.InstanceId); + throw; + } + + // The TCS will be completed on the message stream handler when it gets a response back from the remote process + // TODO: How should we handle timeouts if the remote process never sends a response? + // Probably need to have a static timeout (e.g. 5 minutes). + return await tcs.Task; + } + + async Task ITaskExecutor.ExecuteActivity(OrchestrationInstance instance, TaskScheduledEvent activityEvent) + { + // Create a task completion source that represents the async completion of the activity. + // This must be done before we start the activity execution. + TaskCompletionSource tcs = this.CreateTaskCompletionSourceForActivity( + instance.InstanceId, + activityEvent.EventId); + + try + { + await this.SendWorkItemToClientAsync(new P.WorkItem + { + ActivityRequest = new P.ActivityRequest + { + Name = activityEvent.Name, + Version = activityEvent.Version, + Input = activityEvent.Input, + TaskId = activityEvent.EventId, + OrchestrationInstance = new P.OrchestrationInstance + { + InstanceId = instance.InstanceId, + ExecutionId = instance.ExecutionId, + }, + } + }); + } + catch + { + // Remove the TaskCompletionSource that we just created + this.RemoveActivityTaskCompletionSource(instance.InstanceId, activityEvent.EventId); + throw; + } + + // The TCS will be completed on the message stream handler when it gets a response back from the remote process. + // TODO: How should we handle timeouts if the remote process never sends a response? + // Probably need a timeout feature for activities and/or a heartbeat API that activities + // can use to signal that they're still running. + return await tcs.Task; + } + + async Task SendWorkItemToClientAsync(P.WorkItem workItem) + { + IServerStreamWriter outputStream; + + // Use a lock to mitigate the race condition where we signal the dispatch host to start but haven't + // yet saved a reference to the client response stream. + lock (this.isConnectedSignal) + { + outputStream = this.workerToClientStream ?? + throw new InvalidOperationException( + "No client is connected! Need to wait until a client connects before executing!"); + } + + // The gRPC channel can only handle one message at a time, so we need to serialize access to it. + await this.sendWorkItemLock.WaitAsync(); + try + { + await outputStream.WriteAsync(workItem); + } + finally + { + this.sendWorkItemLock.Release(); + } + } + + TaskCompletionSource CreateTaskCompletionSourceForOrchestrator(string instanceId) + { + TaskCompletionSource tcs = new(); + this.pendingOrchestratorTasks.TryAdd(instanceId, tcs); + return tcs; + } + + void RemoveOrchestratorTaskCompletionSource(string instanceId) + { + this.pendingOrchestratorTasks.TryRemove(instanceId, out _); + } + + TaskCompletionSource CreateTaskCompletionSourceForActivity(string instanceId, int taskId) + { + string taskIdKey = GetTaskIdKey(instanceId, taskId); + TaskCompletionSource tcs = new(); + this.pendingActivityTasks.TryAdd(taskIdKey, tcs); + return tcs; + } + + void RemoveActivityTaskCompletionSource(string instanceId, int taskId) + { + string taskIdKey = GetTaskIdKey(instanceId, taskId); + this.pendingActivityTasks.TryRemove(taskIdKey, out _); + } + + static string GetTaskIdKey(string instanceId, int taskId) + { + return string.Concat(instanceId, "__", taskId.ToString(CultureInfo.InvariantCulture)); + } + + /// + /// A implementation that is used to control whether the task hub + /// dispatcher can fetch new work-items, based on whether a client is currently connected. + /// + class IsConnectedSignal : ITrafficSignal + { + readonly AsyncManualResetEvent isConnectedEvent = new(isSignaled: false); + + /// + public string WaitReason => "Waiting for a client to connect"; + + /// + /// Blocks the caller until the method is called, which means a client is connected. + /// + /// + public Task WaitAsync(TimeSpan waitTime, CancellationToken cancellationToken) + { + return this.isConnectedEvent.WaitAsync(waitTime, cancellationToken); + } + + /// + /// Signals the dispatchers to start fetching new work-items. + /// + /// + /// Returns true if the current thread transitioned the event to the "signaled" state; + /// otherwise false, meaning some other thread already called on this signal. + /// + public bool Set() => this.isConnectedEvent.Set(); + + /// + /// Causes the dispatchers to stop fetching new work-items. + /// + public void Reset() => this.isConnectedEvent.Reset(); + } +} diff --git a/src/Services/Sidecar/Grpc/TaskHubGrpcServerOptions.cs b/src/Services/Sidecar/Grpc/TaskHubGrpcServerOptions.cs new file mode 100644 index 00000000..ffeeacd5 --- /dev/null +++ b/src/Services/Sidecar/Grpc/TaskHubGrpcServerOptions.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Sidecar.Grpc; + +/// +/// Options for configuring the task hub gRPC server. +/// +public class TaskHubGrpcServerOptions +{ + /// + /// The high-level mode of operation for the gRPC server. + /// + public TaskHubGrpcServerMode Mode { get; set; } +} + +/// +/// A set of options that determine what capabilities are enabled for the gRPC server. +/// +public enum TaskHubGrpcServerMode +{ + /// + /// The gRPC server handles both orchestration dispatching and management API requests. + /// + ApiServerAndDispatcher, + + /// + /// The gRPC server handles management API requests but not orchestration dispatching. + /// + ApiServerOnly, +} \ No newline at end of file diff --git a/src/Services/Sidecar/InMemoryOrchestrationService.cs b/src/Services/Sidecar/InMemoryOrchestrationService.cs new file mode 100644 index 00000000..555f6e3c --- /dev/null +++ b/src/Services/Sidecar/InMemoryOrchestrationService.cs @@ -0,0 +1,781 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Text.Json.Nodes; +using System.Threading.Channels; +using DurableTask.Core; +using DurableTask.Core.Exceptions; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.DurableTask.Sidecar; + +/// +/// In-memory implementation of the Durable Task backend storage provider. +/// +public class InMemoryOrchestrationService : + IOrchestrationService, + IOrchestrationServiceClient, + IOrchestrationServiceQueryClient, + IOrchestrationServicePurgeClient +{ + readonly InMemoryQueue activityQueue = new(); + readonly InMemoryInstanceStore instanceStore; + readonly ILogger logger; + + int IOrchestrationService.TaskOrchestrationDispatcherCount => 1; + + int IOrchestrationService.TaskActivityDispatcherCount => 1; + + int IOrchestrationService.MaxConcurrentTaskOrchestrationWorkItems => Environment.ProcessorCount; + + int IOrchestrationService.MaxConcurrentTaskActivityWorkItems => Environment.ProcessorCount; + + BehaviorOnContinueAsNew IOrchestrationService.EventBehaviourForContinueAsNew => BehaviorOnContinueAsNew.Carryover; + + /// + /// Initializes a new instance of the class. + /// + /// The logger factory to use for logging. + public InMemoryOrchestrationService(ILoggerFactory? loggerFactory = null) + { + this.logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger("Microsoft.DurableTask.Sidecar.InMemoryStorageProvider"); + this.instanceStore = new InMemoryInstanceStore(this.logger); + } + + /// + public Task AbandonTaskActivityWorkItemAsync(TaskActivityWorkItem workItem) + { + this.logger.AbandoningTaskActivityWorkItem(workItem.Id); + this.activityQueue.Enqueue(workItem.TaskMessage); + return Task.CompletedTask; + } + + /// + public Task AbandonTaskOrchestrationWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + this.instanceStore.AbandonInstance(workItem.NewMessages); + return Task.CompletedTask; + } + + /// + public Task CompleteTaskActivityWorkItemAsync(TaskActivityWorkItem workItem, TaskMessage responseMessage) + { + this.instanceStore.AddMessage(responseMessage); + return Task.CompletedTask; + } + + /// + public Task CompleteTaskOrchestrationWorkItemAsync( + TaskOrchestrationWorkItem workItem, + OrchestrationRuntimeState newOrchestrationRuntimeState, + IList outboundMessages, + IList orchestratorMessages, + IList timerMessages, + TaskMessage continuedAsNewMessage, + OrchestrationState orchestrationState) + { + this.instanceStore.SaveState( + runtimeState: newOrchestrationRuntimeState, + statusRecord: orchestrationState, + newMessages: orchestratorMessages.Union(timerMessages).Append(continuedAsNewMessage).Where(msg => msg != null)); + + this.activityQueue.Enqueue(outboundMessages); + return Task.CompletedTask; + } + + /// + public Task CreateAsync() => Task.CompletedTask; + + /// + public Task CreateAsync(bool recreateInstanceStore) + { + if (recreateInstanceStore) + { + this.instanceStore.Reset(); + } + return Task.CompletedTask; + } + + /// + public Task CreateIfNotExistsAsync() => Task.CompletedTask; + + /// + public Task CreateTaskOrchestrationAsync(TaskMessage creationMessage) + { + return this.CreateTaskOrchestrationAsync( + creationMessage, + [OrchestrationStatus.Pending, OrchestrationStatus.Running]); + } + + /// + public Task CreateTaskOrchestrationAsync(TaskMessage creationMessage, OrchestrationStatus[]? dedupeStatuses) + { + // Lock the instance store to prevent multiple "create" threads from racing with each other. + lock (this.instanceStore) + { + string instanceId = creationMessage.OrchestrationInstance.InstanceId; + if (this.instanceStore.TryGetState(instanceId, out OrchestrationState? statusRecord) && + dedupeStatuses != null && + dedupeStatuses.Contains(statusRecord.OrchestrationStatus)) + { + throw new OrchestrationAlreadyExistsException($"An orchestration with id '{instanceId}' already exists. It's in the {statusRecord.OrchestrationStatus} state."); + } + + this.instanceStore.AddMessage(creationMessage); + } + + return Task.CompletedTask; + } + + /// + public Task DeleteAsync() => this.DeleteAsync(true); + + /// + public Task DeleteAsync(bool deleteInstanceStore) + { + if (deleteInstanceStore) + { + this.instanceStore.Reset(); + } + return Task.CompletedTask; + } + + /// + public Task ForceTerminateTaskOrchestrationAsync(string instanceId, string reason) + { + var taskMessage = new TaskMessage + { + OrchestrationInstance = new OrchestrationInstance { InstanceId = instanceId }, + Event = new ExecutionTerminatedEvent(-1, reason), + }; + + return this.SendTaskOrchestrationMessageAsync(taskMessage); + } + + int IOrchestrationService.GetDelayInSecondsAfterOnFetchException(Exception exception) + { + return exception is OperationCanceledException ? 0 : 1; + } + + int IOrchestrationService.GetDelayInSecondsAfterOnProcessException(Exception exception) + { + return exception is OperationCanceledException ? 0 : 1; + } + + /// + public Task GetOrchestrationHistoryAsync(string instanceId, string executionId) + { + // Also not supported in the emulator + throw new NotImplementedException(); + } + + /// + public async Task> GetOrchestrationStateAsync(string instanceId, bool allExecutions) + { + OrchestrationState state = await this.GetOrchestrationStateAsync(instanceId, executionId: null); + return [state]; + } + + /// + public Task GetOrchestrationStateAsync(string instanceId, string? executionId) + { + if (this.instanceStore.TryGetState(instanceId, out OrchestrationState? statusRecord)) + { + if (executionId == null || executionId == statusRecord.OrchestrationInstance.ExecutionId) + { + return Task.FromResult(statusRecord); + } + } + + return Task.FromResult(null!); + } + + bool IOrchestrationService.IsMaxMessageCountExceeded(int currentMessageCount, OrchestrationRuntimeState runtimeState) => false; + + /// + public async Task LockNextTaskActivityWorkItem( + TimeSpan receiveTimeout, + CancellationToken cancellationToken) + { + TaskMessage message = await this.activityQueue.DequeueAsync(cancellationToken); + return new TaskActivityWorkItem + { + Id = message.SequenceNumber.ToString(CultureInfo.InvariantCulture), + LockedUntilUtc = DateTime.MaxValue, + TaskMessage = message, + }; + } + + /// + public async Task LockNextTaskOrchestrationWorkItemAsync( + TimeSpan receiveTimeout, + CancellationToken cancellationToken) + { + var (instanceId, runtimeState, messages) = await this.instanceStore.GetNextReadyToRunInstanceAsync( + cancellationToken); + + return new TaskOrchestrationWorkItem + { + InstanceId = instanceId, + OrchestrationRuntimeState = runtimeState, + NewMessages = messages, + LockedUntilUtc = DateTime.MaxValue, + }; + } + + /// + public Task PurgeOrchestrationHistoryAsync( + DateTime thresholdDateTimeUtc, + OrchestrationStateTimeRangeFilterType timeRangeFilterType) + { + throw new NotSupportedException(); + } + + /// + public Task ReleaseTaskOrchestrationWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + this.instanceStore.ReleaseLock(workItem.InstanceId); + return Task.CompletedTask; + } + + /// + public Task RenewTaskActivityWorkItemLockAsync(TaskActivityWorkItem workItem) + { + return Task.FromResult(workItem); // PeekLock isn't supported + } + + /// + public Task RenewTaskOrchestrationWorkItemLockAsync(TaskOrchestrationWorkItem workItem) + { + return Task.CompletedTask; // PeekLock isn't supported + } + + /// + public Task SendTaskOrchestrationMessageAsync(TaskMessage message) + { + this.instanceStore.AddMessage(message); + return Task.CompletedTask; + } + + /// + public Task SendTaskOrchestrationMessageBatchAsync(params TaskMessage[] messages) + { + // NOTE: This is not transactionally consistent - some messages may get processed earlier than others. + foreach (TaskMessage message in messages) + { + this.instanceStore.AddMessage(message); + } + + return Task.CompletedTask; + } + + /// + public Task StartAsync() => Task.CompletedTask; + + /// + public Task StopAsync() => Task.CompletedTask; + + /// + public Task StopAsync(bool isForced) => Task.CompletedTask; + + /// + public async Task WaitForOrchestrationAsync( + string instanceId, + string executionId, + TimeSpan timeout, + CancellationToken cancellationToken) + { + if (timeout <= TimeSpan.Zero) + { + return await this.instanceStore.WaitForInstanceAsync(instanceId, cancellationToken); + } + else + { + using CancellationTokenSource timeoutCancellationSource = new(timeout); + using CancellationTokenSource linkedCancellationSource = CancellationTokenSource.CreateLinkedTokenSource( + cancellationToken, + timeoutCancellationSource.Token); + return await this.instanceStore.WaitForInstanceAsync(instanceId, linkedCancellationSource.Token); + } + } + + static bool TryGetScheduledTime(TaskMessage message, out TimeSpan delay) + { + DateTime scheduledTime = default; + if (message.Event is ExecutionStartedEvent startEvent) + { + scheduledTime = startEvent.ScheduledStartTime ?? default; + } + else if (message.Event is TimerFiredEvent timerEvent) + { + scheduledTime = timerEvent.FireAt; + } + + DateTime now = DateTime.UtcNow; + if (scheduledTime > now) + { + delay = scheduledTime - now; + return true; + } + else + { + delay = default; + return false; + } + } + + /// + public Task GetOrchestrationWithQueryAsync( + OrchestrationQuery query, + CancellationToken cancellationToken) + { + return Task.FromResult(this.instanceStore.GetOrchestrationWithQuery(query)); + } + + /// + public Task PurgeInstanceStateAsync(string instanceId) + { + return Task.FromResult(this.instanceStore.PurgeInstanceState(instanceId)); + } + + /// + public Task PurgeInstanceStateAsync(PurgeInstanceFilter purgeInstanceFilter) + { + return Task.FromResult(this.instanceStore.PurgeInstanceState(purgeInstanceFilter)); + } + + class InMemoryQueue + { + readonly Channel innerQueue = Channel.CreateUnbounded(); + + public void Enqueue(TaskMessage taskMessage) + { + if (TryGetScheduledTime(taskMessage, out TimeSpan delay)) + { + _ = Task.Delay(delay).ContinueWith(t => this.innerQueue.Writer.TryWrite(taskMessage)); + } + else + { + this.innerQueue.Writer.TryWrite(taskMessage); + } + } + + public void Enqueue(IEnumerable messages) + { + foreach (TaskMessage msg in messages) + { + this.Enqueue(msg); + } + } + + public async Task DequeueAsync(CancellationToken cancellationToken) + { + return await this.innerQueue.Reader.ReadAsync(cancellationToken); + } + } + + class InMemoryInstanceStore(ILogger logger) + { + readonly ConcurrentDictionary store = new(StringComparer.OrdinalIgnoreCase); + readonly ConcurrentDictionary> waiters = new(StringComparer.OrdinalIgnoreCase); + readonly ReadyToRunQueue readyToRunQueue = new(); + + readonly ILogger logger = logger; + + public void Reset() + { + this.store.Clear(); + this.waiters.Clear(); + this.readyToRunQueue.Reset(); + } + + public async Task<(string, OrchestrationRuntimeState, List)> GetNextReadyToRunInstanceAsync(CancellationToken cancellationToken) + { + SerializedInstanceState state = await this.readyToRunQueue.TakeNextAsync(cancellationToken); + lock (state) + { + List history = state.HistoryEventsJson.Select(e => e!.GetValue()).ToList(); + OrchestrationRuntimeState runtimeState = new(history); + + List messages = state.MessagesJson.Select(node => node!.GetValue()).ToList(); + if (messages == null) + { + throw new InvalidOperationException("Should never load state with zero messages."); + } + + state.IsLoaded = true; + + // There is no "peek-lock" semantic. All dequeued messages are immediately deleted. + state.MessagesJson.Clear(); + + return (state.InstanceId, runtimeState, messages); + } + } + + public bool TryGetState(string instanceId, [NotNullWhen(true)] out OrchestrationState? statusRecord) + { + if (!this.store.TryGetValue(instanceId, out SerializedInstanceState? state)) + { + statusRecord = null; + return false; + } + + statusRecord = state.StatusRecordJson?.GetValue(); + return statusRecord != null; + } + + public bool TryGetHistory(string instanceId, [NotNullWhen(true)] out List? history) + { + if (!this.store.TryGetValue(instanceId, out SerializedInstanceState? state)) + { + history = null; + return false; + } + + lock (state) + { + history = state.HistoryEventsJson.Select(e => e!.GetValue()).ToList(); + return true; + } + } + + public void SaveState( + OrchestrationRuntimeState runtimeState, + OrchestrationState statusRecord, + IEnumerable newMessages) + { + static bool IsCompleted(OrchestrationRuntimeState runtimeState) => + runtimeState.OrchestrationStatus == OrchestrationStatus.Completed || + runtimeState.OrchestrationStatus == OrchestrationStatus.Failed || + runtimeState.OrchestrationStatus == OrchestrationStatus.Terminated || + runtimeState.OrchestrationStatus == OrchestrationStatus.Canceled; + + if (string.IsNullOrEmpty(runtimeState.OrchestrationInstance?.InstanceId)) + { + throw new ArgumentException("The provided runtime state doesn't contain instance ID information.", nameof(runtimeState)); + } + + string instanceId = runtimeState.OrchestrationInstance.InstanceId; + string executionId = runtimeState.OrchestrationInstance.ExecutionId; + SerializedInstanceState state = this.store.GetOrAdd( + instanceId, + _ => new SerializedInstanceState(instanceId, executionId)); + lock (state) + { + if (state.ExecutionId == null) + { + // This orchestration was started by a message without an execution ID. + state.ExecutionId = executionId; + } + else if (state.ExecutionId != executionId) + { + // This is a new generation (ContinueAsNew). Erase the old history. + state.HistoryEventsJson.Clear(); + } + + foreach (TaskMessage msg in newMessages) + { + this.AddMessage(msg); + } + + // Append to the orchestration history + foreach (HistoryEvent e in runtimeState.NewEvents) + { + state.HistoryEventsJson.Add(e); + } + + state.StatusRecordJson = JsonValue.Create(statusRecord); + state.IsCompleted = IsCompleted(runtimeState); + } + + // Notify any waiters of the orchestration completion + if (IsCompleted(runtimeState) && + this.waiters.TryRemove(statusRecord.OrchestrationInstance.InstanceId, out TaskCompletionSource? waiter)) + { + waiter.TrySetResult(statusRecord); + } + } + + public void AddMessage(TaskMessage message) + { + string instanceId = message.OrchestrationInstance.InstanceId; + string? executionId = message.OrchestrationInstance.ExecutionId; + + SerializedInstanceState state = this.store.GetOrAdd(instanceId, id => new SerializedInstanceState(id, executionId)); + lock (state) + { + if (message.Event is ExecutionStartedEvent startEvent) + { + OrchestrationState newStatusRecord = new() + { + OrchestrationInstance = startEvent.OrchestrationInstance, + CreatedTime = DateTime.UtcNow, + LastUpdatedTime = DateTime.UtcNow, + OrchestrationStatus = OrchestrationStatus.Pending, + Version = startEvent.Version, + Name = startEvent.Name, + Input = startEvent.Input, + ScheduledStartTime = startEvent.ScheduledStartTime, + }; + + state.StatusRecordJson = JsonValue.Create(newStatusRecord); + state.HistoryEventsJson.Clear(); + state.IsCompleted = false; + } + else if (state.IsCompleted) + { + // Drop the message since we're completed + this.logger.DroppedMessageForCompletedOrchestration( + instanceId, + message.Event.EventType); + return; + } + + if (TryGetScheduledTime(message, out TimeSpan delay)) + { + // Not ready for this message yet - delay the enqueue + _ = Task.Delay(delay).ContinueWith(t => this.AddMessage(message)); + return; + } + + state.MessagesJson.Add(message); + + if (!state.IsLoaded) + { + // The orchestration isn't running, so schedule it to run now. + // If it is running, it will be scheduled again automatically when it's released. + this.readyToRunQueue.Schedule(state); + } + } + } + + public void AbandonInstance(IEnumerable messagesToReturn) + { + foreach (TaskMessage message in messagesToReturn) + { + this.AddMessage(message); + } + } + + public void ReleaseLock(string instanceId) + { + if (!this.store.TryGetValue(instanceId, out SerializedInstanceState? state) || !state.IsLoaded) + { + throw new InvalidOperationException($"Instance {instanceId} is not in the store or is not loaded!"); + } + + lock (state) + { + state.IsLoaded = false; + if (state.MessagesJson.Count > 0) + { + // More messages came in while we were running. Or, messages were abandoned. + // Put this back into the read-to-run queue! + this.readyToRunQueue.Schedule(state); + } + } + } + + public Task WaitForInstanceAsync(string instanceId, CancellationToken cancellationToken) + { + if (this.store.TryGetValue(instanceId, out SerializedInstanceState? state)) + { + lock (state) + { + OrchestrationState? statusRecord = state.StatusRecordJson?.GetValue(); + if (statusRecord != null) + { + if (statusRecord.OrchestrationStatus == OrchestrationStatus.Completed || + statusRecord.OrchestrationStatus == OrchestrationStatus.Failed || + statusRecord.OrchestrationStatus == OrchestrationStatus.Terminated) + { + // orchestration has already completed + return Task.FromResult(statusRecord); + } + } + + } + } + + // Caller will be notified when the instance completes. + // The ContinueWith is just to enable cancellation: https://stackoverflow.com/a/25652873/2069 + var tcs = this.waiters.GetOrAdd(instanceId, _ => new TaskCompletionSource()); + return tcs.Task.ContinueWith(t => t.GetAwaiter().GetResult(), cancellationToken); + } + + public OrchestrationQueryResult GetOrchestrationWithQuery(OrchestrationQuery query) + { + int startIndex = 0; + int counter = 0; + string? continuationToken = query.ContinuationToken; + if (continuationToken != null) + { + if (!Int32.TryParse(continuationToken, out startIndex)) + { + throw new InvalidOperationException($"{continuationToken} cannot be parsed to Int32"); + } + } + + counter = startIndex; + + List results = this.store + .Skip(startIndex) + .Where(item => + { + counter++; + OrchestrationState? statusRecord = item.Value.StatusRecordJson?.GetValue(); + if (statusRecord == null) + { + return false; + } + + if (query.CreatedTimeFrom != null && query.CreatedTimeFrom > statusRecord.CreatedTime) + { + return false; + } + + if (query.CreatedTimeTo != null && query.CreatedTimeTo < statusRecord.CreatedTime) + { + return false; + } + + if (query.RuntimeStatus != null && + query.RuntimeStatus.Count > 0 && + !query.RuntimeStatus.Contains(statusRecord.OrchestrationStatus)) + { + return false; + } + + if (query.InstanceIdPrefix != null && + !statusRecord.OrchestrationInstance.InstanceId.StartsWith( + query.InstanceIdPrefix, + StringComparison.Ordinal)) + { + return false; + } + + return true; + }) + .Take(query.PageSize) + .Select(item => item.Value.StatusRecordJson!.GetValue()) + .ToList(); + + string? token = null; + if (results.Count == query.PageSize) + { + token = counter.ToString(CultureInfo.InvariantCulture); + } + + return new OrchestrationQueryResult(results, token); + } + + public PurgeResult PurgeInstanceState(string instanceId) + { + if (instanceId != null && this.store.TryGetValue(instanceId, out SerializedInstanceState? state) && state.IsCompleted) + { + this.store.TryRemove(instanceId, out SerializedInstanceState? removedState); + if (removedState != null) + { + return new PurgeResult(1); + } + } + return new PurgeResult(0); + } + + public PurgeResult PurgeInstanceState(PurgeInstanceFilter purgeInstanceFilter) + { + int counter = 0; + + List filteredInstanceIds = this.store + .Where(item => + { + OrchestrationState? statusRecord = item.Value.StatusRecordJson?.GetValue(); + if (statusRecord == null) return false; + if (purgeInstanceFilter.CreatedTimeFrom > statusRecord.CreatedTime) return false; + if (purgeInstanceFilter.CreatedTimeTo != null && purgeInstanceFilter.CreatedTimeTo < statusRecord.CreatedTime) return false; + if (purgeInstanceFilter.RuntimeStatus != null && purgeInstanceFilter.RuntimeStatus.Any() && !purgeInstanceFilter.RuntimeStatus.Contains(statusRecord.OrchestrationStatus)) return false; + return true; + }) + .Select(item => item.Key) + .ToList(); + + foreach (string instanceId in filteredInstanceIds) + { + this.store.TryRemove(instanceId, out SerializedInstanceState? removedState); + if (removedState != null) + { + counter++; + } + } + + return new PurgeResult(counter); + } + + class ReadyToRunQueue + { + readonly Channel readyToRunQueue = Channel.CreateUnbounded(); + readonly Dictionary readyInstances = new(StringComparer.OrdinalIgnoreCase); + + public void Reset() + { + this.readyInstances.Clear(); + } + + public async ValueTask TakeNextAsync(CancellationToken ct) + { + while (true) + { + SerializedInstanceState state = await this.readyToRunQueue.Reader.ReadAsync(ct); + lock (state) + { + if (this.readyInstances.Remove(state.InstanceId)) + { + if (state.IsLoaded) + { + throw new InvalidOperationException("Should never load state that is already loaded."); + } + + state.IsLoaded = true; + return state; + } + } + } + } + + public void Schedule(SerializedInstanceState state) + { + // TODO: There is a race condition here. If another thread is calling TakeNextAsync + // and removed the queue item before updating the dictionary, then we'll fail + // to update the readyToRunQueue and the orchestration will get stuck. + if (this.readyInstances.TryAdd(state.InstanceId, state)) + { + this.readyToRunQueue.Writer.TryWrite(state); + } + } + } + + class SerializedInstanceState + { + public SerializedInstanceState(string instanceId, string? executionId) + { + this.InstanceId = instanceId; + this.ExecutionId = executionId; + } + + public string InstanceId { get; } + public string? ExecutionId { get; internal set; } + public JsonValue? StatusRecordJson { get; set; } + public JsonArray HistoryEventsJson { get; } = new JsonArray(); + public JsonArray MessagesJson { get; } = new JsonArray(); + + internal bool IsLoaded { get; set; } + internal bool IsCompleted { get; set; } + } + } +} diff --git a/src/Services/Sidecar/Logs.cs b/src/Services/Sidecar/Logs.cs new file mode 100644 index 00000000..e8238251 --- /dev/null +++ b/src/Services/Sidecar/Logs.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.Command; +using DurableTask.Core.History; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar +{ + static partial class Logs + { + [LoggerMessage( + EventId = 5, + Level = LogLevel.Information, + Message = "Waiting for a remote client to connect to this server. Total wait time: {totalWaitTime:c}")] + public static partial void WaitingForClientConnection( + this ILogger logger, + TimeSpan totalWaitTime); + + [LoggerMessage( + EventId = 6, + Level = LogLevel.Information, + Message = "Received work-item connection from {address}. Client connection deadline = {deadline:s}.")] + public static partial void ClientConnected( + this ILogger logger, + string address, + DateTime deadline); + + [LoggerMessage( + EventId = 7, + Level = LogLevel.Information, + Message = "Client at {address} has disconnected. No further work-items will be processed until a new connection is established.")] + public static partial void ClientDisconnected( + this ILogger logger, + string address); + + [LoggerMessage( + EventId = 22, + Level = LogLevel.Information, + Message = "{dispatcher}: Shutting down, waiting for {currentWorkItemCount} active work-items to complete.")] + public static partial void DispatcherStopping( + this ILogger logger, + string dispatcher, + int currentWorkItemCount); + + [LoggerMessage( + EventId = 23, + Level = LogLevel.Trace, + Message = "{dispatcher}: Fetching next work item. Current active work-items: {currentWorkItemCount}/{maxWorkItemCount}.")] + public static partial void FetchWorkItemStarting( + this ILogger logger, + string dispatcher, + int currentWorkItemCount, + int maxWorkItemCount); + + [LoggerMessage( + EventId = 24, + Level = LogLevel.Trace, + Message = "{dispatcher}: Fetched next work item '{workItemId}' after {latencyMs}ms. Current active work-items: {currentWorkItemCount}/{maxWorkItemCount}.")] + public static partial void FetchWorkItemCompleted( + this ILogger logger, + string dispatcher, + string workItemId, + long latencyMs, + int currentWorkItemCount, + int maxWorkItemCount); + + [LoggerMessage( + EventId = 25, + Level = LogLevel.Error, + Message = "{dispatcher}: Unexpected {action} failure for work-item '{workItemId}': {details}")] + public static partial void DispatchWorkItemFailure( + this ILogger logger, + string dispatcher, + string action, + string workItemId, + string details); + + [LoggerMessage( + EventId = 26, + Level = LogLevel.Information, + Message = "{dispatcher}: Work-item fetching is paused: {details}. Current active work-item count: {currentWorkItemCount}/{maxWorkItemCount}.")] + public static partial void FetchingThrottled( + this ILogger logger, + string dispatcher, + string details, + int currentWorkItemCount, + int maxWorkItemCount); + + [LoggerMessage( + EventId = 49, + Level = LogLevel.Information, + Message = "{instanceId}: Orchestrator '{name}' completed with a {runtimeStatus} status and {sizeInBytes} bytes of output.")] + public static partial void OrchestratorCompleted( + this ILogger logger, + string instanceId, + string name, + OrchestrationStatus runtimeStatus, + int sizeInBytes); + + [LoggerMessage( + EventId = 51, + Level = LogLevel.Debug, + Message = "{instanceId}: Preparing to execute orchestrator '{name}' with {eventCount} new events: {newEvents}")] + public static partial void OrchestratorExecuting( + this ILogger logger, + string instanceId, + string name, + int eventCount, + string newEvents); + + [LoggerMessage( + EventId = 55, + Level = LogLevel.Warning, + Message = "{instanceId}: Ignoring unknown orchestrator action '{action}'.")] + public static partial void IgnoringUnknownOrchestratorAction( + this ILogger logger, + string instanceId, + OrchestratorActionType action); + + [LoggerMessage( + EventId = 56, + Level = LogLevel.Warning, + Message = "{instanceId}: Dropped {eventType} message because the orchestration has already completed.")] + public static partial void DroppedMessageForCompletedOrchestration( + this ILogger logger, + string instanceId, + EventType eventType); + + [LoggerMessage( + EventId = 57, + Level = LogLevel.Warning, + Message = "Abandoning task activity work item {id}")] + public static partial void AbandoningTaskActivityWorkItem( + this ILogger logger, + string id); + + [LoggerMessage( + EventId = 100, + Level = LogLevel.Information, + Message = "Received request to create a new instance with ID = '{instanceId}'")] + public static partial void CreatingNewInstance( + this ILogger logger, + string instanceId); + + [LoggerMessage( + EventId = 101, + Level = LogLevel.Information, + Message = "Received request to raise event '{eventName}' to instance '{instanceId}'")] + public static partial void RaisingEvent( + this ILogger logger, + string instanceId, + string eventName); + + [LoggerMessage( + EventId = 102, + Level = LogLevel.Information, + Message = "Received request to terminating instance '{instanceId}'")] + public static partial void TerminatingInstance( + this ILogger logger, + string instanceId); + } +} \ No newline at end of file diff --git a/src/Services/Sidecar/Sidecar.csproj b/src/Services/Sidecar/Sidecar.csproj new file mode 100644 index 00000000..90c2aee8 --- /dev/null +++ b/src/Services/Sidecar/Sidecar.csproj @@ -0,0 +1,28 @@ + + + + + + + + Microsoft.DurableTask.Sidecar + Durable Task Sidecar + gRPC sidecar implementation for the Durable Task Framework. + true + + + + + + + + + + + + + + + + + diff --git a/src/Services/Sidecar/Utils.cs b/src/Services/Sidecar/Utils.cs new file mode 100644 index 00000000..d0a56995 --- /dev/null +++ b/src/Services/Sidecar/Utils.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core.History; + +namespace Microsoft.DurableTask.Sidecar; + +static class Utils +{ + public static bool TryGetTaskScheduledId(HistoryEvent historyEvent, out int taskScheduledId) + { + switch (historyEvent.EventType) + { + case EventType.TaskCompleted: + taskScheduledId = ((TaskCompletedEvent)historyEvent).TaskScheduledId; + return true; + case EventType.TaskFailed: + taskScheduledId = ((TaskFailedEvent)historyEvent).TaskScheduledId; + return true; + case EventType.SubOrchestrationInstanceCompleted: + taskScheduledId = ((SubOrchestrationInstanceCompletedEvent)historyEvent).TaskScheduledId; + return true; + case EventType.SubOrchestrationInstanceFailed: + taskScheduledId = ((SubOrchestrationInstanceFailedEvent)historyEvent).TaskScheduledId; + return true; + case EventType.TimerFired: + taskScheduledId = ((TimerFiredEvent)historyEvent).TimerId; + return true; + case EventType.ExecutionStarted: + var parentInstance = ((ExecutionStartedEvent)historyEvent).ParentInstance; + if (parentInstance != null) + { + // taskId that scheduled a sub-orchestration + taskScheduledId = parentInstance.TaskScheduleId; + return true; + } + else + { + taskScheduledId = -1; + return false; + } + default: + taskScheduledId = -1; + return false; + } + } +} diff --git a/src/Services/common.props b/src/Services/common.props new file mode 100644 index 00000000..5c96155b --- /dev/null +++ b/src/Services/common.props @@ -0,0 +1,43 @@ + + + + + + net6.0 + enable + enable + true + ../key.snk + embedded + Microsoft Corporation + Durable Task Sidecar + true + true + https://github.com/microsoft/durabletask-dotnet + + + + + 1.1.1 + + + $(VersionPrefix).$(FileVersionRevision) + + + + + Microsoft + © Microsoft Corporation. All rights reserved. + MIT + $(RepositoryUrl) + $(RepositoryUrl)/releases/ + + + + + + true + content/SBOM + + + \ No newline at end of file diff --git a/src/Services/key.snk b/src/Services/key.snk new file mode 100644 index 00000000..a5ed5b37 Binary files /dev/null and b/src/Services/key.snk differ diff --git a/src/Shared/Shared.csproj b/src/Shared/Shared.csproj index ae3d8491..b227c9f8 100644 --- a/src/Shared/Shared.csproj +++ b/src/Shared/Shared.csproj @@ -17,7 +17,6 @@ - diff --git a/test/Grpc.IntegrationTests/Grpc.IntegrationTests.csproj b/test/Grpc.IntegrationTests/Grpc.IntegrationTests.csproj index 2762089b..98e39ddc 100644 --- a/test/Grpc.IntegrationTests/Grpc.IntegrationTests.csproj +++ b/test/Grpc.IntegrationTests/Grpc.IntegrationTests.csproj @@ -3,11 +3,11 @@ + -