diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterRestorationRaceIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterRestorationRaceIntegrationTest.java new file mode 100644 index 0000000000000..c30bb047a3f7d --- /dev/null +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterRestorationRaceIntegrationTest.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.LogCaptureAppender; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.TaskManager; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.AbstractStoreBuilder; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.TestUtils; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.junit.jupiter.api.Timeout; + +import java.io.IOException; +import java.time.Duration; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Integration test that verifies the deferred future tracking fix for the race condition + * between the StateUpdater thread and the StreamThread when a rebalance triggers task + * removal while the StateUpdater is blocked during changelog restoration. + * + *

Without the fix, the race condition chain is: + *

    + *
  1. StateUpdater thread is blocked in restoration (e.g., RocksDB write stall)
  2. + *
  3. A rebalance occurs, StreamThread calls {@code waitForFuture()} which times out
  4. + *
  5. The task is silently dropped — nobody tracks it
  6. + *
  7. StateUpdater eventually processes the REMOVE, suspends the task (stores NOT closed)
  8. + *
  9. The orphaned task holds the RocksDB file LOCK
  10. + *
  11. On restart, {@code RocksDB.open()} fails with {@code ProcessorStateException}
  12. + *
+ * + *

With the fix ({@code pendingRemoveFutures} tracking in {@code TaskManager}): + *

    + *
  1. When {@code waitForFuture()} times out, the future is stashed (not discarded)
  2. + *
  3. On the next {@code checkStateUpdater()} call, completed futures are polled
  4. + *
  5. The returned task is closed dirty, releasing the RocksDB LOCK
  6. + *
  7. On restart, {@code RocksDB.open()} succeeds — no LOCK conflict
  8. + *
