diff --git a/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs b/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs deleted file mode 100644 index 126c4b9bc..000000000 --- a/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs +++ /dev/null @@ -1,227 +0,0 @@ -// ---------------------------------------------------------------------------------- -// Copyright Microsoft Corporation -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ---------------------------------------------------------------------------------- - -namespace DurableTask.AzureStorage.Tests -{ - using System; - using System.Collections.Generic; - using System.Diagnostics; - using System.Linq; - using System.Reflection; - using System.Threading; - using System.Threading.Tasks; - using DurableTask.AzureStorage.Messaging; - using DurableTask.AzureStorage.Monitoring; - using DurableTask.AzureStorage.Tracking; - using Microsoft.VisualStudio.TestTools.UnitTesting; - using Moq; - - /// - /// Tests for shutdown cancellation behavior with extended sessions. - /// - [TestClass] - public class OrchestrationSessionTests - { - /// - /// Verifies that - /// exits immediately when the cancellation token is cancelled. - /// - [TestMethod] - public async Task WaitAsync_CancellationToken_ExitsImmediately() - { - var resetEvent = new AsyncAutoResetEvent(signaled: false); - using var cts = new CancellationTokenSource(); - - TimeSpan longTimeout = TimeSpan.FromSeconds(30); - Task waitTask = resetEvent.WaitAsync(longTimeout, cts.Token); - - Assert.IsFalse(waitTask.IsCompleted, "Wait should not complete immediately"); - - var stopwatch = Stopwatch.StartNew(); - cts.Cancel(); - - bool result = await waitTask; - stopwatch.Stop(); - - Assert.IsFalse(result, "Cancellation should return false (no signal received)"); - Assert.IsTrue( - stopwatch.ElapsedMilliseconds < 5000, - $"Cancellation should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms"); - } - - /// - /// Verifies that signaling still returns true when a cancellation token is provided. - /// - [TestMethod] - public async Task WaitAsync_WithCancellationToken_SignalStillWorks() - { - var resetEvent = new AsyncAutoResetEvent(signaled: false); - using var cts = new CancellationTokenSource(); - - Task waitTask = resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); - Assert.IsFalse(waitTask.IsCompleted); - - resetEvent.Set(); - - Task winner = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5))); - Assert.IsTrue(winner == waitTask, "Signal should wake the waiter"); - Assert.IsTrue(waitTask.Result, "Wait result should be true when signaled"); - } - - /// - /// Verifies that the wait returns false on timeout when a cancellation token is provided but not cancelled. - /// - [TestMethod] - public async Task WaitAsync_WithCancellationToken_TimeoutStillWorks() - { - var resetEvent = new AsyncAutoResetEvent(signaled: false); - using var cts = new CancellationTokenSource(); - - bool result = await resetEvent.WaitAsync(TimeSpan.FromMilliseconds(100), cts.Token); - - Assert.IsFalse(result, "Wait should return false on timeout"); - } - - /// - /// Verifies that all queued waiters return false when the token is cancelled. - /// - [TestMethod] - public async Task WaitAsync_CancellationToken_MultipleWaiters() - { - var resetEvent = new AsyncAutoResetEvent(signaled: false); - using var cts = new CancellationTokenSource(); - - var waiters = new List>(); - for (int i = 0; i < 5; i++) - { - waiters.Add(resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token)); - } - - foreach (var waiter in waiters) - { - Assert.IsFalse(waiter.IsCompleted); - } - - var stopwatch = Stopwatch.StartNew(); - cts.Cancel(); - - // All waiters should return false (cancelled = not signaled) - await Task.WhenAll( - waiters.Select( - async waiter => - { - bool result = await waiter; - Assert.IsFalse(result, "Cancelled waiter should return false"); - })); - - stopwatch.Stop(); - - Assert.IsTrue( - stopwatch.ElapsedMilliseconds < 5000, - $"All waiters should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms"); - } - - /// - /// Verifies that a pre-cancelled token causes WaitAsync to return false immediately. - /// - [TestMethod] - public async Task WaitAsync_AlreadyCancelledToken_ReturnsFalseImmediately() - { - var resetEvent = new AsyncAutoResetEvent(signaled: false); - using var cts = new CancellationTokenSource(); - cts.Cancel(); // Pre-cancel - - var stopwatch = Stopwatch.StartNew(); - bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); - stopwatch.Stop(); - - Assert.IsFalse(result, "Pre-cancelled token should cause immediate false return"); - Assert.IsTrue( - stopwatch.ElapsedMilliseconds < 5000, - $"Should complete immediately, but took {stopwatch.ElapsedMilliseconds}ms"); - } - - /// - /// Verifies that a pre-cancelled token still returns true if the event is already signaled. - /// - [TestMethod] - public async Task WaitAsync_AlreadySignaledAndCancelled_ReturnsTrue() - { - var resetEvent = new AsyncAutoResetEvent(signaled: true); - using var cts = new CancellationTokenSource(); - cts.Cancel(); - - bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); - Assert.IsTrue(result, "Already signaled event should return true even with cancelled token"); - } - - /// - /// Verifies that clears all active sessions. - /// - [TestMethod] - public void AbortAllSessions_ClearsActiveSessions() - { - var settings = new AzureStorageOrchestrationServiceSettings(); - var stats = new AzureStorageOrchestrationServiceStats(); - var trackingStore = new Mock(); - - using var manager = new OrchestrationSessionManager( - "testaccount", - settings, - stats, - trackingStore.Object); - - // Use reflection to access the internal sessions dictionary. - var sessionsField = typeof(OrchestrationSessionManager) - .GetField("activeOrchestrationSessions", BindingFlags.NonPublic | BindingFlags.Instance); - var sessions = (Dictionary)sessionsField.GetValue(manager); - - manager.GetStats(out _, out _, out int initialCount); - Assert.AreEqual(0, initialCount, "Should start with no active sessions"); - - sessions["instance1"] = null; - sessions["instance2"] = null; - sessions["instance3"] = null; - - manager.GetStats(out _, out _, out int activeCount); - Assert.AreEqual(3, activeCount, "Should have 3 active sessions"); - - manager.AbortAllSessions(); - - manager.GetStats(out _, out _, out int afterAbortCount); - Assert.AreEqual(0, afterAbortCount, "AbortAllSessions should clear all active sessions"); - } - - /// - /// Verifies that is safe to call with no active sessions. - /// - [TestMethod] - public void AbortAllSessions_NoSessions_DoesNotThrow() - { - var settings = new AzureStorageOrchestrationServiceSettings(); - var stats = new AzureStorageOrchestrationServiceStats(); - var trackingStore = new Mock(); - - using var manager = new OrchestrationSessionManager( - "testaccount", - settings, - stats, - trackingStore.Object); - - manager.AbortAllSessions(); - - manager.GetStats(out _, out _, out int count); - Assert.AreEqual(0, count, "Should still have no active sessions"); - } - } -} diff --git a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs index 74798cb45..cfe64a7d4 100644 --- a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs +++ b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs @@ -553,7 +553,7 @@ internal async Task OnIntentLeaseAquiredAsync(BlobPartitionLease lease) { var controlQueue = new ControlQueue(this.azureStorageClient, lease.PartitionId, this.messageManager); await controlQueue.CreateIfNotExistsAsync(); - this.orchestrationSessionManager.ResumeListeningIfOwnQueue(lease.PartitionId, controlQueue, this.shutdownSource.Token); + await this.orchestrationSessionManager.ResumeListeningIfOwnQueue(lease.PartitionId, controlQueue, this.shutdownSource.Token); } internal Task OnIntentLeaseReleasedAsync(BlobPartitionLease lease, CloseReason reason) @@ -572,21 +572,20 @@ internal async Task OnOwnershipLeaseAquiredAsync(BlobPartitionLease lease) this.allControlQueues[lease.PartitionId] = controlQueue; } - internal void DropLostControlQueue(TablePartitionLease partition) + internal async Task DropLostControlQueue(TablePartitionLease partition) { // If lease is lost but we're still dequeuing messages, remove the queue if (this.allControlQueues.TryGetValue(partition.RowKey, out ControlQueue controlQueue) && this.OwnedControlQueues.Contains(controlQueue) && partition.CurrentOwner != this.settings.WorkerId) { - this.orchestrationSessionManager.RemoveQueue(partition.RowKey, CloseReason.LeaseLost, nameof(DropLostControlQueue)); + await this.orchestrationSessionManager.RemoveQueue(partition.RowKey, CloseReason.LeaseLost, nameof(DropLostControlQueue)); } } internal Task OnOwnershipLeaseReleasedAsync(BlobPartitionLease lease, CloseReason reason) { - this.orchestrationSessionManager.RemoveQueue(lease.PartitionId, reason, "Ownership LeaseCollectionBalancer"); - return Utils.CompletedTask; + return this.orchestrationSessionManager.RemoveQueue(lease.PartitionId, reason, "Ownership LeaseCollectionBalancer"); } internal async Task OnTableLeaseAcquiredAsync(TablePartitionLease lease) @@ -1249,7 +1248,13 @@ await this.CommitOutboundQueueMessages( { var messages = session.DeferredMessages.ToList(); session.DeferredMessages.Clear(); - this.orchestrationSessionManager.AddMessageToPendingOrchestration(session.ControlQueue, messages, session.TraceActivityId, CancellationToken.None); + IReadOnlyList messagesToAbandon = this.orchestrationSessionManager.AddMessageToPendingOrchestration( + session.ControlQueue, + messages, + session.TraceActivityId, + CancellationToken.None); + + await this.orchestrationSessionManager.AbandonMessagesForDrainAsync(session.ControlQueue, messagesToAbandon); } } // Handle the case where the 'ETag' has changed, which implies another worker has taken over this work item while @@ -1524,8 +1529,11 @@ async Task ReleaseSessionAsync(string instanceId) if (this.orchestrationSessionManager.TryReleaseSession( instanceId, this.shutdownSource.Token, - out OrchestrationSession session)) + out OrchestrationSession session, + out IReadOnlyList messagesToAbandon)) { + await this.orchestrationSessionManager.AbandonMessagesForDrainAsync(session.ControlQueue, messagesToAbandon); + // Some messages may need to be discarded await session.DiscardedMessages.ParallelForEachAsync( this.settings.MaxStorageOperationConcurrency, diff --git a/src/DurableTask.AzureStorage/Messaging/ControlQueue.cs b/src/DurableTask.AzureStorage/Messaging/ControlQueue.cs index 9f1c0d2ba..2d27fa3ce 100644 --- a/src/DurableTask.AzureStorage/Messaging/ControlQueue.cs +++ b/src/DurableTask.AzureStorage/Messaging/ControlQueue.cs @@ -23,6 +23,7 @@ namespace DurableTask.AzureStorage.Messaging using DurableTask.AzureStorage.Monitoring; using DurableTask.AzureStorage.Partitioning; using DurableTask.AzureStorage.Storage; + using DurableTask.Core; class ControlQueue : TaskHubQueue, IDisposable { @@ -209,6 +210,50 @@ public override Task AbandonMessageAsync(MessageData message, SessionBase? sessi return base.AbandonMessageAsync(message, session); } + /// + /// Abandons a message with zero visibility timeout so it becomes immediately visible + /// for another partition owner to pick up. This is used during drain to avoid stranding + /// messages that were dequeued but not yet promoted to active sessions. + /// + public async Task AbandonMessageForDrainAsync(MessageData message) + { + this.stats.PendingOrchestratorMessages.TryRemove(message.OriginalQueueMessage.MessageId, out _); + + QueueMessage queueMessage = message.OriginalQueueMessage; + TaskMessage taskMessage = message.TaskMessage; + OrchestrationInstance instance = taskMessage.OrchestrationInstance; + + this.settings.Logger.AbandoningMessage( + this.storageAccountName, + this.settings.TaskHubName, + taskMessage.Event.EventType.ToString(), + Utils.GetTaskEventId(taskMessage.Event), + queueMessage.MessageId, + instance.InstanceId, + instance.ExecutionId, + this.storageQueue.Name, + message.SequenceNumber, + queueMessage.PopReceipt, + visibilityTimeoutSeconds: 0); + + try + { + await this.storageQueue.UpdateMessageAsync( + queueMessage, + TimeSpan.Zero, + clientRequestId: null); + } + catch (DurableTaskStorageException e) + { + this.settings.Logger.PartitionManagerWarning( + this.storageAccountName, + this.settings.TaskHubName, + this.settings.WorkerId, + this.Name, + $"Failed to abandon message {queueMessage.MessageId} during drain: {e.Message}"); + } + } + public override Task DeleteMessageAsync(MessageData message, SessionBase? session = null) { this.stats.PendingOrchestratorMessages.TryRemove(message.OriginalQueueMessage.MessageId, out _); diff --git a/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs b/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs index abf7a58b2..a9d27aeb4 100644 --- a/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs +++ b/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs @@ -30,8 +30,11 @@ namespace DurableTask.AzureStorage class OrchestrationSessionManager : IDisposable { + static readonly IReadOnlyList EmptyMessageDataList = Array.Empty(); + readonly Dictionary activeOrchestrationSessions = new Dictionary(StringComparer.OrdinalIgnoreCase); readonly ConcurrentDictionary ownedControlQueues = new ConcurrentDictionary(); + readonly ConcurrentDictionary dequeueLoopTasks = new ConcurrentDictionary(); readonly LinkedList pendingOrchestrationMessageBatches = new LinkedList(); readonly AsyncQueue> orchestrationsReadyForProcessingQueue = new AsyncQueue>(); readonly AsyncQueue> entitiesReadyForProcessingQueue = new AsyncQueue>(); @@ -67,7 +70,8 @@ public void AddQueue(string partitionId, ControlQueue controlQueue, Cancellation if (this.ownedControlQueues.TryAdd(partitionId, controlQueue)) { - _ = Task.Run(() => this.DequeueLoop(partitionId, controlQueue, cancellationToken)); + Task dequeueLoopTask = Task.Run(async () => await this.DequeueLoop(partitionId, controlQueue, cancellationToken)); + this.dequeueLoopTasks[partitionId] = dequeueLoopTask; } else { @@ -80,11 +84,13 @@ public void AddQueue(string partitionId, ControlQueue controlQueue, Cancellation } } - public void RemoveQueue(string partitionId, CloseReason? reason, string caller) + public async Task RemoveQueue(string partitionId, CloseReason? reason, string caller) { if (this.ownedControlQueues.TryRemove(partitionId, out ControlQueue controlQueue)) { controlQueue.Release(reason, caller); + this.dequeueLoopTasks.TryRemove(partitionId, out _); + await this.AbandonPendingMessagesAsync(partitionId, controlQueue); } } @@ -97,15 +103,15 @@ public void ReleaseQueue(string partitionId, CloseReason? reason, string caller) } } - public bool ResumeListeningIfOwnQueue(string partitionId, ControlQueue controlQueue, CancellationToken shutdownToken) + public async Task ResumeListeningIfOwnQueue(string partitionId, ControlQueue controlQueue, CancellationToken cancellationToken) { if (this.ownedControlQueues.TryGetValue(partitionId, out ControlQueue ownedControlQueue)) { if (ownedControlQueue.IsReleased) { // The easiest way to resume listening is to re-add a new queue that has not been released. - this.RemoveQueue(partitionId, null, "OrchestrationSessionManager ResumeListeningIfOwnQueue"); - this.AddQueue(partitionId, controlQueue, shutdownToken); + await this.RemoveQueue(partitionId, null, "OrchestrationSessionManager ResumeListeningIfOwnQueue"); + this.AddQueue(partitionId, controlQueue, cancellationToken); } } @@ -154,7 +160,13 @@ async Task DequeueLoop(string partitionId, ControlQueue controlQueue, Cancellati traceActivityId, cancellationToken); - this.AddMessageToPendingOrchestration(controlQueue, filteredMessages, traceActivityId, cancellationToken); + IReadOnlyList messagesToAbandon = this.AddMessageToPendingOrchestration( + controlQueue, + filteredMessages, + traceActivityId, + cancellationToken); + + await this.AbandonMessagesForDrainAsync(controlQueue, messagesToAbandon); } } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) @@ -198,6 +210,8 @@ public async Task DrainAsync(string partitionId, CloseReason reason, Cancellatio this.ReleaseQueue(partitionId, reason, caller); try { + await this.WaitForDequeueLoopToStopAsync(partitionId, cancellationToken); + // Wait until all messages from this queue have been processed. while (!cancellationToken.IsCancellationRequested && this.IsControlQueueProcessingMessages(partitionId)) { @@ -216,8 +230,150 @@ public async Task DrainAsync(string partitionId, CloseReason reason, Cancellatio } finally { - // Remove the partition from memory - this.RemoveQueue(partitionId, reason, caller); + try + { + // Make dequeued-but-undispatched messages visible before dropping the partition. + await this.AbandonPendingMessagesAsync(partitionId); + } + finally + { + await this.RemoveQueue(partitionId, reason, caller); + } + } + } + + async Task WaitForDequeueLoopToStopAsync(string partitionId, CancellationToken cancellationToken) + { + if (!this.dequeueLoopTasks.TryGetValue(partitionId, out Task dequeueLoopTask)) + { + return; + } + + try + { + bool completed = await WaitForTaskAsync(dequeueLoopTask, cancellationToken); + if (!completed) + { + this.settings.Logger.PartitionManagerWarning( + this.storageAccountName, + this.settings.TaskHubName, + this.settings.WorkerId, + partitionId, + $"Timed-out waiting for the dequeue loop to stop during drain."); + } + } + catch (OperationCanceledException e) + { + this.settings.Logger.PartitionManagerWarning( + this.storageAccountName, + this.settings.TaskHubName, + this.settings.WorkerId, + partitionId, + $"Canceled while waiting for the dequeue loop to stop during drain. Exception: {e}"); + } + catch (AggregateException e) when (e.InnerExceptions.All(ex => ex is OperationCanceledException)) + { + this.settings.Logger.PartitionManagerWarning( + this.storageAccountName, + this.settings.TaskHubName, + this.settings.WorkerId, + partitionId, + $"Canceled while waiting for the dequeue loop to stop during drain. Exception: {e}"); + } + } + + static async Task WaitForTaskAsync(Task task, CancellationToken cancellationToken) + { + if (task.IsCompleted) + { + await task; + return true; + } + + var cancellationCompletion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using (cancellationToken.Register(state => ((TaskCompletionSource)state).TrySetResult(true), cancellationCompletion)) + { + Task completedTask = await Task.WhenAny(task, cancellationCompletion.Task); + if (completedTask != task) + { + return false; + } + } + + await task; + return true; + } + + /// + /// Abandons all pending (dequeued but not yet dispatched) messages for the specified partition, + /// making them immediately visible in the Azure Storage queue for the new partition owner. + /// This prevents a throughput gap equal to the visibility timeout duration when a partition + /// is released during drain. + /// + async Task AbandonPendingMessagesAsync(string partitionId, ControlQueue? removedControlQueue = null) + { + var messagesToAbandon = new List<(ControlQueue Queue, MessageData Message)>(); + + lock (this.messageAndSessionLock) + { + var node = this.pendingOrchestrationMessageBatches.First; + while (node != null) + { + LinkedListNode? next = node.Next; + PendingMessageBatch batch = node.Value; + + if (string.Equals(batch.ControlQueue.Name, partitionId, StringComparison.OrdinalIgnoreCase)) + { + foreach (MessageData message in batch.Messages) + { + messagesToAbandon.Add((batch.ControlQueue, message)); + } + + this.pendingOrchestrationMessageBatches.Remove(node); + } + + node = next; + } + } + + await this.AbandonMessagesForDrainAsync(partitionId, messagesToAbandon); + + if (removedControlQueue != null) + { + this.SignalSessionWaitersIfNoQueuesRemain(removedControlQueue); + } + } + + internal Task AbandonMessagesForDrainAsync(ControlQueue controlQueue, IReadOnlyList messages) + { + if (messages.Count == 0) + { + return Utils.CompletedTask; + } + + var messagesToAbandon = messages + .Select(message => (controlQueue, message)) + .ToList(); + + return this.AbandonMessagesForDrainAsync(controlQueue.Name, messagesToAbandon); + } + + async Task AbandonMessagesForDrainAsync( + string partitionId, + IList<(ControlQueue Queue, MessageData Message)> messagesToAbandon) + { + if (messagesToAbandon.Count > 0) + { + this.settings.Logger.PartitionManagerInfo( + this.storageAccountName, + this.settings.TaskHubName, + this.settings.WorkerId, + partitionId, + $"Abandoning {messagesToAbandon.Count} pending message(s) during drain to make them immediately visible for the new partition owner."); + + await messagesToAbandon.ParallelForEachAsync( + this.settings.MaxStorageOperationConcurrency, + item => item.Queue.AbandonMessageForDrainAsync(item.Message)); } } @@ -402,7 +558,7 @@ bool IsScheduledAfterInstanceUpdate(MessageData msg, OrchestrationState? remoteI /// New messages to assign to orchestrators /// The "related" ActivityId of this operation. /// Cancellation token in case the orchestration is terminated. - internal void AddMessageToPendingOrchestration( + internal IReadOnlyList AddMessageToPendingOrchestration( ControlQueue controlQueue, IEnumerable queueMessages, Guid traceActivityId, @@ -414,6 +570,11 @@ internal void AddMessageToPendingOrchestration( // 3. Do we need to add messages to a currently executing orchestration? lock (this.messageAndSessionLock) { + if (controlQueue.IsReleased) + { + return queueMessages.ToList(); + } + var existingSessionMessages = new Dictionary>(); foreach (MessageData data in queueMessages) @@ -509,6 +670,8 @@ internal void AddMessageToPendingOrchestration( session.AddOrReplaceMessages(newMessages); } } + + return EmptyMessageDataList; } // This method runs on a background task thread @@ -517,6 +680,11 @@ async Task ScheduleOrchestrationStatePrefetch( Guid traceActivityId, CancellationToken cancellationToken) { + if (!this.IsPendingBatchActive(node)) + { + return; + } + PendingMessageBatch batch = node.Value; AnalyticsEventSource.SetLogicalTraceActivityId(traceActivityId); @@ -530,6 +698,11 @@ async Task ScheduleOrchestrationStatePrefetch( batch.OrchestrationExecutionId, cancellationToken); + if (!this.IsPendingBatchActive(node)) + { + return; + } + batch.OrchestrationState = new OrchestrationRuntimeState(history.Events); batch.ETags.HistoryETag = history.ETag; batch.LastCheckpointTime = history.LastCheckpointTime; @@ -541,20 +714,34 @@ async Task ScheduleOrchestrationStatePrefetch( InstanceStatus? instanceStatus = await this.trackingStore.FetchInstanceStatusAsync( batch.OrchestrationInstanceId, cancellationToken); + + if (!this.IsPendingBatchActive(node)) + { + return; + } + // The instance could not exist in the case that these messages are for the first execution of a suborchestration, // or an entity-started orchestration, for example batch.ETags.InstanceETag = instanceStatus?.ETag; } } - if (this.settings.UseSeparateQueueForEntityWorkItems - && DurableTask.Core.Common.Entities.IsEntityInstance(batch.OrchestrationInstanceId)) - { - this.entitiesReadyForProcessingQueue.Enqueue(node); - } - else + lock (this.messageAndSessionLock) { - this.orchestrationsReadyForProcessingQueue.Enqueue(node); + if (!this.IsPendingBatchActiveLocked(node)) + { + return; + } + + if (this.settings.UseSeparateQueueForEntityWorkItems + && DurableTask.Core.Common.Entities.IsEntityInstance(batch.OrchestrationInstanceId)) + { + this.entitiesReadyForProcessingQueue.Enqueue(node); + } + else + { + this.orchestrationsReadyForProcessingQueue.Enqueue(node); + } } } catch (OperationCanceledException) @@ -571,14 +758,37 @@ async Task ScheduleOrchestrationStatePrefetch( e.ToString()); // Sleep briefly to avoid a tight failure loop. - await Task.Delay(TimeSpan.FromSeconds(5)); + try + { + await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); + } + catch (OperationCanceledException) + { + return; + } // This is a background operation so failure is not an option. All exceptions must be handled. // To avoid starvation, we need to re-enqueue this async operation instead of retrying in a loop. - await Task.Run(() => this.ScheduleOrchestrationStatePrefetch(node, traceActivityId, cancellationToken)); + if (this.IsPendingBatchActive(node)) + { + await Task.Run(async () => await this.ScheduleOrchestrationStatePrefetch(node, traceActivityId, cancellationToken)); + } } } + bool IsPendingBatchActive(LinkedListNode node) + { + lock (this.messageAndSessionLock) + { + return this.IsPendingBatchActiveLocked(node); + } + } + + bool IsPendingBatchActiveLocked(LinkedListNode node) + { + return node.List == this.pendingOrchestrationMessageBatches && !node.Value.ControlQueue.IsReleased; + } + public async Task GetNextSessionAsync(bool entitiesOnly, CancellationToken cancellationToken) { var readyForProcessingQueue = entitiesOnly? this.entitiesReadyForProcessingQueue : this.orchestrationsReadyForProcessingQueue; @@ -589,78 +799,166 @@ async Task ScheduleOrchestrationStatePrefetch( // 1) a batch of messages has been received for a particular instance and // 2) the history for that instance has been fetched LinkedListNode node = await readyForProcessingQueue.DequeueAsync(cancellationToken); + ControlQueue? queueToAbandon = null; + IReadOnlyList messagesToAbandon = EmptyMessageDataList; + bool shouldStopWaitingForSessions = false; lock (this.messageAndSessionLock) { - PendingMessageBatch nextBatch = node.Value; - this.pendingOrchestrationMessageBatches.Remove(node); - - if (!this.activeOrchestrationSessions.TryGetValue(nextBatch.OrchestrationInstanceId, out var existingSession)) + // Drain may have removed this batch after it was queued for dispatch. + if (node.List != this.pendingOrchestrationMessageBatches) { - OrchestrationInstance instance = nextBatch.OrchestrationState?.OrchestrationInstance ?? - new OrchestrationInstance - { - InstanceId = nextBatch.OrchestrationInstanceId, - ExecutionId = nextBatch.OrchestrationExecutionId, - }; - - Guid traceActivityId = AzureStorageOrchestrationService.StartNewLogicalTraceScope(useExisting: true); - - OrchestrationSession session = new OrchestrationSession( - this.settings, - this.storageAccountName, - instance, - nextBatch.ControlQueue, - nextBatch.Messages, - nextBatch.OrchestrationState, - nextBatch.ETags, - nextBatch.LastCheckpointTime, - nextBatch.TrackingStoreContext, - this.settings.ExtendedSessionIdleTimeout, - this.shutdownToken, - traceActivityId); - - this.activeOrchestrationSessions.Add(instance.InstanceId, session); - - return session; - } - else if (nextBatch.OrchestrationExecutionId == existingSession.Instance?.ExecutionId) - { - // there is already an active session with the same execution id. - // The session might be waiting for more messages. If it is, signal them. - existingSession.AddOrReplaceMessages(node.Value.Messages); + shouldStopWaitingForSessions = this.ShouldStopWaitingForSessions(readyForProcessingQueue); } else { - // A message arrived for a different generation of an existing orchestration instance. - // Put it back into the ready queue so that it can be processed once the current generation - // is done executing. - if (readyForProcessingQueue.Count == 0) + PendingMessageBatch nextBatch = node.Value; + this.pendingOrchestrationMessageBatches.Remove(node); + + if (nextBatch.ControlQueue.IsReleased) { - // To avoid a tight dequeue loop, delay for a bit before putting this node back into the queue. - // This is only necessary when the queue is empty. The main dequeue thread must not be blocked - // by this delay, which is why we use Task.Delay(...).ContinueWith(...) instead of await. - Task.Delay(millisecondsDelay: 200).ContinueWith(_ => - { - lock (this.messageAndSessionLock) + queueToAbandon = nextBatch.ControlQueue; + messagesToAbandon = nextBatch.Messages.ToList(); + shouldStopWaitingForSessions = this.ShouldStopWaitingForSessions(readyForProcessingQueue); + } + else if (!this.activeOrchestrationSessions.TryGetValue(nextBatch.OrchestrationInstanceId, out var existingSession)) + { + OrchestrationInstance instance = nextBatch.OrchestrationState?.OrchestrationInstance ?? + new OrchestrationInstance { - this.pendingOrchestrationMessageBatches.AddLast(node); - readyForProcessingQueue.Enqueue(node); - } - }); + InstanceId = nextBatch.OrchestrationInstanceId, + ExecutionId = nextBatch.OrchestrationExecutionId, + }; + + Guid traceActivityId = AzureStorageOrchestrationService.StartNewLogicalTraceScope(useExisting: true); + + OrchestrationSession session = new OrchestrationSession( + this.settings, + this.storageAccountName, + instance, + nextBatch.ControlQueue, + nextBatch.Messages, + nextBatch.OrchestrationState, + nextBatch.ETags, + nextBatch.LastCheckpointTime, + nextBatch.TrackingStoreContext, + this.settings.ExtendedSessionIdleTimeout, + this.shutdownToken, + traceActivityId); + + this.activeOrchestrationSessions.Add(instance.InstanceId, session); + + return session; + } + else if (nextBatch.OrchestrationExecutionId == existingSession.Instance?.ExecutionId) + { + // there is already an active session with the same execution id. + // The session might be waiting for more messages. If it is, signal them. + existingSession.AddOrReplaceMessages(node.Value.Messages); } else { - this.pendingOrchestrationMessageBatches.AddLast(node); - readyForProcessingQueue.Enqueue(node); + // A message arrived for a different generation of an existing orchestration instance. + // Put it back into the ready queue so that it can be processed once the current generation + // is done executing. + if (readyForProcessingQueue.Count == 0) + { + // To avoid a tight dequeue loop, delay for a bit before putting this node back into the queue. + // This is only necessary when the queue is empty. The main dequeue thread must not be blocked + // by this delay, which is why it is scheduled without awaiting here. + _ = this.RequeuePendingBatchAfterDelayAsync(node, readyForProcessingQueue); + } + else + { + this.RequeueOrAbandonPendingBatchLocked(node, readyForProcessingQueue, out queueToAbandon, out messagesToAbandon); + } } } } + + if (queueToAbandon != null) + { + await this.AbandonMessagesForDrainAsync(queueToAbandon, messagesToAbandon); + } + + if (shouldStopWaitingForSessions) + { + return null; + } } return null; } + async Task RequeuePendingBatchAfterDelayAsync( + LinkedListNode node, + AsyncQueue> readyForProcessingQueue) + { + await Task.Delay(millisecondsDelay: 200); + + ControlQueue? queueToAbandon; + IReadOnlyList messagesToAbandon; + bool shouldStopWaitingForSessions; + + lock (this.messageAndSessionLock) + { + this.RequeueOrAbandonPendingBatchLocked(node, readyForProcessingQueue, out queueToAbandon, out messagesToAbandon); + shouldStopWaitingForSessions = queueToAbandon != null && this.ShouldStopWaitingForSessions(readyForProcessingQueue); + } + + if (queueToAbandon != null) + { + await this.AbandonMessagesForDrainAsync(queueToAbandon, messagesToAbandon); + if (shouldStopWaitingForSessions) + { + readyForProcessingQueue.Enqueue(node); + } + } + } + + void RequeueOrAbandonPendingBatchLocked( + LinkedListNode node, + AsyncQueue> readyForProcessingQueue, + out ControlQueue? queueToAbandon, + out IReadOnlyList messagesToAbandon) + { + if (node.Value.ControlQueue.IsReleased) + { + queueToAbandon = node.Value.ControlQueue; + messagesToAbandon = node.Value.Messages.ToList(); + return; + } + + this.pendingOrchestrationMessageBatches.AddLast(node); + readyForProcessingQueue.Enqueue(node); + queueToAbandon = null; + messagesToAbandon = EmptyMessageDataList; + } + + bool ShouldStopWaitingForSessions(AsyncQueue> readyForProcessingQueue) + { + return readyForProcessingQueue.Count == 0 && + !this.ownedControlQueues.Values.Any(queue => !queue.IsReleased); + } + + void SignalSessionWaitersIfNoQueuesRemain(ControlQueue releasedControlQueue) + { + lock (this.messageAndSessionLock) + { + if (this.ownedControlQueues.Values.Any(queue => !queue.IsReleased)) + { + return; + } + + var orchestratorSentinel = new LinkedListNode( + new PendingMessageBatch(releasedControlQueue, string.Empty, executionId: null)); + var entitySentinel = new LinkedListNode( + new PendingMessageBatch(releasedControlQueue, string.Empty, executionId: null)); + this.orchestrationsReadyForProcessingQueue.Enqueue(orchestratorSentinel); + this.entitiesReadyForProcessingQueue.Enqueue(entitySentinel); + } + } + /// /// Immediately removes all active sessions, causing /// to return false for all partitions. This unblocks so that @@ -682,8 +980,14 @@ public bool TryGetExistingSession(string instanceId, out OrchestrationSession se } } - public bool TryReleaseSession(string instanceId, CancellationToken cancellationToken, out OrchestrationSession session) + public bool TryReleaseSession( + string instanceId, + CancellationToken cancellationToken, + out OrchestrationSession session, + out IReadOnlyList messagesToAbandon) { + messagesToAbandon = EmptyMessageDataList; + // Taking this lock ensures we don't add new messages to a session we're about to release. lock (this.messageAndSessionLock) { @@ -693,7 +997,7 @@ public bool TryReleaseSession(string instanceId, CancellationToken cancellationT this.activeOrchestrationSessions.Remove(instanceId)) { // Put any unprocessed messages back into the pending buffer. - this.AddMessageToPendingOrchestration( + messagesToAbandon = this.AddMessageToPendingOrchestration( session.ControlQueue, session.PendingMessages.Concat(session.DeferredMessages), session.TraceActivityId, diff --git a/src/DurableTask.AzureStorage/Partitioning/TablePartitionManager.cs b/src/DurableTask.AzureStorage/Partitioning/TablePartitionManager.cs index 3fad8c98d..a55685ca4 100644 --- a/src/DurableTask.AzureStorage/Partitioning/TablePartitionManager.cs +++ b/src/DurableTask.AzureStorage/Partitioning/TablePartitionManager.cs @@ -362,7 +362,7 @@ public async Task ReadAndWriteTableAsync(bool isShuttingDown, // In a worker becomes unhealthy, it may lose a lease without realizing it and continue listening // for messages. We check for that case here and stop dequeuing messages if we discover that // another worker currently owns the lease. - this.service.DropLostControlQueue(partition); + await this.service.DropLostControlQueue(partition); bool claimedLease = false; bool stoleLease = false; diff --git a/test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs b/test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs new file mode 100644 index 000000000..bb6dafdb2 --- /dev/null +++ b/test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs @@ -0,0 +1,783 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.AzureStorage.Tests +{ + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.Diagnostics; + using System.Linq; + using System.Reflection; + using System.Threading; + using System.Threading.Tasks; + using Azure; + using Azure.Storage.Queues; + using Azure.Storage.Queues.Models; + using DurableTask.AzureStorage.Messaging; + using DurableTask.AzureStorage.Monitoring; + using DurableTask.AzureStorage.Partitioning; + using DurableTask.AzureStorage.Storage; + using DurableTask.AzureStorage.Tracking; + using DurableTask.Core; + using DurableTask.Core.History; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Moq; + + /// + /// Tests for shutdown cancellation behavior with extended sessions. + /// + [TestClass] + public class OrchestrationSessionTests + { + /// + /// Verifies that + /// exits immediately when the cancellation token is cancelled. + /// + [TestMethod] + public async Task WaitAsync_CancellationToken_ExitsImmediately() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + + TimeSpan longTimeout = TimeSpan.FromSeconds(30); + Task waitTask = resetEvent.WaitAsync(longTimeout, cts.Token); + + Assert.IsFalse(waitTask.IsCompleted, "Wait should not complete immediately"); + + var stopwatch = Stopwatch.StartNew(); + cts.Cancel(); + + bool result = await waitTask; + stopwatch.Stop(); + + Assert.IsFalse(result, "Cancellation should return false (no signal received)"); + Assert.IsTrue( + stopwatch.ElapsedMilliseconds < 5000, + $"Cancellation should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms"); + } + + /// + /// Verifies that signaling still returns true when a cancellation token is provided. + /// + [TestMethod] + public async Task WaitAsync_WithCancellationToken_SignalStillWorks() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + + Task waitTask = resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); + Assert.IsFalse(waitTask.IsCompleted); + + resetEvent.Set(); + + Task winner = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5))); + Assert.IsTrue(winner == waitTask, "Signal should wake the waiter"); + Assert.IsTrue(waitTask.Result, "Wait result should be true when signaled"); + } + + /// + /// Verifies that the wait returns false on timeout when a cancellation token is provided but not cancelled. + /// + [TestMethod] + public async Task WaitAsync_WithCancellationToken_TimeoutStillWorks() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + + bool result = await resetEvent.WaitAsync(TimeSpan.FromMilliseconds(100), cts.Token); + + Assert.IsFalse(result, "Wait should return false on timeout"); + } + + /// + /// Verifies that all queued waiters return false when the token is cancelled. + /// + [TestMethod] + public async Task WaitAsync_CancellationToken_MultipleWaiters() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + + var waiters = new List>(); + for (int i = 0; i < 5; i++) + { + waiters.Add(resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token)); + } + + foreach (var waiter in waiters) + { + Assert.IsFalse(waiter.IsCompleted); + } + + var stopwatch = Stopwatch.StartNew(); + cts.Cancel(); + + // All waiters should return false (cancelled = not signaled) + await Task.WhenAll( + waiters.Select( + async waiter => + { + bool result = await waiter; + Assert.IsFalse(result, "Cancelled waiter should return false"); + })); + + stopwatch.Stop(); + + Assert.IsTrue( + stopwatch.ElapsedMilliseconds < 5000, + $"All waiters should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms"); + } + + /// + /// Verifies that a pre-cancelled token causes WaitAsync to return false immediately. + /// + [TestMethod] + public async Task WaitAsync_AlreadyCancelledToken_ReturnsFalseImmediately() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + cts.Cancel(); // Pre-cancel + + var stopwatch = Stopwatch.StartNew(); + bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); + stopwatch.Stop(); + + Assert.IsFalse(result, "Pre-cancelled token should cause immediate false return"); + Assert.IsTrue( + stopwatch.ElapsedMilliseconds < 5000, + $"Should complete immediately, but took {stopwatch.ElapsedMilliseconds}ms"); + } + + /// + /// Verifies that a pre-cancelled token still returns true if the event is already signaled. + /// + [TestMethod] + public async Task WaitAsync_AlreadySignaledAndCancelled_ReturnsTrue() + { + var resetEvent = new AsyncAutoResetEvent(signaled: true); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); + Assert.IsTrue(result, "Already signaled event should return true even with cancelled token"); + } + + /// + /// Verifies that clears all active sessions. + /// + [TestMethod] + public void AbortAllSessions_ClearsActiveSessions() + { + var settings = new AzureStorageOrchestrationServiceSettings(); + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + // Use reflection to access the internal sessions dictionary. + var sessionsField = typeof(OrchestrationSessionManager) + .GetField("activeOrchestrationSessions", BindingFlags.NonPublic | BindingFlags.Instance); + var sessions = (Dictionary)sessionsField.GetValue(manager); + + manager.GetStats(out _, out _, out int initialCount); + Assert.AreEqual(0, initialCount, "Should start with no active sessions"); + + sessions["instance1"] = null; + sessions["instance2"] = null; + sessions["instance3"] = null; + + manager.GetStats(out _, out _, out int activeCount); + Assert.AreEqual(3, activeCount, "Should have 3 active sessions"); + + manager.AbortAllSessions(); + + manager.GetStats(out _, out _, out int afterAbortCount); + Assert.AreEqual(0, afterAbortCount, "AbortAllSessions should clear all active sessions"); + } + + /// + /// Verifies that is safe to call with no active sessions. + /// + [TestMethod] + public void AbortAllSessions_NoSessions_DoesNotThrow() + { + var settings = new AzureStorageOrchestrationServiceSettings(); + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + manager.AbortAllSessions(); + + manager.GetStats(out _, out _, out int count); + Assert.AreEqual(0, count, "Should still have no active sessions"); + } + + [TestMethod] + public async Task GetNextSessionAsync_DrainedReadyQueueNode_ReturnsNullWhenNoQueuesRemain() + { + var settings = new AzureStorageOrchestrationServiceSettings + { + StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"), + }; + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + var storageClient = new AzureStorageClient(settings); + var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName); + var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager); + + object pendingBatch = CreatePendingBatch(controlQueue); + object node = AddPendingBatchNode(manager, pendingBatch); + RemovePendingBatchNode(manager, node); + EnqueueReadyForProcessingNode(manager, node); + + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(100)); + OrchestrationSession session = await manager.GetNextSessionAsync(entitiesOnly: false, cts.Token); + + Assert.IsNull(session, "Detached ready nodes should not block dispatch when no queues remain."); + } + + [TestMethod] + public async Task ScheduleOrchestrationStatePrefetch_DetachedNode_DoesNotFetchHistory() + { + var settings = new AzureStorageOrchestrationServiceSettings + { + StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"), + }; + var stats = new AzureStorageOrchestrationServiceStats(); + int fetchCount = 0; + var trackingStore = new Mock(); + trackingStore + .Setup(t => t.GetHistoryEventsAsync("instance1", "execution1", It.IsAny())) + .Callback(() => fetchCount++) + .ThrowsAsync(new OperationCanceledException()); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + var storageClient = new AzureStorageClient(settings); + var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName); + var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager); + + object pendingBatch = CreatePendingBatch(controlQueue); + object node = AddPendingBatchNode(manager, pendingBatch); + RemovePendingBatchNode(manager, node); + + await InvokeScheduleOrchestrationStatePrefetch(manager, node, CancellationToken.None); + + Assert.AreEqual(0, fetchCount, "Detached pending batches should not fetch orchestration history."); + } + + [TestMethod] + public void AddMessageToPendingOrchestration_ReleasedControlQueue_ReturnsMessagesToAbandon() + { + var settings = new AzureStorageOrchestrationServiceSettings + { + StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"), + }; + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + var storageClient = new AzureStorageClient(settings); + var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName); + var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager); + controlQueue.Release(null, "test"); + + MessageData message = CreateMessageData(); + MethodInfo addMessage = typeof(OrchestrationSessionManager).GetMethod( + "AddMessageToPendingOrchestration", + BindingFlags.Instance | BindingFlags.NonPublic); + + var messagesToAbandon = (IReadOnlyList)addMessage.Invoke( + manager, + new object[] { controlQueue, new[] { message }, Guid.NewGuid(), CancellationToken.None }); + + Assert.IsNotNull(messagesToAbandon, "Released queue messages should be returned for immediate abandon."); + Assert.AreEqual(1, messagesToAbandon.Count); + Assert.AreSame(message, messagesToAbandon[0]); + + manager.GetStats(out int pendingOrchestratorInstances, out _, out _); + Assert.AreEqual(0, pendingOrchestratorInstances, "Released queue messages should not be added to pending batches."); + } + + [TestMethod] + public async Task RemoveQueue_PendingBatch_AbandonsMessages() + { + var settings = new AzureStorageOrchestrationServiceSettings + { + StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"), + }; + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + var storageClient = new AzureStorageClient(settings); + var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName); + var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager); + AddOwnedControlQueue(manager, "partition-0", controlQueue); + + MessageData message = CreateMessageData(); + int abandonCount = 0; + var queueClient = new Mock(); + queueClient.SetupGet(q => q.Name).Returns("partition-0"); + queueClient + .Setup( + q => q.UpdateMessageAsync( + message.OriginalQueueMessage.MessageId, + message.OriginalQueueMessage.PopReceipt, + It.IsAny(), + TimeSpan.Zero, + It.IsAny())) + .Callback(() => abandonCount++) + .ReturnsAsync(Response.FromValue( + QueuesModelFactory.UpdateReceipt("newPopReceipt", DateTimeOffset.UtcNow), + Mock.Of())); + SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object); + + object pendingBatch = CreatePendingBatch(controlQueue); + AddMessageToPendingBatch(pendingBatch, message); + AddPendingBatchNode(manager, pendingBatch); + + await InvokeRemoveQueue(manager, "partition-0"); + + Assert.AreEqual(0, GetPendingBatchCount(manager), "RemoveQueue should remove pending batches for the released queue."); + Assert.AreEqual(1, abandonCount, "RemoveQueue should immediately abandon pending messages for the released queue."); + } + + [TestMethod] + public async Task RemoveQueue_PendingBatch_ReturnsNullToWaitingDispatcher() + { + var settings = new AzureStorageOrchestrationServiceSettings + { + StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"), + }; + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + var storageClient = new AzureStorageClient(settings); + var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName); + var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager); + AddOwnedControlQueue(manager, "partition-0", controlQueue); + + MessageData message = CreateMessageData(); + var queueClient = new Mock(); + queueClient.SetupGet(q => q.Name).Returns("partition-0"); + queueClient + .Setup( + q => q.UpdateMessageAsync( + message.OriginalQueueMessage.MessageId, + message.OriginalQueueMessage.PopReceipt, + It.IsAny(), + TimeSpan.Zero, + It.IsAny())) + .ReturnsAsync(Response.FromValue( + QueuesModelFactory.UpdateReceipt("newPopReceipt", DateTimeOffset.UtcNow), + Mock.Of())); + SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object); + + object pendingBatch = CreatePendingBatch(controlQueue); + AddMessageToPendingBatch(pendingBatch, message); + AddPendingBatchNode(manager, pendingBatch); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1)); + Task getNextTask = manager.GetNextSessionAsync(entitiesOnly: false, cts.Token); + + await InvokeRemoveQueue(manager, "partition-0"); + OrchestrationSession session = await getNextTask; + + Assert.IsNull(session, "Removing the last queue should unblock dispatchers waiting for a session."); + } + + [TestMethod] + public async Task WaitForDequeueLoopToStopAsync_FaultedDequeueLoop_PropagatesUnexpectedException() + { + var settings = new AzureStorageOrchestrationServiceSettings(); + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + var dequeueLoopTasks = (ConcurrentDictionary)GetPrivateField(manager, "dequeueLoopTasks"); + var expected = new InvalidOperationException("unexpected dequeue loop failure"); + dequeueLoopTasks["partition-0"] = Task.FromException(expected); + + MethodInfo wait = typeof(OrchestrationSessionManager).GetMethod( + "WaitForDequeueLoopToStopAsync", + BindingFlags.NonPublic | BindingFlags.Instance); + + Task waitTask = (Task)wait.Invoke(manager, new object[] { "partition-0", CancellationToken.None }); + + InvalidOperationException actual = await Assert.ThrowsExceptionAsync( + () => waitTask); + + Assert.AreSame(expected, actual); + } + + [TestMethod] + public async Task AbandonMessageForDrainAsync_DurableTaskStorageException_DoesNotThrow() + { + var settings = new AzureStorageOrchestrationServiceSettings + { + StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"), + }; + var storageClient = new AzureStorageClient(settings); + var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName); + var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager); + + var queueClient = new Mock(); + queueClient.SetupGet(q => q.Name).Returns("partition-0"); + queueClient + .Setup( + q => q.UpdateMessageAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new RequestFailedException(404, "queue update failed")); + SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object); + + await controlQueue.AbandonMessageForDrainAsync(CreateMessageData()); + } + + [TestMethod] + public async Task GetNextSessionAsync_ReleasedDelayedRequeueNode_AbandonsMessagesAndReturnsNullWhenNoQueuesRemain() + { + var settings = new AzureStorageOrchestrationServiceSettings + { + StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"), + }; + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + var storageClient = new AzureStorageClient(settings); + var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName); + var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager); + MessageData message = CreateMessageData(); + int abandonCount = 0; + var queueClient = new Mock(); + queueClient.SetupGet(q => q.Name).Returns("partition-0"); + queueClient + .Setup( + q => q.UpdateMessageAsync( + message.OriginalQueueMessage.MessageId, + message.OriginalQueueMessage.PopReceipt, + It.IsAny(), + TimeSpan.Zero, + It.IsAny())) + .Callback(() => abandonCount++) + .ReturnsAsync(Response.FromValue( + QueuesModelFactory.UpdateReceipt("newPopReceipt", DateTimeOffset.UtcNow), + Mock.Of())); + SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object); + + AddActiveSession(manager, settings, controlQueue, "instance1", "activeExecution"); + object pendingBatch = CreatePendingBatch(controlQueue); + AddMessageToPendingBatch(pendingBatch, message); + object node = AddPendingBatchNode(manager, pendingBatch); + EnqueueReadyForProcessingNode(manager, node); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2)); + Task getNextTask = manager.GetNextSessionAsync(entitiesOnly: false, cts.Token); + + await WaitUntilAsync(() => IsNodeDetached(node), TimeSpan.FromSeconds(2)); + controlQueue.Release(null, "test"); + OrchestrationSession session = await getNextTask; + + Assert.IsNull(session, "Released delayed requeue nodes should not block dispatch when no queues remain."); + await WaitUntilAsync(() => abandonCount == 1, TimeSpan.FromSeconds(2)); + + Assert.AreEqual(0, GetPendingBatchCount(manager), "Released queue nodes should not be requeued after a delay."); + Assert.AreEqual(0, GetReadyQueueCount(manager), "Released queue nodes should not be made ready for dispatch."); + Assert.AreEqual(1, abandonCount, "Messages from a released delayed requeue node should be immediately abandoned."); + } + + [TestMethod] + public async Task GetNextSessionAsync_ReleasedReadyQueueNode_AbandonsMessagesAndReturnsNullWhenNoQueuesRemain() + { + var settings = new AzureStorageOrchestrationServiceSettings + { + StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"), + }; + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + var storageClient = new AzureStorageClient(settings); + var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName); + var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager); + MessageData message = CreateMessageData(); + int abandonCount = 0; + var queueClient = new Mock(); + queueClient.SetupGet(q => q.Name).Returns("partition-0"); + queueClient + .Setup( + q => q.UpdateMessageAsync( + message.OriginalQueueMessage.MessageId, + message.OriginalQueueMessage.PopReceipt, + It.IsAny(), + TimeSpan.Zero, + It.IsAny())) + .Callback(() => abandonCount++) + .ReturnsAsync(Response.FromValue( + QueuesModelFactory.UpdateReceipt("newPopReceipt", DateTimeOffset.UtcNow), + Mock.Of())); + SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object); + + object pendingBatch = CreatePendingBatch(controlQueue); + AddMessageToPendingBatch(pendingBatch, message); + object node = AddPendingBatchNode(manager, pendingBatch); + EnqueueReadyForProcessingNode(manager, node); + controlQueue.Release(null, "test"); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1)); + OrchestrationSession session = await manager.GetNextSessionAsync(entitiesOnly: false, cts.Token); + + Assert.IsNull(session, "Released queue nodes should not block dispatch when no queues remain."); + Assert.AreEqual(0, GetPendingBatchCount(manager), "Released ready nodes should be removed from pending batches."); + Assert.AreEqual(1, abandonCount, "Messages from released ready nodes should be immediately abandoned."); + } + + static object CreatePendingBatch(ControlQueue controlQueue) + { + Type pendingBatchType = typeof(OrchestrationSessionManager) + .GetNestedType("PendingMessageBatch", BindingFlags.NonPublic); + + return Activator.CreateInstance( + pendingBatchType, + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, + binder: null, + args: new object[] { controlQueue, "instance1", "execution1" }, + culture: null); + } + + static object AddPendingBatchNode(OrchestrationSessionManager manager, object pendingBatch) + { + object pendingBatches = GetPrivateField(manager, "pendingOrchestrationMessageBatches"); + MethodInfo addLast = pendingBatches.GetType().GetMethod("AddLast", new[] { pendingBatch.GetType() }); + return addLast.Invoke(pendingBatches, new[] { pendingBatch }); + } + + static void RemovePendingBatchNode(OrchestrationSessionManager manager, object node) + { + object pendingBatches = GetPrivateField(manager, "pendingOrchestrationMessageBatches"); + MethodInfo remove = pendingBatches.GetType().GetMethod("Remove", new[] { node.GetType() }); + remove.Invoke(pendingBatches, new[] { node }); + } + + static void EnqueueReadyForProcessingNode(OrchestrationSessionManager manager, object node) + { + object readyQueue = GetPrivateField(manager, "orchestrationsReadyForProcessingQueue"); + MethodInfo enqueue = readyQueue.GetType().GetMethod("Enqueue"); + enqueue.Invoke(readyQueue, new[] { node }); + } + + static void AddOwnedControlQueue(OrchestrationSessionManager manager, string partitionId, ControlQueue controlQueue) + { + var ownedQueues = (ConcurrentDictionary)GetPrivateField(manager, "ownedControlQueues"); + ownedQueues[partitionId] = controlQueue; + } + + static async Task InvokeRemoveQueue(OrchestrationSessionManager manager, string partitionId) + { + MethodInfo removeQueue = typeof(OrchestrationSessionManager).GetMethod("RemoveQueue"); + object result = removeQueue.Invoke(manager, new object[] { partitionId, CloseReason.LeaseLost, "test" }); + if (result is Task task) + { + await task; + } + } + + static void AddMessageToPendingBatch(object pendingBatch, MessageData message) + { + var messages = (ICollection)pendingBatch.GetType().GetProperty("Messages").GetValue(pendingBatch); + messages.Add(message); + } + + static Task InvokeScheduleOrchestrationStatePrefetch( + OrchestrationSessionManager manager, + object node, + CancellationToken cancellationToken) + { + MethodInfo schedule = typeof(OrchestrationSessionManager).GetMethod( + "ScheduleOrchestrationStatePrefetch", + BindingFlags.NonPublic | BindingFlags.Instance); + + return (Task)schedule.Invoke(manager, new[] { node, Guid.NewGuid(), cancellationToken }); + } + + static void AddActiveSession( + OrchestrationSessionManager manager, + AzureStorageOrchestrationServiceSettings settings, + ControlQueue controlQueue, + string instanceId, + string executionId) + { + var sessions = (Dictionary)GetPrivateField(manager, "activeOrchestrationSessions"); + var instance = new OrchestrationInstance + { + InstanceId = instanceId, + ExecutionId = executionId, + }; + var runtimeState = new OrchestrationRuntimeState(); + runtimeState.AddEvent(new ExecutionStartedEvent(-1, string.Empty) + { + OrchestrationInstance = instance, + }); + + sessions[instanceId] = new OrchestrationSession( + settings, + "testaccount", + instance, + controlQueue, + new List(), + runtimeState, + eTags: null, + DateTime.UtcNow, + trackingStoreContext: null, + TimeSpan.FromSeconds(30), + CancellationToken.None, + Guid.NewGuid()); + } + + static bool IsNodeDetached(object node) + { + object list = node.GetType().GetProperty("List").GetValue(node); + return list == null; + } + + static int GetPendingBatchCount(OrchestrationSessionManager manager) + { + object pendingBatches = GetPrivateField(manager, "pendingOrchestrationMessageBatches"); + return (int)pendingBatches.GetType().GetProperty("Count").GetValue(pendingBatches); + } + + static int GetReadyQueueCount(OrchestrationSessionManager manager) + { + object readyQueue = GetPrivateField(manager, "orchestrationsReadyForProcessingQueue"); + return (int)readyQueue.GetType().GetProperty("Count").GetValue(readyQueue); + } + + static async Task WaitUntilAsync(Func condition, TimeSpan timeout) + { + var stopwatch = Stopwatch.StartNew(); + while (!condition()) + { + if (stopwatch.Elapsed > timeout) + { + Assert.Fail("Condition was not reached before timeout."); + } + + await Task.Delay(TimeSpan.FromMilliseconds(10)); + } + } + + static MessageData CreateMessageData() + { + var instance = new OrchestrationInstance + { + InstanceId = "instance1", + ExecutionId = "execution1", + }; + + var taskMessage = new TaskMessage + { + OrchestrationInstance = instance, + Event = new TimerFiredEvent(0), + }; + + var message = new MessageData( + taskMessage, + Guid.NewGuid(), + "partition-0", + orchestrationEpisode: null, + sender: instance); + + message.OriginalQueueMessage = QueuesModelFactory.QueueMessage( + Guid.NewGuid().ToString("N"), + Guid.NewGuid().ToString("N"), + string.Empty, + 1, + DateTimeOffset.UtcNow, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddMinutes(5)); + + return message; + } + + static object GetPrivateField(object target, string fieldName) + { + FieldInfo field = target.GetType().GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Instance); + Assert.IsNotNull(field); + return field.GetValue(target); + } + + static void SetPrivateField(object target, string fieldName, object value) + { + FieldInfo field = target.GetType().GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Instance); + Assert.IsNotNull(field); + field.SetValue(target, value); + } + } +}