diff --git a/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs b/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs
deleted file mode 100644
index 126c4b9bc..000000000
--- a/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs
+++ /dev/null
@@ -1,227 +0,0 @@
-// ----------------------------------------------------------------------------------
-// Copyright Microsoft Corporation
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-// http://www.apache.org/licenses/LICENSE-2.0
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// ----------------------------------------------------------------------------------
-
-namespace DurableTask.AzureStorage.Tests
-{
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.Linq;
- using System.Reflection;
- using System.Threading;
- using System.Threading.Tasks;
- using DurableTask.AzureStorage.Messaging;
- using DurableTask.AzureStorage.Monitoring;
- using DurableTask.AzureStorage.Tracking;
- using Microsoft.VisualStudio.TestTools.UnitTesting;
- using Moq;
-
- ///
- /// Tests for shutdown cancellation behavior with extended sessions.
- ///
- [TestClass]
- public class OrchestrationSessionTests
- {
- ///
- /// Verifies that
- /// exits immediately when the cancellation token is cancelled.
- ///
- [TestMethod]
- public async Task WaitAsync_CancellationToken_ExitsImmediately()
- {
- var resetEvent = new AsyncAutoResetEvent(signaled: false);
- using var cts = new CancellationTokenSource();
-
- TimeSpan longTimeout = TimeSpan.FromSeconds(30);
- Task waitTask = resetEvent.WaitAsync(longTimeout, cts.Token);
-
- Assert.IsFalse(waitTask.IsCompleted, "Wait should not complete immediately");
-
- var stopwatch = Stopwatch.StartNew();
- cts.Cancel();
-
- bool result = await waitTask;
- stopwatch.Stop();
-
- Assert.IsFalse(result, "Cancellation should return false (no signal received)");
- Assert.IsTrue(
- stopwatch.ElapsedMilliseconds < 5000,
- $"Cancellation should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms");
- }
-
- ///
- /// Verifies that signaling still returns true when a cancellation token is provided.
- ///
- [TestMethod]
- public async Task WaitAsync_WithCancellationToken_SignalStillWorks()
- {
- var resetEvent = new AsyncAutoResetEvent(signaled: false);
- using var cts = new CancellationTokenSource();
-
- Task waitTask = resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
- Assert.IsFalse(waitTask.IsCompleted);
-
- resetEvent.Set();
-
- Task winner = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5)));
- Assert.IsTrue(winner == waitTask, "Signal should wake the waiter");
- Assert.IsTrue(waitTask.Result, "Wait result should be true when signaled");
- }
-
- ///
- /// Verifies that the wait returns false on timeout when a cancellation token is provided but not cancelled.
- ///
- [TestMethod]
- public async Task WaitAsync_WithCancellationToken_TimeoutStillWorks()
- {
- var resetEvent = new AsyncAutoResetEvent(signaled: false);
- using var cts = new CancellationTokenSource();
-
- bool result = await resetEvent.WaitAsync(TimeSpan.FromMilliseconds(100), cts.Token);
-
- Assert.IsFalse(result, "Wait should return false on timeout");
- }
-
- ///
- /// Verifies that all queued waiters return false when the token is cancelled.
- ///
- [TestMethod]
- public async Task WaitAsync_CancellationToken_MultipleWaiters()
- {
- var resetEvent = new AsyncAutoResetEvent(signaled: false);
- using var cts = new CancellationTokenSource();
-
- var waiters = new List>();
- for (int i = 0; i < 5; i++)
- {
- waiters.Add(resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token));
- }
-
- foreach (var waiter in waiters)
- {
- Assert.IsFalse(waiter.IsCompleted);
- }
-
- var stopwatch = Stopwatch.StartNew();
- cts.Cancel();
-
- // All waiters should return false (cancelled = not signaled)
- await Task.WhenAll(
- waiters.Select(
- async waiter =>
- {
- bool result = await waiter;
- Assert.IsFalse(result, "Cancelled waiter should return false");
- }));
-
- stopwatch.Stop();
-
- Assert.IsTrue(
- stopwatch.ElapsedMilliseconds < 5000,
- $"All waiters should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms");
- }
-
- ///
- /// Verifies that a pre-cancelled token causes WaitAsync to return false immediately.
- ///
- [TestMethod]
- public async Task WaitAsync_AlreadyCancelledToken_ReturnsFalseImmediately()
- {
- var resetEvent = new AsyncAutoResetEvent(signaled: false);
- using var cts = new CancellationTokenSource();
- cts.Cancel(); // Pre-cancel
-
- var stopwatch = Stopwatch.StartNew();
- bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
- stopwatch.Stop();
-
- Assert.IsFalse(result, "Pre-cancelled token should cause immediate false return");
- Assert.IsTrue(
- stopwatch.ElapsedMilliseconds < 5000,
- $"Should complete immediately, but took {stopwatch.ElapsedMilliseconds}ms");
- }
-
- ///
- /// Verifies that a pre-cancelled token still returns true if the event is already signaled.
- ///
- [TestMethod]
- public async Task WaitAsync_AlreadySignaledAndCancelled_ReturnsTrue()
- {
- var resetEvent = new AsyncAutoResetEvent(signaled: true);
- using var cts = new CancellationTokenSource();
- cts.Cancel();
-
- bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
- Assert.IsTrue(result, "Already signaled event should return true even with cancelled token");
- }
-
- ///
- /// Verifies that clears all active sessions.
- ///
- [TestMethod]
- public void AbortAllSessions_ClearsActiveSessions()
- {
- var settings = new AzureStorageOrchestrationServiceSettings();
- var stats = new AzureStorageOrchestrationServiceStats();
- var trackingStore = new Mock();
-
- using var manager = new OrchestrationSessionManager(
- "testaccount",
- settings,
- stats,
- trackingStore.Object);
-
- // Use reflection to access the internal sessions dictionary.
- var sessionsField = typeof(OrchestrationSessionManager)
- .GetField("activeOrchestrationSessions", BindingFlags.NonPublic | BindingFlags.Instance);
- var sessions = (Dictionary)sessionsField.GetValue(manager);
-
- manager.GetStats(out _, out _, out int initialCount);
- Assert.AreEqual(0, initialCount, "Should start with no active sessions");
-
- sessions["instance1"] = null;
- sessions["instance2"] = null;
- sessions["instance3"] = null;
-
- manager.GetStats(out _, out _, out int activeCount);
- Assert.AreEqual(3, activeCount, "Should have 3 active sessions");
-
- manager.AbortAllSessions();
-
- manager.GetStats(out _, out _, out int afterAbortCount);
- Assert.AreEqual(0, afterAbortCount, "AbortAllSessions should clear all active sessions");
- }
-
- ///
- /// Verifies that is safe to call with no active sessions.
- ///
- [TestMethod]
- public void AbortAllSessions_NoSessions_DoesNotThrow()
- {
- var settings = new AzureStorageOrchestrationServiceSettings();
- var stats = new AzureStorageOrchestrationServiceStats();
- var trackingStore = new Mock();
-
- using var manager = new OrchestrationSessionManager(
- "testaccount",
- settings,
- stats,
- trackingStore.Object);
-
- manager.AbortAllSessions();
-
- manager.GetStats(out _, out _, out int count);
- Assert.AreEqual(0, count, "Should still have no active sessions");
- }
- }
-}
diff --git a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs
index 74798cb45..cfe64a7d4 100644
--- a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs
+++ b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs
@@ -553,7 +553,7 @@ internal async Task OnIntentLeaseAquiredAsync(BlobPartitionLease lease)
{
var controlQueue = new ControlQueue(this.azureStorageClient, lease.PartitionId, this.messageManager);
await controlQueue.CreateIfNotExistsAsync();
- this.orchestrationSessionManager.ResumeListeningIfOwnQueue(lease.PartitionId, controlQueue, this.shutdownSource.Token);
+ await this.orchestrationSessionManager.ResumeListeningIfOwnQueue(lease.PartitionId, controlQueue, this.shutdownSource.Token);
}
internal Task OnIntentLeaseReleasedAsync(BlobPartitionLease lease, CloseReason reason)
@@ -572,21 +572,20 @@ internal async Task OnOwnershipLeaseAquiredAsync(BlobPartitionLease lease)
this.allControlQueues[lease.PartitionId] = controlQueue;
}
- internal void DropLostControlQueue(TablePartitionLease partition)
+ internal async Task DropLostControlQueue(TablePartitionLease partition)
{
// If lease is lost but we're still dequeuing messages, remove the queue
if (this.allControlQueues.TryGetValue(partition.RowKey, out ControlQueue controlQueue) &&
this.OwnedControlQueues.Contains(controlQueue) &&
partition.CurrentOwner != this.settings.WorkerId)
{
- this.orchestrationSessionManager.RemoveQueue(partition.RowKey, CloseReason.LeaseLost, nameof(DropLostControlQueue));
+ await this.orchestrationSessionManager.RemoveQueue(partition.RowKey, CloseReason.LeaseLost, nameof(DropLostControlQueue));
}
}
internal Task OnOwnershipLeaseReleasedAsync(BlobPartitionLease lease, CloseReason reason)
{
- this.orchestrationSessionManager.RemoveQueue(lease.PartitionId, reason, "Ownership LeaseCollectionBalancer");
- return Utils.CompletedTask;
+ return this.orchestrationSessionManager.RemoveQueue(lease.PartitionId, reason, "Ownership LeaseCollectionBalancer");
}
internal async Task OnTableLeaseAcquiredAsync(TablePartitionLease lease)
@@ -1249,7 +1248,13 @@ await this.CommitOutboundQueueMessages(
{
var messages = session.DeferredMessages.ToList();
session.DeferredMessages.Clear();
- this.orchestrationSessionManager.AddMessageToPendingOrchestration(session.ControlQueue, messages, session.TraceActivityId, CancellationToken.None);
+ IReadOnlyList messagesToAbandon = this.orchestrationSessionManager.AddMessageToPendingOrchestration(
+ session.ControlQueue,
+ messages,
+ session.TraceActivityId,
+ CancellationToken.None);
+
+ await this.orchestrationSessionManager.AbandonMessagesForDrainAsync(session.ControlQueue, messagesToAbandon);
}
}
// Handle the case where the 'ETag' has changed, which implies another worker has taken over this work item while
@@ -1524,8 +1529,11 @@ async Task ReleaseSessionAsync(string instanceId)
if (this.orchestrationSessionManager.TryReleaseSession(
instanceId,
this.shutdownSource.Token,
- out OrchestrationSession session))
+ out OrchestrationSession session,
+ out IReadOnlyList messagesToAbandon))
{
+ await this.orchestrationSessionManager.AbandonMessagesForDrainAsync(session.ControlQueue, messagesToAbandon);
+
// Some messages may need to be discarded
await session.DiscardedMessages.ParallelForEachAsync(
this.settings.MaxStorageOperationConcurrency,
diff --git a/src/DurableTask.AzureStorage/Messaging/ControlQueue.cs b/src/DurableTask.AzureStorage/Messaging/ControlQueue.cs
index 9f1c0d2ba..2d27fa3ce 100644
--- a/src/DurableTask.AzureStorage/Messaging/ControlQueue.cs
+++ b/src/DurableTask.AzureStorage/Messaging/ControlQueue.cs
@@ -23,6 +23,7 @@ namespace DurableTask.AzureStorage.Messaging
using DurableTask.AzureStorage.Monitoring;
using DurableTask.AzureStorage.Partitioning;
using DurableTask.AzureStorage.Storage;
+ using DurableTask.Core;
class ControlQueue : TaskHubQueue, IDisposable
{
@@ -209,6 +210,50 @@ public override Task AbandonMessageAsync(MessageData message, SessionBase? sessi
return base.AbandonMessageAsync(message, session);
}
+ ///
+ /// Abandons a message with zero visibility timeout so it becomes immediately visible
+ /// for another partition owner to pick up. This is used during drain to avoid stranding
+ /// messages that were dequeued but not yet promoted to active sessions.
+ ///
+ public async Task AbandonMessageForDrainAsync(MessageData message)
+ {
+ this.stats.PendingOrchestratorMessages.TryRemove(message.OriginalQueueMessage.MessageId, out _);
+
+ QueueMessage queueMessage = message.OriginalQueueMessage;
+ TaskMessage taskMessage = message.TaskMessage;
+ OrchestrationInstance instance = taskMessage.OrchestrationInstance;
+
+ this.settings.Logger.AbandoningMessage(
+ this.storageAccountName,
+ this.settings.TaskHubName,
+ taskMessage.Event.EventType.ToString(),
+ Utils.GetTaskEventId(taskMessage.Event),
+ queueMessage.MessageId,
+ instance.InstanceId,
+ instance.ExecutionId,
+ this.storageQueue.Name,
+ message.SequenceNumber,
+ queueMessage.PopReceipt,
+ visibilityTimeoutSeconds: 0);
+
+ try
+ {
+ await this.storageQueue.UpdateMessageAsync(
+ queueMessage,
+ TimeSpan.Zero,
+ clientRequestId: null);
+ }
+ catch (DurableTaskStorageException e)
+ {
+ this.settings.Logger.PartitionManagerWarning(
+ this.storageAccountName,
+ this.settings.TaskHubName,
+ this.settings.WorkerId,
+ this.Name,
+ $"Failed to abandon message {queueMessage.MessageId} during drain: {e.Message}");
+ }
+ }
+
public override Task DeleteMessageAsync(MessageData message, SessionBase? session = null)
{
this.stats.PendingOrchestratorMessages.TryRemove(message.OriginalQueueMessage.MessageId, out _);
diff --git a/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs b/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs
index abf7a58b2..a9d27aeb4 100644
--- a/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs
+++ b/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs
@@ -30,8 +30,11 @@ namespace DurableTask.AzureStorage
class OrchestrationSessionManager : IDisposable
{
+ static readonly IReadOnlyList EmptyMessageDataList = Array.Empty();
+
readonly Dictionary activeOrchestrationSessions = new Dictionary(StringComparer.OrdinalIgnoreCase);
readonly ConcurrentDictionary ownedControlQueues = new ConcurrentDictionary();
+ readonly ConcurrentDictionary dequeueLoopTasks = new ConcurrentDictionary();
readonly LinkedList pendingOrchestrationMessageBatches = new LinkedList();
readonly AsyncQueue> orchestrationsReadyForProcessingQueue = new AsyncQueue>();
readonly AsyncQueue> entitiesReadyForProcessingQueue = new AsyncQueue>();
@@ -67,7 +70,8 @@ public void AddQueue(string partitionId, ControlQueue controlQueue, Cancellation
if (this.ownedControlQueues.TryAdd(partitionId, controlQueue))
{
- _ = Task.Run(() => this.DequeueLoop(partitionId, controlQueue, cancellationToken));
+ Task dequeueLoopTask = Task.Run(async () => await this.DequeueLoop(partitionId, controlQueue, cancellationToken));
+ this.dequeueLoopTasks[partitionId] = dequeueLoopTask;
}
else
{
@@ -80,11 +84,13 @@ public void AddQueue(string partitionId, ControlQueue controlQueue, Cancellation
}
}
- public void RemoveQueue(string partitionId, CloseReason? reason, string caller)
+ public async Task RemoveQueue(string partitionId, CloseReason? reason, string caller)
{
if (this.ownedControlQueues.TryRemove(partitionId, out ControlQueue controlQueue))
{
controlQueue.Release(reason, caller);
+ this.dequeueLoopTasks.TryRemove(partitionId, out _);
+ await this.AbandonPendingMessagesAsync(partitionId, controlQueue);
}
}
@@ -97,15 +103,15 @@ public void ReleaseQueue(string partitionId, CloseReason? reason, string caller)
}
}
- public bool ResumeListeningIfOwnQueue(string partitionId, ControlQueue controlQueue, CancellationToken shutdownToken)
+ public async Task ResumeListeningIfOwnQueue(string partitionId, ControlQueue controlQueue, CancellationToken cancellationToken)
{
if (this.ownedControlQueues.TryGetValue(partitionId, out ControlQueue ownedControlQueue))
{
if (ownedControlQueue.IsReleased)
{
// The easiest way to resume listening is to re-add a new queue that has not been released.
- this.RemoveQueue(partitionId, null, "OrchestrationSessionManager ResumeListeningIfOwnQueue");
- this.AddQueue(partitionId, controlQueue, shutdownToken);
+ await this.RemoveQueue(partitionId, null, "OrchestrationSessionManager ResumeListeningIfOwnQueue");
+ this.AddQueue(partitionId, controlQueue, cancellationToken);
}
}
@@ -154,7 +160,13 @@ async Task DequeueLoop(string partitionId, ControlQueue controlQueue, Cancellati
traceActivityId,
cancellationToken);
- this.AddMessageToPendingOrchestration(controlQueue, filteredMessages, traceActivityId, cancellationToken);
+ IReadOnlyList messagesToAbandon = this.AddMessageToPendingOrchestration(
+ controlQueue,
+ filteredMessages,
+ traceActivityId,
+ cancellationToken);
+
+ await this.AbandonMessagesForDrainAsync(controlQueue, messagesToAbandon);
}
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
@@ -198,6 +210,8 @@ public async Task DrainAsync(string partitionId, CloseReason reason, Cancellatio
this.ReleaseQueue(partitionId, reason, caller);
try
{
+ await this.WaitForDequeueLoopToStopAsync(partitionId, cancellationToken);
+
// Wait until all messages from this queue have been processed.
while (!cancellationToken.IsCancellationRequested && this.IsControlQueueProcessingMessages(partitionId))
{
@@ -216,8 +230,150 @@ public async Task DrainAsync(string partitionId, CloseReason reason, Cancellatio
}
finally
{
- // Remove the partition from memory
- this.RemoveQueue(partitionId, reason, caller);
+ try
+ {
+ // Make dequeued-but-undispatched messages visible before dropping the partition.
+ await this.AbandonPendingMessagesAsync(partitionId);
+ }
+ finally
+ {
+ await this.RemoveQueue(partitionId, reason, caller);
+ }
+ }
+ }
+
+ async Task WaitForDequeueLoopToStopAsync(string partitionId, CancellationToken cancellationToken)
+ {
+ if (!this.dequeueLoopTasks.TryGetValue(partitionId, out Task dequeueLoopTask))
+ {
+ return;
+ }
+
+ try
+ {
+ bool completed = await WaitForTaskAsync(dequeueLoopTask, cancellationToken);
+ if (!completed)
+ {
+ this.settings.Logger.PartitionManagerWarning(
+ this.storageAccountName,
+ this.settings.TaskHubName,
+ this.settings.WorkerId,
+ partitionId,
+ $"Timed-out waiting for the dequeue loop to stop during drain.");
+ }
+ }
+ catch (OperationCanceledException e)
+ {
+ this.settings.Logger.PartitionManagerWarning(
+ this.storageAccountName,
+ this.settings.TaskHubName,
+ this.settings.WorkerId,
+ partitionId,
+ $"Canceled while waiting for the dequeue loop to stop during drain. Exception: {e}");
+ }
+ catch (AggregateException e) when (e.InnerExceptions.All(ex => ex is OperationCanceledException))
+ {
+ this.settings.Logger.PartitionManagerWarning(
+ this.storageAccountName,
+ this.settings.TaskHubName,
+ this.settings.WorkerId,
+ partitionId,
+ $"Canceled while waiting for the dequeue loop to stop during drain. Exception: {e}");
+ }
+ }
+
+ static async Task WaitForTaskAsync(Task task, CancellationToken cancellationToken)
+ {
+ if (task.IsCompleted)
+ {
+ await task;
+ return true;
+ }
+
+ var cancellationCompletion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ using (cancellationToken.Register(state => ((TaskCompletionSource)state).TrySetResult(true), cancellationCompletion))
+ {
+ Task completedTask = await Task.WhenAny(task, cancellationCompletion.Task);
+ if (completedTask != task)
+ {
+ return false;
+ }
+ }
+
+ await task;
+ return true;
+ }
+
+ ///
+ /// Abandons all pending (dequeued but not yet dispatched) messages for the specified partition,
+ /// making them immediately visible in the Azure Storage queue for the new partition owner.
+ /// This prevents a throughput gap equal to the visibility timeout duration when a partition
+ /// is released during drain.
+ ///
+ async Task AbandonPendingMessagesAsync(string partitionId, ControlQueue? removedControlQueue = null)
+ {
+ var messagesToAbandon = new List<(ControlQueue Queue, MessageData Message)>();
+
+ lock (this.messageAndSessionLock)
+ {
+ var node = this.pendingOrchestrationMessageBatches.First;
+ while (node != null)
+ {
+ LinkedListNode? next = node.Next;
+ PendingMessageBatch batch = node.Value;
+
+ if (string.Equals(batch.ControlQueue.Name, partitionId, StringComparison.OrdinalIgnoreCase))
+ {
+ foreach (MessageData message in batch.Messages)
+ {
+ messagesToAbandon.Add((batch.ControlQueue, message));
+ }
+
+ this.pendingOrchestrationMessageBatches.Remove(node);
+ }
+
+ node = next;
+ }
+ }
+
+ await this.AbandonMessagesForDrainAsync(partitionId, messagesToAbandon);
+
+ if (removedControlQueue != null)
+ {
+ this.SignalSessionWaitersIfNoQueuesRemain(removedControlQueue);
+ }
+ }
+
+ internal Task AbandonMessagesForDrainAsync(ControlQueue controlQueue, IReadOnlyList messages)
+ {
+ if (messages.Count == 0)
+ {
+ return Utils.CompletedTask;
+ }
+
+ var messagesToAbandon = messages
+ .Select(message => (controlQueue, message))
+ .ToList();
+
+ return this.AbandonMessagesForDrainAsync(controlQueue.Name, messagesToAbandon);
+ }
+
+ async Task AbandonMessagesForDrainAsync(
+ string partitionId,
+ IList<(ControlQueue Queue, MessageData Message)> messagesToAbandon)
+ {
+ if (messagesToAbandon.Count > 0)
+ {
+ this.settings.Logger.PartitionManagerInfo(
+ this.storageAccountName,
+ this.settings.TaskHubName,
+ this.settings.WorkerId,
+ partitionId,
+ $"Abandoning {messagesToAbandon.Count} pending message(s) during drain to make them immediately visible for the new partition owner.");
+
+ await messagesToAbandon.ParallelForEachAsync(
+ this.settings.MaxStorageOperationConcurrency,
+ item => item.Queue.AbandonMessageForDrainAsync(item.Message));
}
}
@@ -402,7 +558,7 @@ bool IsScheduledAfterInstanceUpdate(MessageData msg, OrchestrationState? remoteI
/// New messages to assign to orchestrators
/// The "related" ActivityId of this operation.
/// Cancellation token in case the orchestration is terminated.
- internal void AddMessageToPendingOrchestration(
+ internal IReadOnlyList AddMessageToPendingOrchestration(
ControlQueue controlQueue,
IEnumerable queueMessages,
Guid traceActivityId,
@@ -414,6 +570,11 @@ internal void AddMessageToPendingOrchestration(
// 3. Do we need to add messages to a currently executing orchestration?
lock (this.messageAndSessionLock)
{
+ if (controlQueue.IsReleased)
+ {
+ return queueMessages.ToList();
+ }
+
var existingSessionMessages = new Dictionary>();
foreach (MessageData data in queueMessages)
@@ -509,6 +670,8 @@ internal void AddMessageToPendingOrchestration(
session.AddOrReplaceMessages(newMessages);
}
}
+
+ return EmptyMessageDataList;
}
// This method runs on a background task thread
@@ -517,6 +680,11 @@ async Task ScheduleOrchestrationStatePrefetch(
Guid traceActivityId,
CancellationToken cancellationToken)
{
+ if (!this.IsPendingBatchActive(node))
+ {
+ return;
+ }
+
PendingMessageBatch batch = node.Value;
AnalyticsEventSource.SetLogicalTraceActivityId(traceActivityId);
@@ -530,6 +698,11 @@ async Task ScheduleOrchestrationStatePrefetch(
batch.OrchestrationExecutionId,
cancellationToken);
+ if (!this.IsPendingBatchActive(node))
+ {
+ return;
+ }
+
batch.OrchestrationState = new OrchestrationRuntimeState(history.Events);
batch.ETags.HistoryETag = history.ETag;
batch.LastCheckpointTime = history.LastCheckpointTime;
@@ -541,20 +714,34 @@ async Task ScheduleOrchestrationStatePrefetch(
InstanceStatus? instanceStatus = await this.trackingStore.FetchInstanceStatusAsync(
batch.OrchestrationInstanceId,
cancellationToken);
+
+ if (!this.IsPendingBatchActive(node))
+ {
+ return;
+ }
+
// The instance could not exist in the case that these messages are for the first execution of a suborchestration,
// or an entity-started orchestration, for example
batch.ETags.InstanceETag = instanceStatus?.ETag;
}
}
- if (this.settings.UseSeparateQueueForEntityWorkItems
- && DurableTask.Core.Common.Entities.IsEntityInstance(batch.OrchestrationInstanceId))
- {
- this.entitiesReadyForProcessingQueue.Enqueue(node);
- }
- else
+ lock (this.messageAndSessionLock)
{
- this.orchestrationsReadyForProcessingQueue.Enqueue(node);
+ if (!this.IsPendingBatchActiveLocked(node))
+ {
+ return;
+ }
+
+ if (this.settings.UseSeparateQueueForEntityWorkItems
+ && DurableTask.Core.Common.Entities.IsEntityInstance(batch.OrchestrationInstanceId))
+ {
+ this.entitiesReadyForProcessingQueue.Enqueue(node);
+ }
+ else
+ {
+ this.orchestrationsReadyForProcessingQueue.Enqueue(node);
+ }
}
}
catch (OperationCanceledException)
@@ -571,14 +758,37 @@ async Task ScheduleOrchestrationStatePrefetch(
e.ToString());
// Sleep briefly to avoid a tight failure loop.
- await Task.Delay(TimeSpan.FromSeconds(5));
+ try
+ {
+ await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken);
+ }
+ catch (OperationCanceledException)
+ {
+ return;
+ }
// This is a background operation so failure is not an option. All exceptions must be handled.
// To avoid starvation, we need to re-enqueue this async operation instead of retrying in a loop.
- await Task.Run(() => this.ScheduleOrchestrationStatePrefetch(node, traceActivityId, cancellationToken));
+ if (this.IsPendingBatchActive(node))
+ {
+ await Task.Run(async () => await this.ScheduleOrchestrationStatePrefetch(node, traceActivityId, cancellationToken));
+ }
}
}
+ bool IsPendingBatchActive(LinkedListNode node)
+ {
+ lock (this.messageAndSessionLock)
+ {
+ return this.IsPendingBatchActiveLocked(node);
+ }
+ }
+
+ bool IsPendingBatchActiveLocked(LinkedListNode node)
+ {
+ return node.List == this.pendingOrchestrationMessageBatches && !node.Value.ControlQueue.IsReleased;
+ }
+
public async Task GetNextSessionAsync(bool entitiesOnly, CancellationToken cancellationToken)
{
var readyForProcessingQueue = entitiesOnly? this.entitiesReadyForProcessingQueue : this.orchestrationsReadyForProcessingQueue;
@@ -589,78 +799,166 @@ async Task ScheduleOrchestrationStatePrefetch(
// 1) a batch of messages has been received for a particular instance and
// 2) the history for that instance has been fetched
LinkedListNode node = await readyForProcessingQueue.DequeueAsync(cancellationToken);
+ ControlQueue? queueToAbandon = null;
+ IReadOnlyList messagesToAbandon = EmptyMessageDataList;
+ bool shouldStopWaitingForSessions = false;
lock (this.messageAndSessionLock)
{
- PendingMessageBatch nextBatch = node.Value;
- this.pendingOrchestrationMessageBatches.Remove(node);
-
- if (!this.activeOrchestrationSessions.TryGetValue(nextBatch.OrchestrationInstanceId, out var existingSession))
+ // Drain may have removed this batch after it was queued for dispatch.
+ if (node.List != this.pendingOrchestrationMessageBatches)
{
- OrchestrationInstance instance = nextBatch.OrchestrationState?.OrchestrationInstance ??
- new OrchestrationInstance
- {
- InstanceId = nextBatch.OrchestrationInstanceId,
- ExecutionId = nextBatch.OrchestrationExecutionId,
- };
-
- Guid traceActivityId = AzureStorageOrchestrationService.StartNewLogicalTraceScope(useExisting: true);
-
- OrchestrationSession session = new OrchestrationSession(
- this.settings,
- this.storageAccountName,
- instance,
- nextBatch.ControlQueue,
- nextBatch.Messages,
- nextBatch.OrchestrationState,
- nextBatch.ETags,
- nextBatch.LastCheckpointTime,
- nextBatch.TrackingStoreContext,
- this.settings.ExtendedSessionIdleTimeout,
- this.shutdownToken,
- traceActivityId);
-
- this.activeOrchestrationSessions.Add(instance.InstanceId, session);
-
- return session;
- }
- else if (nextBatch.OrchestrationExecutionId == existingSession.Instance?.ExecutionId)
- {
- // there is already an active session with the same execution id.
- // The session might be waiting for more messages. If it is, signal them.
- existingSession.AddOrReplaceMessages(node.Value.Messages);
+ shouldStopWaitingForSessions = this.ShouldStopWaitingForSessions(readyForProcessingQueue);
}
else
{
- // A message arrived for a different generation of an existing orchestration instance.
- // Put it back into the ready queue so that it can be processed once the current generation
- // is done executing.
- if (readyForProcessingQueue.Count == 0)
+ PendingMessageBatch nextBatch = node.Value;
+ this.pendingOrchestrationMessageBatches.Remove(node);
+
+ if (nextBatch.ControlQueue.IsReleased)
{
- // To avoid a tight dequeue loop, delay for a bit before putting this node back into the queue.
- // This is only necessary when the queue is empty. The main dequeue thread must not be blocked
- // by this delay, which is why we use Task.Delay(...).ContinueWith(...) instead of await.
- Task.Delay(millisecondsDelay: 200).ContinueWith(_ =>
- {
- lock (this.messageAndSessionLock)
+ queueToAbandon = nextBatch.ControlQueue;
+ messagesToAbandon = nextBatch.Messages.ToList();
+ shouldStopWaitingForSessions = this.ShouldStopWaitingForSessions(readyForProcessingQueue);
+ }
+ else if (!this.activeOrchestrationSessions.TryGetValue(nextBatch.OrchestrationInstanceId, out var existingSession))
+ {
+ OrchestrationInstance instance = nextBatch.OrchestrationState?.OrchestrationInstance ??
+ new OrchestrationInstance
{
- this.pendingOrchestrationMessageBatches.AddLast(node);
- readyForProcessingQueue.Enqueue(node);
- }
- });
+ InstanceId = nextBatch.OrchestrationInstanceId,
+ ExecutionId = nextBatch.OrchestrationExecutionId,
+ };
+
+ Guid traceActivityId = AzureStorageOrchestrationService.StartNewLogicalTraceScope(useExisting: true);
+
+ OrchestrationSession session = new OrchestrationSession(
+ this.settings,
+ this.storageAccountName,
+ instance,
+ nextBatch.ControlQueue,
+ nextBatch.Messages,
+ nextBatch.OrchestrationState,
+ nextBatch.ETags,
+ nextBatch.LastCheckpointTime,
+ nextBatch.TrackingStoreContext,
+ this.settings.ExtendedSessionIdleTimeout,
+ this.shutdownToken,
+ traceActivityId);
+
+ this.activeOrchestrationSessions.Add(instance.InstanceId, session);
+
+ return session;
+ }
+ else if (nextBatch.OrchestrationExecutionId == existingSession.Instance?.ExecutionId)
+ {
+ // there is already an active session with the same execution id.
+ // The session might be waiting for more messages. If it is, signal them.
+ existingSession.AddOrReplaceMessages(node.Value.Messages);
}
else
{
- this.pendingOrchestrationMessageBatches.AddLast(node);
- readyForProcessingQueue.Enqueue(node);
+ // A message arrived for a different generation of an existing orchestration instance.
+ // Put it back into the ready queue so that it can be processed once the current generation
+ // is done executing.
+ if (readyForProcessingQueue.Count == 0)
+ {
+ // To avoid a tight dequeue loop, delay for a bit before putting this node back into the queue.
+ // This is only necessary when the queue is empty. The main dequeue thread must not be blocked
+ // by this delay, which is why it is scheduled without awaiting here.
+ _ = this.RequeuePendingBatchAfterDelayAsync(node, readyForProcessingQueue);
+ }
+ else
+ {
+ this.RequeueOrAbandonPendingBatchLocked(node, readyForProcessingQueue, out queueToAbandon, out messagesToAbandon);
+ }
}
}
}
+
+ if (queueToAbandon != null)
+ {
+ await this.AbandonMessagesForDrainAsync(queueToAbandon, messagesToAbandon);
+ }
+
+ if (shouldStopWaitingForSessions)
+ {
+ return null;
+ }
}
return null;
}
+ async Task RequeuePendingBatchAfterDelayAsync(
+ LinkedListNode node,
+ AsyncQueue> readyForProcessingQueue)
+ {
+ await Task.Delay(millisecondsDelay: 200);
+
+ ControlQueue? queueToAbandon;
+ IReadOnlyList messagesToAbandon;
+ bool shouldStopWaitingForSessions;
+
+ lock (this.messageAndSessionLock)
+ {
+ this.RequeueOrAbandonPendingBatchLocked(node, readyForProcessingQueue, out queueToAbandon, out messagesToAbandon);
+ shouldStopWaitingForSessions = queueToAbandon != null && this.ShouldStopWaitingForSessions(readyForProcessingQueue);
+ }
+
+ if (queueToAbandon != null)
+ {
+ await this.AbandonMessagesForDrainAsync(queueToAbandon, messagesToAbandon);
+ if (shouldStopWaitingForSessions)
+ {
+ readyForProcessingQueue.Enqueue(node);
+ }
+ }
+ }
+
+ void RequeueOrAbandonPendingBatchLocked(
+ LinkedListNode node,
+ AsyncQueue> readyForProcessingQueue,
+ out ControlQueue? queueToAbandon,
+ out IReadOnlyList messagesToAbandon)
+ {
+ if (node.Value.ControlQueue.IsReleased)
+ {
+ queueToAbandon = node.Value.ControlQueue;
+ messagesToAbandon = node.Value.Messages.ToList();
+ return;
+ }
+
+ this.pendingOrchestrationMessageBatches.AddLast(node);
+ readyForProcessingQueue.Enqueue(node);
+ queueToAbandon = null;
+ messagesToAbandon = EmptyMessageDataList;
+ }
+
+ bool ShouldStopWaitingForSessions(AsyncQueue> readyForProcessingQueue)
+ {
+ return readyForProcessingQueue.Count == 0 &&
+ !this.ownedControlQueues.Values.Any(queue => !queue.IsReleased);
+ }
+
+ void SignalSessionWaitersIfNoQueuesRemain(ControlQueue releasedControlQueue)
+ {
+ lock (this.messageAndSessionLock)
+ {
+ if (this.ownedControlQueues.Values.Any(queue => !queue.IsReleased))
+ {
+ return;
+ }
+
+ var orchestratorSentinel = new LinkedListNode(
+ new PendingMessageBatch(releasedControlQueue, string.Empty, executionId: null));
+ var entitySentinel = new LinkedListNode(
+ new PendingMessageBatch(releasedControlQueue, string.Empty, executionId: null));
+ this.orchestrationsReadyForProcessingQueue.Enqueue(orchestratorSentinel);
+ this.entitiesReadyForProcessingQueue.Enqueue(entitySentinel);
+ }
+ }
+
///
/// Immediately removes all active sessions, causing
/// to return false for all partitions. This unblocks so that
@@ -682,8 +980,14 @@ public bool TryGetExistingSession(string instanceId, out OrchestrationSession se
}
}
- public bool TryReleaseSession(string instanceId, CancellationToken cancellationToken, out OrchestrationSession session)
+ public bool TryReleaseSession(
+ string instanceId,
+ CancellationToken cancellationToken,
+ out OrchestrationSession session,
+ out IReadOnlyList messagesToAbandon)
{
+ messagesToAbandon = EmptyMessageDataList;
+
// Taking this lock ensures we don't add new messages to a session we're about to release.
lock (this.messageAndSessionLock)
{
@@ -693,7 +997,7 @@ public bool TryReleaseSession(string instanceId, CancellationToken cancellationT
this.activeOrchestrationSessions.Remove(instanceId))
{
// Put any unprocessed messages back into the pending buffer.
- this.AddMessageToPendingOrchestration(
+ messagesToAbandon = this.AddMessageToPendingOrchestration(
session.ControlQueue,
session.PendingMessages.Concat(session.DeferredMessages),
session.TraceActivityId,
diff --git a/src/DurableTask.AzureStorage/Partitioning/TablePartitionManager.cs b/src/DurableTask.AzureStorage/Partitioning/TablePartitionManager.cs
index 3fad8c98d..a55685ca4 100644
--- a/src/DurableTask.AzureStorage/Partitioning/TablePartitionManager.cs
+++ b/src/DurableTask.AzureStorage/Partitioning/TablePartitionManager.cs
@@ -362,7 +362,7 @@ public async Task ReadAndWriteTableAsync(bool isShuttingDown,
// In a worker becomes unhealthy, it may lose a lease without realizing it and continue listening
// for messages. We check for that case here and stop dequeuing messages if we discover that
// another worker currently owns the lease.
- this.service.DropLostControlQueue(partition);
+ await this.service.DropLostControlQueue(partition);
bool claimedLease = false;
bool stoleLease = false;
diff --git a/test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs b/test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs
new file mode 100644
index 000000000..bb6dafdb2
--- /dev/null
+++ b/test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs
@@ -0,0 +1,783 @@
+// ----------------------------------------------------------------------------------
+// Copyright Microsoft Corporation
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ----------------------------------------------------------------------------------
+
+namespace DurableTask.AzureStorage.Tests
+{
+ using System;
+ using System.Collections.Concurrent;
+ using System.Collections.Generic;
+ using System.Diagnostics;
+ using System.Linq;
+ using System.Reflection;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Azure;
+ using Azure.Storage.Queues;
+ using Azure.Storage.Queues.Models;
+ using DurableTask.AzureStorage.Messaging;
+ using DurableTask.AzureStorage.Monitoring;
+ using DurableTask.AzureStorage.Partitioning;
+ using DurableTask.AzureStorage.Storage;
+ using DurableTask.AzureStorage.Tracking;
+ using DurableTask.Core;
+ using DurableTask.Core.History;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using Moq;
+
+ ///
+ /// Tests for shutdown cancellation behavior with extended sessions.
+ ///
+ [TestClass]
+ public class OrchestrationSessionTests
+ {
+ ///
+ /// Verifies that
+ /// exits immediately when the cancellation token is cancelled.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_CancellationToken_ExitsImmediately()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+
+ TimeSpan longTimeout = TimeSpan.FromSeconds(30);
+ Task waitTask = resetEvent.WaitAsync(longTimeout, cts.Token);
+
+ Assert.IsFalse(waitTask.IsCompleted, "Wait should not complete immediately");
+
+ var stopwatch = Stopwatch.StartNew();
+ cts.Cancel();
+
+ bool result = await waitTask;
+ stopwatch.Stop();
+
+ Assert.IsFalse(result, "Cancellation should return false (no signal received)");
+ Assert.IsTrue(
+ stopwatch.ElapsedMilliseconds < 5000,
+ $"Cancellation should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms");
+ }
+
+ ///
+ /// Verifies that signaling still returns true when a cancellation token is provided.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_WithCancellationToken_SignalStillWorks()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+
+ Task waitTask = resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
+ Assert.IsFalse(waitTask.IsCompleted);
+
+ resetEvent.Set();
+
+ Task winner = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5)));
+ Assert.IsTrue(winner == waitTask, "Signal should wake the waiter");
+ Assert.IsTrue(waitTask.Result, "Wait result should be true when signaled");
+ }
+
+ ///
+ /// Verifies that the wait returns false on timeout when a cancellation token is provided but not cancelled.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_WithCancellationToken_TimeoutStillWorks()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+
+ bool result = await resetEvent.WaitAsync(TimeSpan.FromMilliseconds(100), cts.Token);
+
+ Assert.IsFalse(result, "Wait should return false on timeout");
+ }
+
+ ///
+ /// Verifies that all queued waiters return false when the token is cancelled.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_CancellationToken_MultipleWaiters()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+
+ var waiters = new List>();
+ for (int i = 0; i < 5; i++)
+ {
+ waiters.Add(resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token));
+ }
+
+ foreach (var waiter in waiters)
+ {
+ Assert.IsFalse(waiter.IsCompleted);
+ }
+
+ var stopwatch = Stopwatch.StartNew();
+ cts.Cancel();
+
+ // All waiters should return false (cancelled = not signaled)
+ await Task.WhenAll(
+ waiters.Select(
+ async waiter =>
+ {
+ bool result = await waiter;
+ Assert.IsFalse(result, "Cancelled waiter should return false");
+ }));
+
+ stopwatch.Stop();
+
+ Assert.IsTrue(
+ stopwatch.ElapsedMilliseconds < 5000,
+ $"All waiters should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms");
+ }
+
+ ///
+ /// Verifies that a pre-cancelled token causes WaitAsync to return false immediately.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_AlreadyCancelledToken_ReturnsFalseImmediately()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+ cts.Cancel(); // Pre-cancel
+
+ var stopwatch = Stopwatch.StartNew();
+ bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
+ stopwatch.Stop();
+
+ Assert.IsFalse(result, "Pre-cancelled token should cause immediate false return");
+ Assert.IsTrue(
+ stopwatch.ElapsedMilliseconds < 5000,
+ $"Should complete immediately, but took {stopwatch.ElapsedMilliseconds}ms");
+ }
+
+ ///
+ /// Verifies that a pre-cancelled token still returns true if the event is already signaled.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_AlreadySignaledAndCancelled_ReturnsTrue()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: true);
+ using var cts = new CancellationTokenSource();
+ cts.Cancel();
+
+ bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
+ Assert.IsTrue(result, "Already signaled event should return true even with cancelled token");
+ }
+
+ ///
+ /// Verifies that clears all active sessions.
+ ///
+ [TestMethod]
+ public void AbortAllSessions_ClearsActiveSessions()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings();
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ // Use reflection to access the internal sessions dictionary.
+ var sessionsField = typeof(OrchestrationSessionManager)
+ .GetField("activeOrchestrationSessions", BindingFlags.NonPublic | BindingFlags.Instance);
+ var sessions = (Dictionary)sessionsField.GetValue(manager);
+
+ manager.GetStats(out _, out _, out int initialCount);
+ Assert.AreEqual(0, initialCount, "Should start with no active sessions");
+
+ sessions["instance1"] = null;
+ sessions["instance2"] = null;
+ sessions["instance3"] = null;
+
+ manager.GetStats(out _, out _, out int activeCount);
+ Assert.AreEqual(3, activeCount, "Should have 3 active sessions");
+
+ manager.AbortAllSessions();
+
+ manager.GetStats(out _, out _, out int afterAbortCount);
+ Assert.AreEqual(0, afterAbortCount, "AbortAllSessions should clear all active sessions");
+ }
+
+ ///
+ /// Verifies that is safe to call with no active sessions.
+ ///
+ [TestMethod]
+ public void AbortAllSessions_NoSessions_DoesNotThrow()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings();
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ manager.AbortAllSessions();
+
+ manager.GetStats(out _, out _, out int count);
+ Assert.AreEqual(0, count, "Should still have no active sessions");
+ }
+
+ [TestMethod]
+ public async Task GetNextSessionAsync_DrainedReadyQueueNode_ReturnsNullWhenNoQueuesRemain()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings
+ {
+ StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"),
+ };
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ var storageClient = new AzureStorageClient(settings);
+ var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName);
+ var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager);
+
+ object pendingBatch = CreatePendingBatch(controlQueue);
+ object node = AddPendingBatchNode(manager, pendingBatch);
+ RemovePendingBatchNode(manager, node);
+ EnqueueReadyForProcessingNode(manager, node);
+
+ using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(100));
+ OrchestrationSession session = await manager.GetNextSessionAsync(entitiesOnly: false, cts.Token);
+
+ Assert.IsNull(session, "Detached ready nodes should not block dispatch when no queues remain.");
+ }
+
+ [TestMethod]
+ public async Task ScheduleOrchestrationStatePrefetch_DetachedNode_DoesNotFetchHistory()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings
+ {
+ StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"),
+ };
+ var stats = new AzureStorageOrchestrationServiceStats();
+ int fetchCount = 0;
+ var trackingStore = new Mock();
+ trackingStore
+ .Setup(t => t.GetHistoryEventsAsync("instance1", "execution1", It.IsAny()))
+ .Callback(() => fetchCount++)
+ .ThrowsAsync(new OperationCanceledException());
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ var storageClient = new AzureStorageClient(settings);
+ var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName);
+ var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager);
+
+ object pendingBatch = CreatePendingBatch(controlQueue);
+ object node = AddPendingBatchNode(manager, pendingBatch);
+ RemovePendingBatchNode(manager, node);
+
+ await InvokeScheduleOrchestrationStatePrefetch(manager, node, CancellationToken.None);
+
+ Assert.AreEqual(0, fetchCount, "Detached pending batches should not fetch orchestration history.");
+ }
+
+ [TestMethod]
+ public void AddMessageToPendingOrchestration_ReleasedControlQueue_ReturnsMessagesToAbandon()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings
+ {
+ StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"),
+ };
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ var storageClient = new AzureStorageClient(settings);
+ var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName);
+ var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager);
+ controlQueue.Release(null, "test");
+
+ MessageData message = CreateMessageData();
+ MethodInfo addMessage = typeof(OrchestrationSessionManager).GetMethod(
+ "AddMessageToPendingOrchestration",
+ BindingFlags.Instance | BindingFlags.NonPublic);
+
+ var messagesToAbandon = (IReadOnlyList)addMessage.Invoke(
+ manager,
+ new object[] { controlQueue, new[] { message }, Guid.NewGuid(), CancellationToken.None });
+
+ Assert.IsNotNull(messagesToAbandon, "Released queue messages should be returned for immediate abandon.");
+ Assert.AreEqual(1, messagesToAbandon.Count);
+ Assert.AreSame(message, messagesToAbandon[0]);
+
+ manager.GetStats(out int pendingOrchestratorInstances, out _, out _);
+ Assert.AreEqual(0, pendingOrchestratorInstances, "Released queue messages should not be added to pending batches.");
+ }
+
+ [TestMethod]
+ public async Task RemoveQueue_PendingBatch_AbandonsMessages()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings
+ {
+ StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"),
+ };
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ var storageClient = new AzureStorageClient(settings);
+ var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName);
+ var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager);
+ AddOwnedControlQueue(manager, "partition-0", controlQueue);
+
+ MessageData message = CreateMessageData();
+ int abandonCount = 0;
+ var queueClient = new Mock();
+ queueClient.SetupGet(q => q.Name).Returns("partition-0");
+ queueClient
+ .Setup(
+ q => q.UpdateMessageAsync(
+ message.OriginalQueueMessage.MessageId,
+ message.OriginalQueueMessage.PopReceipt,
+ It.IsAny(),
+ TimeSpan.Zero,
+ It.IsAny()))
+ .Callback(() => abandonCount++)
+ .ReturnsAsync(Response.FromValue(
+ QueuesModelFactory.UpdateReceipt("newPopReceipt", DateTimeOffset.UtcNow),
+ Mock.Of()));
+ SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object);
+
+ object pendingBatch = CreatePendingBatch(controlQueue);
+ AddMessageToPendingBatch(pendingBatch, message);
+ AddPendingBatchNode(manager, pendingBatch);
+
+ await InvokeRemoveQueue(manager, "partition-0");
+
+ Assert.AreEqual(0, GetPendingBatchCount(manager), "RemoveQueue should remove pending batches for the released queue.");
+ Assert.AreEqual(1, abandonCount, "RemoveQueue should immediately abandon pending messages for the released queue.");
+ }
+
+ [TestMethod]
+ public async Task RemoveQueue_PendingBatch_ReturnsNullToWaitingDispatcher()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings
+ {
+ StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"),
+ };
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ var storageClient = new AzureStorageClient(settings);
+ var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName);
+ var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager);
+ AddOwnedControlQueue(manager, "partition-0", controlQueue);
+
+ MessageData message = CreateMessageData();
+ var queueClient = new Mock();
+ queueClient.SetupGet(q => q.Name).Returns("partition-0");
+ queueClient
+ .Setup(
+ q => q.UpdateMessageAsync(
+ message.OriginalQueueMessage.MessageId,
+ message.OriginalQueueMessage.PopReceipt,
+ It.IsAny(),
+ TimeSpan.Zero,
+ It.IsAny()))
+ .ReturnsAsync(Response.FromValue(
+ QueuesModelFactory.UpdateReceipt("newPopReceipt", DateTimeOffset.UtcNow),
+ Mock.Of()));
+ SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object);
+
+ object pendingBatch = CreatePendingBatch(controlQueue);
+ AddMessageToPendingBatch(pendingBatch, message);
+ AddPendingBatchNode(manager, pendingBatch);
+
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1));
+ Task getNextTask = manager.GetNextSessionAsync(entitiesOnly: false, cts.Token);
+
+ await InvokeRemoveQueue(manager, "partition-0");
+ OrchestrationSession session = await getNextTask;
+
+ Assert.IsNull(session, "Removing the last queue should unblock dispatchers waiting for a session.");
+ }
+
+ [TestMethod]
+ public async Task WaitForDequeueLoopToStopAsync_FaultedDequeueLoop_PropagatesUnexpectedException()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings();
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ var dequeueLoopTasks = (ConcurrentDictionary)GetPrivateField(manager, "dequeueLoopTasks");
+ var expected = new InvalidOperationException("unexpected dequeue loop failure");
+ dequeueLoopTasks["partition-0"] = Task.FromException(expected);
+
+ MethodInfo wait = typeof(OrchestrationSessionManager).GetMethod(
+ "WaitForDequeueLoopToStopAsync",
+ BindingFlags.NonPublic | BindingFlags.Instance);
+
+ Task waitTask = (Task)wait.Invoke(manager, new object[] { "partition-0", CancellationToken.None });
+
+ InvalidOperationException actual = await Assert.ThrowsExceptionAsync(
+ () => waitTask);
+
+ Assert.AreSame(expected, actual);
+ }
+
+ [TestMethod]
+ public async Task AbandonMessageForDrainAsync_DurableTaskStorageException_DoesNotThrow()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings
+ {
+ StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"),
+ };
+ var storageClient = new AzureStorageClient(settings);
+ var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName);
+ var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager);
+
+ var queueClient = new Mock();
+ queueClient.SetupGet(q => q.Name).Returns("partition-0");
+ queueClient
+ .Setup(
+ q => q.UpdateMessageAsync(
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny()))
+ .ThrowsAsync(new RequestFailedException(404, "queue update failed"));
+ SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object);
+
+ await controlQueue.AbandonMessageForDrainAsync(CreateMessageData());
+ }
+
+ [TestMethod]
+ public async Task GetNextSessionAsync_ReleasedDelayedRequeueNode_AbandonsMessagesAndReturnsNullWhenNoQueuesRemain()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings
+ {
+ StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"),
+ };
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ var storageClient = new AzureStorageClient(settings);
+ var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName);
+ var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager);
+ MessageData message = CreateMessageData();
+ int abandonCount = 0;
+ var queueClient = new Mock();
+ queueClient.SetupGet(q => q.Name).Returns("partition-0");
+ queueClient
+ .Setup(
+ q => q.UpdateMessageAsync(
+ message.OriginalQueueMessage.MessageId,
+ message.OriginalQueueMessage.PopReceipt,
+ It.IsAny(),
+ TimeSpan.Zero,
+ It.IsAny()))
+ .Callback(() => abandonCount++)
+ .ReturnsAsync(Response.FromValue(
+ QueuesModelFactory.UpdateReceipt("newPopReceipt", DateTimeOffset.UtcNow),
+ Mock.Of()));
+ SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object);
+
+ AddActiveSession(manager, settings, controlQueue, "instance1", "activeExecution");
+ object pendingBatch = CreatePendingBatch(controlQueue);
+ AddMessageToPendingBatch(pendingBatch, message);
+ object node = AddPendingBatchNode(manager, pendingBatch);
+ EnqueueReadyForProcessingNode(manager, node);
+
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2));
+ Task getNextTask = manager.GetNextSessionAsync(entitiesOnly: false, cts.Token);
+
+ await WaitUntilAsync(() => IsNodeDetached(node), TimeSpan.FromSeconds(2));
+ controlQueue.Release(null, "test");
+ OrchestrationSession session = await getNextTask;
+
+ Assert.IsNull(session, "Released delayed requeue nodes should not block dispatch when no queues remain.");
+ await WaitUntilAsync(() => abandonCount == 1, TimeSpan.FromSeconds(2));
+
+ Assert.AreEqual(0, GetPendingBatchCount(manager), "Released queue nodes should not be requeued after a delay.");
+ Assert.AreEqual(0, GetReadyQueueCount(manager), "Released queue nodes should not be made ready for dispatch.");
+ Assert.AreEqual(1, abandonCount, "Messages from a released delayed requeue node should be immediately abandoned.");
+ }
+
+ [TestMethod]
+ public async Task GetNextSessionAsync_ReleasedReadyQueueNode_AbandonsMessagesAndReturnsNullWhenNoQueuesRemain()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings
+ {
+ StorageAccountClientProvider = new StorageAccountClientProvider("UseDevelopmentStorage=true"),
+ };
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ var storageClient = new AzureStorageClient(settings);
+ var messageManager = new MessageManager(settings, storageClient, settings.TaskHubName);
+ var controlQueue = new ControlQueue(storageClient, "partition-0", messageManager);
+ MessageData message = CreateMessageData();
+ int abandonCount = 0;
+ var queueClient = new Mock();
+ queueClient.SetupGet(q => q.Name).Returns("partition-0");
+ queueClient
+ .Setup(
+ q => q.UpdateMessageAsync(
+ message.OriginalQueueMessage.MessageId,
+ message.OriginalQueueMessage.PopReceipt,
+ It.IsAny(),
+ TimeSpan.Zero,
+ It.IsAny()))
+ .Callback(() => abandonCount++)
+ .ReturnsAsync(Response.FromValue(
+ QueuesModelFactory.UpdateReceipt("newPopReceipt", DateTimeOffset.UtcNow),
+ Mock.Of()));
+ SetPrivateField(controlQueue.InnerQueue, "queueClient", queueClient.Object);
+
+ object pendingBatch = CreatePendingBatch(controlQueue);
+ AddMessageToPendingBatch(pendingBatch, message);
+ object node = AddPendingBatchNode(manager, pendingBatch);
+ EnqueueReadyForProcessingNode(manager, node);
+ controlQueue.Release(null, "test");
+
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1));
+ OrchestrationSession session = await manager.GetNextSessionAsync(entitiesOnly: false, cts.Token);
+
+ Assert.IsNull(session, "Released queue nodes should not block dispatch when no queues remain.");
+ Assert.AreEqual(0, GetPendingBatchCount(manager), "Released ready nodes should be removed from pending batches.");
+ Assert.AreEqual(1, abandonCount, "Messages from released ready nodes should be immediately abandoned.");
+ }
+
+ static object CreatePendingBatch(ControlQueue controlQueue)
+ {
+ Type pendingBatchType = typeof(OrchestrationSessionManager)
+ .GetNestedType("PendingMessageBatch", BindingFlags.NonPublic);
+
+ return Activator.CreateInstance(
+ pendingBatchType,
+ BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
+ binder: null,
+ args: new object[] { controlQueue, "instance1", "execution1" },
+ culture: null);
+ }
+
+ static object AddPendingBatchNode(OrchestrationSessionManager manager, object pendingBatch)
+ {
+ object pendingBatches = GetPrivateField(manager, "pendingOrchestrationMessageBatches");
+ MethodInfo addLast = pendingBatches.GetType().GetMethod("AddLast", new[] { pendingBatch.GetType() });
+ return addLast.Invoke(pendingBatches, new[] { pendingBatch });
+ }
+
+ static void RemovePendingBatchNode(OrchestrationSessionManager manager, object node)
+ {
+ object pendingBatches = GetPrivateField(manager, "pendingOrchestrationMessageBatches");
+ MethodInfo remove = pendingBatches.GetType().GetMethod("Remove", new[] { node.GetType() });
+ remove.Invoke(pendingBatches, new[] { node });
+ }
+
+ static void EnqueueReadyForProcessingNode(OrchestrationSessionManager manager, object node)
+ {
+ object readyQueue = GetPrivateField(manager, "orchestrationsReadyForProcessingQueue");
+ MethodInfo enqueue = readyQueue.GetType().GetMethod("Enqueue");
+ enqueue.Invoke(readyQueue, new[] { node });
+ }
+
+ static void AddOwnedControlQueue(OrchestrationSessionManager manager, string partitionId, ControlQueue controlQueue)
+ {
+ var ownedQueues = (ConcurrentDictionary)GetPrivateField(manager, "ownedControlQueues");
+ ownedQueues[partitionId] = controlQueue;
+ }
+
+ static async Task InvokeRemoveQueue(OrchestrationSessionManager manager, string partitionId)
+ {
+ MethodInfo removeQueue = typeof(OrchestrationSessionManager).GetMethod("RemoveQueue");
+ object result = removeQueue.Invoke(manager, new object[] { partitionId, CloseReason.LeaseLost, "test" });
+ if (result is Task task)
+ {
+ await task;
+ }
+ }
+
+ static void AddMessageToPendingBatch(object pendingBatch, MessageData message)
+ {
+ var messages = (ICollection)pendingBatch.GetType().GetProperty("Messages").GetValue(pendingBatch);
+ messages.Add(message);
+ }
+
+ static Task InvokeScheduleOrchestrationStatePrefetch(
+ OrchestrationSessionManager manager,
+ object node,
+ CancellationToken cancellationToken)
+ {
+ MethodInfo schedule = typeof(OrchestrationSessionManager).GetMethod(
+ "ScheduleOrchestrationStatePrefetch",
+ BindingFlags.NonPublic | BindingFlags.Instance);
+
+ return (Task)schedule.Invoke(manager, new[] { node, Guid.NewGuid(), cancellationToken });
+ }
+
+ static void AddActiveSession(
+ OrchestrationSessionManager manager,
+ AzureStorageOrchestrationServiceSettings settings,
+ ControlQueue controlQueue,
+ string instanceId,
+ string executionId)
+ {
+ var sessions = (Dictionary)GetPrivateField(manager, "activeOrchestrationSessions");
+ var instance = new OrchestrationInstance
+ {
+ InstanceId = instanceId,
+ ExecutionId = executionId,
+ };
+ var runtimeState = new OrchestrationRuntimeState();
+ runtimeState.AddEvent(new ExecutionStartedEvent(-1, string.Empty)
+ {
+ OrchestrationInstance = instance,
+ });
+
+ sessions[instanceId] = new OrchestrationSession(
+ settings,
+ "testaccount",
+ instance,
+ controlQueue,
+ new List(),
+ runtimeState,
+ eTags: null,
+ DateTime.UtcNow,
+ trackingStoreContext: null,
+ TimeSpan.FromSeconds(30),
+ CancellationToken.None,
+ Guid.NewGuid());
+ }
+
+ static bool IsNodeDetached(object node)
+ {
+ object list = node.GetType().GetProperty("List").GetValue(node);
+ return list == null;
+ }
+
+ static int GetPendingBatchCount(OrchestrationSessionManager manager)
+ {
+ object pendingBatches = GetPrivateField(manager, "pendingOrchestrationMessageBatches");
+ return (int)pendingBatches.GetType().GetProperty("Count").GetValue(pendingBatches);
+ }
+
+ static int GetReadyQueueCount(OrchestrationSessionManager manager)
+ {
+ object readyQueue = GetPrivateField(manager, "orchestrationsReadyForProcessingQueue");
+ return (int)readyQueue.GetType().GetProperty("Count").GetValue(readyQueue);
+ }
+
+ static async Task WaitUntilAsync(Func condition, TimeSpan timeout)
+ {
+ var stopwatch = Stopwatch.StartNew();
+ while (!condition())
+ {
+ if (stopwatch.Elapsed > timeout)
+ {
+ Assert.Fail("Condition was not reached before timeout.");
+ }
+
+ await Task.Delay(TimeSpan.FromMilliseconds(10));
+ }
+ }
+
+ static MessageData CreateMessageData()
+ {
+ var instance = new OrchestrationInstance
+ {
+ InstanceId = "instance1",
+ ExecutionId = "execution1",
+ };
+
+ var taskMessage = new TaskMessage
+ {
+ OrchestrationInstance = instance,
+ Event = new TimerFiredEvent(0),
+ };
+
+ var message = new MessageData(
+ taskMessage,
+ Guid.NewGuid(),
+ "partition-0",
+ orchestrationEpisode: null,
+ sender: instance);
+
+ message.OriginalQueueMessage = QueuesModelFactory.QueueMessage(
+ Guid.NewGuid().ToString("N"),
+ Guid.NewGuid().ToString("N"),
+ string.Empty,
+ 1,
+ DateTimeOffset.UtcNow,
+ DateTimeOffset.UtcNow.AddHours(1),
+ DateTimeOffset.UtcNow.AddMinutes(5));
+
+ return message;
+ }
+
+ static object GetPrivateField(object target, string fieldName)
+ {
+ FieldInfo field = target.GetType().GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Instance);
+ Assert.IsNotNull(field);
+ return field.GetValue(target);
+ }
+
+ static void SetPrivateField(object target, string fieldName, object value)
+ {
+ FieldInfo field = target.GetType().GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Instance);
+ Assert.IsNotNull(field);
+ field.SetValue(target, value);
+ }
+ }
+}