+ * + * @see KIP-1035 + */ +@Timeout(120) +public class StateUpdaterRestorationRaceIntegrationTest { + + private static final int NUM_BROKERS = 1; + private static final String INPUT_TOPIC = "input-topic"; + private static final String BLOCKING_STORE_NAME = "blocking-store"; + private static final String ROCKSDB_STORE_NAME = "rocksdb-store"; + private static final int NUM_PARTITIONS = 6; + + private final EmbeddedKafkaCluster cluster = new EmbeddedKafkaCluster(NUM_BROKERS); + + private String appId; + private KafkaStreams streams1; + private KafkaStreams streams2; + + // Controls whether the restore callback should block + private final AtomicBoolean blockDuringRestore = new AtomicBoolean(false); + // Ensures only the first restore record triggers the block + private final AtomicBoolean hasBlocked = new AtomicBoolean(false); + // Signaled when restoration has started (StateUpdater is in restore) + private final CountDownLatch restorationStartedLatch = new CountDownLatch(1); + // Released to unblock the StateUpdater's restore callback + private final CountDownLatch restoreBlockLatch = new CountDownLatch(1); + + @BeforeEach + public void before(final TestInfo testInfo) throws InterruptedException, IOException { + cluster.start(); + cluster.createTopic(INPUT_TOPIC, NUM_PARTITIONS, 1); + appId = "app-" + safeUniqueTestName(testInfo); + } + + @AfterEach + public void after() { + // Release the block latch in case the test failed before doing so + restoreBlockLatch.countDown(); + if (streams1 != null) { + streams1.close(Duration.ofSeconds(30)); + } + if (streams2 != null) { + streams2.close(Duration.ofSeconds(30)); + } + cluster.stop(); + } + + /** + * Verifies that when {@code waitForFuture()} times out during a rebalance, the deferred + * future tracking in {@code TaskManager} cleans up the leaked task and releases the RocksDB + * LOCK, allowing a subsequent restart to succeed without {@code ProcessorStateException}. + * + *

Uses a two-store topology: + *

+ * + *

Test flow: + *

    + *
  1. Start instance 1 — both stores initialized, RocksDB LOCK acquired, restoration blocks
  2. + *
  3. Start instance 2 → rebalance → {@code waitForFuture()} times out → future stashed
  4. + *
  5. Unblock restoration → StateUpdater processes REMOVE → future completes
  6. + *
  7. Next {@code checkStateUpdater()} call → {@code processPendingRemoveFutures()} → + * task closed dirty → RocksDB LOCK released
  8. + *
  9. Close both instances, restart instance 1 with same state directory
  10. + *
  11. Assert: instance starts successfully (no ProcessorStateException)
  12. + *
+ * + *

Without the fix, step 6 would fail with {@code ProcessorStateException: Error opening store} + * because the orphaned task's RocksDB handle would still hold the file LOCK. + */ + @Test + public void shouldCleanUpLeakedTaskAndReleaseRocksDBLockAfterWaitForFutureTimeout() throws Exception { + blockDuringRestore.set(true); + + // Pre-create changelog topics for both stores + final String blockingChangelog = appId + "-" + BLOCKING_STORE_NAME + "-changelog"; + final String rocksdbChangelog = appId + "-" + ROCKSDB_STORE_NAME + "-changelog"; + cluster.createTopic(blockingChangelog, NUM_PARTITIONS, 1); + cluster.createTopic(rocksdbChangelog, NUM_PARTITIONS, 1); + populateChangelog(blockingChangelog, 50); + populateChangelog(rocksdbChangelog, 50); + + final String stateDir1 = TestUtils.tempDirectory().getPath(); + final Properties props1 = props(stateDir1); + // waitForFuture timeout = maxPollIntervalMs / 2 = 7.5s + props1.put(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 15_000); + + try (final LogCaptureAppender taskManagerAppender = LogCaptureAppender.createAndRegister(TaskManager.class)) { + streams1 = new KafkaStreams(buildTopologyWithRocksDB(), props1); + streams1.start(); + + // Wait for the StateUpdater to begin restoration (and block on the latch) + assertTrue( + restorationStartedLatch.await(30, TimeUnit.SECONDS), + "Restoration never started on instance 1" + ); + + // Start instance 2 to trigger a rebalance while restoration is blocked. + // StreamThread will call handleTasksInStateUpdater() → waitForFuture() → timeout + // → future stashed in pendingRemoveFutures + final String stateDir2 = TestUtils.tempDirectory().getPath(); + streams2 = new KafkaStreams(buildTopologyWithRocksDB(), props(stateDir2)); + streams2.start(); + + // Wait for waitForFuture timeout — proves the deferred tracking path was exercised. + TestUtils.waitForCondition( + () -> taskManagerAppender.getMessages().stream() + .anyMatch(msg -> msg.contains("Deferring cleanup to next checkStateUpdater()")), + 30_000, + "Expected waitForFuture() to time out and defer cleanup" + ); + + // Unblock restoration so the StateUpdater can process the pending REMOVE action. + // With the fix: the completed future is picked up by processPendingRemoveFutures() + // on the next checkStateUpdater() call, the task is closed dirty, and the RocksDB + // LOCK is released. + restoreBlockLatch.countDown(); + + // Wait for processPendingRemoveFutures() to clean up the leaked task: + // 1. StateUpdater finishes restoration and processes the REMOVE + // 2. checkStateUpdater() calls processPendingRemoveFutures() + // 3. closeTaskDirty() closes the RocksDB handle and releases the LOCK + TestUtils.waitForCondition( + () -> taskManagerAppender.getMessages().stream() + .anyMatch(msg -> msg.contains("Processing deferred removal of task")), + 30_000, + "Expected processPendingRemoveFutures() to clean up the leaked task" + ); + + // Close both instances + streams2.close(Duration.ofSeconds(10)); + streams2 = null; + streams1.close(Duration.ofSeconds(10)); + streams1 = null; + + // Restart instance 1 with the SAME state directory. + // Without the fix: the orphaned task's RocksDB handle still holds the file LOCK + // → RocksDB.open() fails → ProcessorStateException + // With the fix: the task was closed dirty, LOCK released + // → RocksDB.open() succeeds → instance starts normally + blockDuringRestore.set(false); + streams1 = new KafkaStreams(buildTopologyWithRocksDB(), props1); + streams1.start(); + + // Assert: instance starts successfully and reaches RUNNING. + // Without the fix, this would fail with ProcessorStateException from the + // RocksDB LOCK conflict. With the fix, processPendingRemoveFutures() closed + // the leaked task dirty, releasing the LOCK before restart. + TestUtils.waitForCondition( + () -> streams1.state() == KafkaStreams.State.RUNNING, + 30_000, + "Instance should reach RUNNING state — the deferred future tracking " + + "should have cleaned up the leaked task and released the RocksDB LOCK" + ); + } + } + + /** + * Builds a topology with two stores: a blocking MockKeyValueStore (to halt the StateUpdater + * during restoration) and a real RocksDB store (whose file LOCK causes ProcessorStateException + * when the task is leaked and later reassigned). + */ + private TopologyWrapper buildTopologyWithRocksDB() { + // Store 1: MockKeyValueStore with blocking restore callback + final StoreBuilder> blockingStoreBuilder = + new AbstractStoreBuilder<>(BLOCKING_STORE_NAME, Serdes.Integer(), Serdes.String(), new MockTime()) { + @Override + public KeyValueStore build() { + return new MockKeyValueStore(name, true) { + @Override + public void init(final StateStoreContext stateStoreContext, final StateStore root) { + stateStoreContext.register(root, (key, value) -> { + if (blockDuringRestore.get() && !hasBlocked.getAndSet(true)) { + restorationStartedLatch.countDown(); + try { + restoreBlockLatch.await(60, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); + initialized = true; + closed = false; + } + }; + } + }; + + // Store 2: Real RocksDB store — acquires file LOCK during init + final StoreBuilder> rocksdbStoreBuilder = + Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore(ROCKSDB_STORE_NAME), + Serdes.Integer(), Serdes.String()); + + final TopologyWrapper topology = new TopologyWrapper(); + topology.addSource("source", INPUT_TOPIC); + topology.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + topology.addStateStore(blockingStoreBuilder, "processor"); + topology.addStateStore(rocksdbStoreBuilder, "processor"); + return topology; + } + + private Properties props(final String stateDir) { + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()); + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + props.put(StreamsConfig.STATE_DIR_CONFIG, stateDir); + props.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0); + props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1); + props.put(StreamsConfig.mainConsumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG), 10_000); + props.put(StreamsConfig.mainConsumerPrefix(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG), 3_000); + return props; + } + + /** + * Produce records directly to the changelog topic so that restoration is needed + * when an instance starts with an empty state directory. + */ + private void populateChangelog(final String changelogTopic, final int recordsPerPartition) { + final Properties producerConfig = new Properties(); + producerConfig.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()); + producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + producerConfig.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + + try (final KafkaProducer producer = new KafkaProducer<>(producerConfig)) { + for (int partition = 0; partition < NUM_PARTITIONS; partition++) { + for (int i = 0; i < recordsPerPartition; i++) { + producer.send(new ProducerRecord<>(changelogTopic, partition, i, "value-" + i)); + } + } + producer.flush(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java index 833d42aeae4ea..0c5577153cd14 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -483,6 +483,10 @@ public static StreamThread create(final TopologyMetadata topologyMetadata, threadIdx ); + final long maxPollIntervalMs = Integer.parseInt( + config.getMainConsumerConfigs("dummy", "dummy", threadIdx) + .getOrDefault(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "300000").toString()); + final TaskManager taskManager = new TaskManager( time, changelogReader, @@ -495,7 +499,8 @@ public static StreamThread create(final TopologyMetadata topologyMetadata, adminClient, stateDirectory, stateUpdater, - schedulingTaskManager + schedulingTaskManager, + maxPollIntervalMs * 3 / 4 ); referenceContainer.taskManager = taskManager; diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index 4eb2ad36fc8ad..f55f332a44926 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -110,6 +110,11 @@ public class TaskManager { private final StandbyTaskCreator standbyTaskCreator; private final StateUpdater stateUpdater; private final DefaultTaskManager schedulingTaskManager; + private final long waitForFutureTimeoutMs; + + + private final Map> pendingRemoveFutures = new LinkedHashMap<>(); + TaskManager(final Time time, final ChangelogReader changelogReader, final ProcessId processId, @@ -121,7 +126,8 @@ public class TaskManager { final Admin adminClient, final StateDirectory stateDirectory, final StateUpdater stateUpdater, - final DefaultTaskManager schedulingTaskManager + final DefaultTaskManager schedulingTaskManager, + final long waitForFutureTimeoutMs ) { this.time = time; this.processId = processId; @@ -139,6 +145,7 @@ public class TaskManager { this.stateUpdater = stateUpdater; this.schedulingTaskManager = schedulingTaskManager; + this.waitForFutureTimeoutMs = waitForFutureTimeoutMs; this.tasks = tasks; this.taskExecutor = new TaskExecutor( this.tasks, @@ -155,6 +162,16 @@ void setMainConsumer(final Consumer mainConsumer) { this.mainConsumer = mainConsumer; } + // visible for testing + long waitForFutureTimeoutMs() { + return waitForFutureTimeoutMs; + } + + // visible for testing + Map> pendingRemoveFutures() { + return pendingRemoveFutures; + } + public double totalProducerBlockedTime() { return activeTaskCreator.totalProducerBlockedTime(); } @@ -604,6 +621,14 @@ private void handleTasksInStateUpdater(final Map> ac final Map> futuresForTasksToClose = new LinkedHashMap<>(); for (final Task task : stateUpdater.tasks()) { final TaskId taskId = task.id(); + if (pendingRemoveFutures.containsKey(taskId)) { + // A previous waitForFuture() timed out for this task — a REMOVE is already + // in flight. Skip it so we don't enqueue a duplicate REMOVE. The pending + // future will be cleaned up in processPendingRemoveFutures(). + activeTasksToCreate.remove(taskId); + standbyTasksToCreate.remove(taskId); + continue; + } if (activeTasksToCreate.containsKey(taskId)) { if (task.isActive()) { if (!task.inputPartitions().equals(activeTasksToCreate.get(taskId))) { @@ -705,10 +730,12 @@ private StateUpdater.RemovedTaskResult waitForFuture(final TaskId taskId, final CompletableFuture future) { final StateUpdater.RemovedTaskResult removedTaskResult; try { - removedTaskResult = future.get(5, TimeUnit.MINUTES); + removedTaskResult = future.get(waitForFutureTimeoutMs, TimeUnit.MILLISECONDS); if (removedTaskResult == null) { - throw new IllegalStateException("Task " + taskId + " was not found in the state updater. " - + BUG_ERROR_MESSAGE); + log.warn("Task {} was not found in the state updater. " + + "Deferring cleanup to next checkStateUpdater() call.", taskId); + pendingRemoveFutures.put(taskId, future); + return null; } return removedTaskResult; } catch (final ExecutionException executionException) { @@ -721,12 +748,33 @@ private StateUpdater.RemovedTaskResult waitForFuture(final TaskId taskId, log.error(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen); throw new IllegalStateException(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen); } catch (final java.util.concurrent.TimeoutException timeoutException) { - log.warn("The state updater wasn't able to remove task {} in time. The state updater thread may be dead. " - + BUG_ERROR_MESSAGE, taskId, timeoutException); + log.warn("The state updater wasn't able to remove task {} in time. " + + "Deferring cleanup to next checkStateUpdater() call.", taskId, timeoutException); + pendingRemoveFutures.put(taskId, future); return null; } } + private void processPendingRemoveFutures() { + final Iterator>> it = + pendingRemoveFutures.entrySet().iterator(); + while (it.hasNext()) { + final Map.Entry> entry = it.next(); + if (entry.getValue().isDone()) { + try { + final StateUpdater.RemovedTaskResult result = entry.getValue().get(); + if (result != null) { + log.info("Processing deferred removal of task {}", entry.getKey()); + closeTaskDirty(result.task(), false); + } + } catch (final Exception e) { + log.warn("Exception processing deferred removal of task {}", entry.getKey(), e); + } + it.remove(); + } + } + } + private Map> pendingTasksToCreate(final Map> tasksToCreate) { final Map> pendingTasks = new HashMap<>(); final Iterator>> iter = tasksToCreate.entrySet().iterator(); @@ -845,6 +893,7 @@ private StreamTask convertStandbyToActive(final StandbyTask standbyTask, final S public boolean checkStateUpdater(final long now, final java.util.function.Consumer> offsetResetter) { + processPendingRemoveFutures(); addTasksToStateUpdater(); if (stateUpdater.hasExceptionsAndFailedTasks()) { handleExceptionsFromStateUpdater(); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java index 21e06b19cc129..8d6414072b097 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -18,6 +18,7 @@ import org.apache.kafka.clients.admin.MockAdminClient; import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; import org.apache.kafka.clients.consumer.ConsumerRecord; @@ -1067,7 +1068,8 @@ public void shouldCommitAfterCommitInterval(final boolean processingThreadsEnabl null, null, null, - null + null, + 300_000L ) { @Override int commit(final Collection tasksToCommit) { @@ -1175,7 +1177,8 @@ public void shouldRecordCommitLatency(final boolean processingThreadsEnabled) { null, stateDirectory, stateUpdater, - schedulingTaskManager + schedulingTaskManager, + 300_000L ) { @Override int commit(final Collection tasksToCommit) { @@ -4197,6 +4200,23 @@ private void runUntilTimeoutOrException(final Runnable action) { } } + @Test + public void shouldSetWaitForFutureTimeoutFromMaxPollIntervalMs() { + final Properties properties = configProps(false, false); + properties.put(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "20000"); + final StreamsConfig config = new StreamsConfig(properties); + thread = createStreamThread(CLIENT_ID, config); + + assertThat(thread.taskManager().waitForFutureTimeoutMs(), equalTo(15_000L)); + } + + @Test + public void shouldSetDefaultWaitForFutureTimeoutFromDefaultMaxPollIntervalMs() { + thread = createStreamThread(CLIENT_ID, false); + + assertThat(thread.taskManager().waitForFutureTimeoutMs(), equalTo(225_000L)); + } + private boolean runUntilTimeoutOrCondition(final Runnable action, final TestCondition testCondition) throws Exception { final long expectedEnd = System.currentTimeMillis() + DEFAULT_MAX_WAIT_MS; while (System.currentTimeMillis() < expectedEnd) { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java index cc3cd12b91810..67edefa6e3aea 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -215,6 +215,13 @@ private TaskManager setUpTaskManager(final ProcessingMode processingMode, final private TaskManager setUpTaskManager(final ProcessingMode processingMode, final TasksRegistry tasks, final boolean processingThreadsEnabled) { + return setUpTaskManager(processingMode, tasks, processingThreadsEnabled, 300_000L); + } + + private TaskManager setUpTaskManager(final ProcessingMode processingMode, + final TasksRegistry tasks, + final boolean processingThreadsEnabled, + final long waitForFutureTimeoutMs) { topologyMetadata = new TopologyMetadata(topologyBuilder, new DummyStreamsConfig(processingMode)); final TaskManager taskManager = new TaskManager( time, @@ -228,7 +235,8 @@ private TaskManager setUpTaskManager(final ProcessingMode processingMode, adminClient, stateDirectory, stateUpdater, - processingThreadsEnabled ? schedulingTaskManager : null + processingThreadsEnabled ? schedulingTaskManager : null, + waitForFutureTimeoutMs ); taskManager.setMainConsumer(consumer); return taskManager; @@ -443,6 +451,105 @@ public void shouldRemoveUnusedFailedStandbyTaskFromStateUpdaterAndCloseDirty() { verify(standbyTaskCreator).createTasks(Collections.emptyMap()); } + @Test + public void shouldStashFutureOnWaitForFutureTimeout() { + final StreamTask activeTask = statefulTask(taskId03, taskId03ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId03Partitions).build(); + final TasksRegistry tasks = mock(TasksRegistry.class); + // Use a short timeout so the test doesn't block for 5 minutes + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, false, 100L); + when(stateUpdater.tasks()).thenReturn(Set.of(activeTask)); + // Future that never completes — will cause TimeoutException + final CompletableFuture future = new CompletableFuture<>(); + when(stateUpdater.remove(eq(activeTask.id()), eq(SuspendReason.MIGRATED))).thenReturn(future); + + taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap()); + + // The timed-out future should be stashed in pendingRemoveFutures + assertTrue(taskManager.pendingRemoveFutures().containsKey(taskId03)); + assertEquals(future, taskManager.pendingRemoveFutures().get(taskId03)); + } + + @Test + public void shouldNotThrowOnNullRemovedTaskResult() { + final StreamTask activeTask = statefulTask(taskId03, taskId03ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId03Partitions).build(); + final TasksRegistry tasks = mock(TasksRegistry.class); + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks); + when(stateUpdater.tasks()).thenReturn(Set.of(activeTask)); + final CompletableFuture future = new CompletableFuture<>(); + when(stateUpdater.remove(eq(activeTask.id()), eq(SuspendReason.MIGRATED))).thenReturn(future); + // Complete with null — previously threw IllegalStateException + future.complete(null); + + // Should not throw — the null result is stashed as a pending future + taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap()); + + assertTrue(taskManager.pendingRemoveFutures().containsKey(taskId03)); + } + + @Test + public void shouldProcessPendingRemoveFuturesAndCloseTaskDirty() { + final StreamTask activeTask = statefulTask(taskId03, taskId03ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId03Partitions).build(); + final TasksRegistry tasks = mock(TasksRegistry.class); + // Use a short timeout so the test doesn't block for 5 minutes + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, false, 100L); + when(stateUpdater.tasks()).thenReturn(Set.of(activeTask)); + final CompletableFuture future = new CompletableFuture<>(); + when(stateUpdater.remove(eq(activeTask.id()), eq(SuspendReason.MIGRATED))).thenReturn(future); + + // Trigger handleAssignment — future times out, gets stashed + taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap()); + assertTrue(taskManager.pendingRemoveFutures().containsKey(taskId03)); + + // Now complete the future (StateUpdater processed the REMOVE) + future.complete(new StateUpdater.RemovedTaskResult(activeTask)); + + // Reset stateUpdater mock for checkStateUpdater + when(stateUpdater.tasks()).thenReturn(Collections.emptySet()); + when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(false); + when(stateUpdater.restoresActiveTasks()).thenReturn(false); + + // checkStateUpdater should process the pending future and close the task dirty + taskManager.checkStateUpdater(time.milliseconds(), noOpResetter); + + verify(activeTask).prepareCommit(false); + verify(activeTask).suspend(); + verify(activeTask).closeDirty(); + assertTrue(taskManager.pendingRemoveFutures().isEmpty()); + } + + @Test + public void shouldSkipTaskInStateUpdaterWithPendingFuture() { + final StreamTask activeTask = statefulTask(taskId03, taskId03ChangelogPartitions) + .inState(State.RESTORING) + .withInputPartitions(taskId03Partitions).build(); + final TasksRegistry tasks = mock(TasksRegistry.class); + // Use a short timeout so the test doesn't block for 5 minutes + final TaskManager taskManager = setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks, false, 100L); + when(stateUpdater.tasks()).thenReturn(Set.of(activeTask)); + final CompletableFuture future = new CompletableFuture<>(); + when(stateUpdater.remove(eq(activeTask.id()), eq(SuspendReason.MIGRATED))).thenReturn(future); + + // First assignment — future times out, gets stashed + taskManager.handleAssignment(Collections.emptyMap(), Collections.emptyMap()); + assertTrue(taskManager.pendingRemoveFutures().containsKey(taskId03)); + + // Second assignment — same task still in stateUpdater but has a pending future. + // Should not enqueue another REMOVE. + taskManager.handleAssignment( + Collections.singletonMap(taskId03, taskId03Partitions), + Collections.emptyMap() + ); + + // remove() should have been called only once (from the first assignment) + verify(stateUpdater, times(1)).remove(eq(taskId03), any()); + } + @Test public void shouldCollectFailedTaskFromStateUpdaterAndRethrow() { final StandbyTask failedStandbyTask = standbyTask(taskId02, taskId02ChangelogPartitions)