From 83dee51915338b6f45ff73d0f6c1a6d351137424 Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Thu, 11 Jun 2026 00:14:49 +0530 Subject: [PATCH 1/3] Fast fail incase of worker crashes --- .../celeborn/client/CommitManager.scala | 34 ++- .../celeborn/client/LifecycleManager.scala | 2 + .../client/commit/CommitHandler.scala | 2 + .../commit/ReducePartitionCommitHandler.scala | 18 +- .../celeborn/client/CommitManagerSuite.scala | 249 ++++++++++++++++++ .../ReducePartitionCommitHandlerSuite.scala | 114 ++++++++ .../apache/celeborn/common/CelebornConf.scala | 2 +- 7 files changed, 404 insertions(+), 17 deletions(-) create mode 100644 client/src/test/scala/org/apache/celeborn/client/CommitManagerSuite.scala create mode 100644 client/src/test/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandlerSuite.scala diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala index 64a7c95e9d8..03e8b3c3aca 100644 --- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala @@ -98,6 +98,7 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage def start(): Unit = { lifecycleManager.registerWorkerStatusListener(new ShutdownWorkerListener) + lifecycleManager.registerWorkerStatusListener(new UnknownWorkerListener) batchHandleCommitPartition = batchHandleCommitPartitionSchedulerThread.map { _.scheduleWithFixedDelay( @@ -286,6 +287,10 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage getCommitHandler(shuffleId).setStageEnd(shuffleId) } + def markShuffleDataLost(shuffleId: Int): Unit = { + getCommitHandler(shuffleId).markShuffleDataLost(shuffleId) + } + def waitStageEnd(shuffleId: Int): (Boolean, Long) = { getCommitHandler(shuffleId).waitStageEnd(shuffleId) } @@ -337,10 +342,9 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage (totalWritten + written, totalFileCount + fileCount) } - class ShutdownWorkerListener extends WorkerStatusListener { - + private class ShutdownWorkerListener extends WorkerStatusListener { override def notifyChangedWorkersStatus(workersStatus: WorkersStatus): Unit = { - if (workersStatus.shutdownWorkers != null) { + if (workersStatus.shutdownWorkers != null && !workersStatus.shutdownWorkers.isEmpty) { lifecycleManager.shuffleAllocatedWorkers.asScala.foreach { case (shuffleId, workerIdToPartitionLocationInfos) => if (!isStageEndOrInProcess(shuffleId)) { @@ -367,6 +371,30 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage } } + private class UnknownWorkerListener extends WorkerStatusListener { + private val shuffleDataLostOnUnknownWorkerEnabled = conf.clientShuffleDataLostOnUnknownWorkerEnabled + private val pushReplicateEnabled = conf.clientPushReplicateEnabled + + override def notifyChangedWorkersStatus(workersStatus: WorkersStatus): Unit = { + if (shuffleDataLostOnUnknownWorkerEnabled && !pushReplicateEnabled) { + if (workersStatus.unknownWorkers != null && !workersStatus.unknownWorkers.isEmpty) { + lifecycleManager.shuffleAllocatedWorkers.asScala.foreach { + case (shuffleId, workerIdToPartitionLocationInfos) => + val hasDataOnLostWorker = workersStatus.unknownWorkers.asScala.exists { worker => + workerIdToPartitionLocationInfos.containsKey(worker.toUniqueId) + } + if (hasDataOnLostWorker) { + logWarning(s"Shuffle $shuffleId has data on lost worker(s) " + + s"${workersStatus.unknownWorkers.asScala.map(_.toUniqueId).mkString("[", ",", "]")}," + + s" marking data as lost immediately.") + markShuffleDataLost(shuffleId) + } + } + } + } + } + } + def finishPartition( shuffleId: Int, partitionId: Int, diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index f9508cfe6c8..0571b4780cc 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -86,6 +86,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends private val pushRackAwareEnabled = conf.clientReserveSlotsRackAwareEnabled private val partitionSplitThreshold = conf.shufflePartitionSplitThreshold private val partitionSplitMode = conf.shufflePartitionSplitMode + private val shuffleDataLostOnUnknownWorkerEnabled = + conf.clientShuffleDataLostOnUnknownWorkerEnabled // shuffle id -> partition type private val shufflePartitionType = JavaUtils.newConcurrentHashMap[Int, PartitionType]() private val rangeReadFilter = conf.shuffleRangeReadFilterEnabled diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala index f54c2b990d9..c868292510f 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala @@ -95,6 +95,8 @@ abstract class CommitHandler( def isStageDataLost(shuffleId: Int): Boolean = false + def markShuffleDataLost(shuffleId: Int): Unit = {} + def setStageEnd(shuffleId: Int): Unit = { throw new UnsupportedOperationException( "Failed when do setStageEnd Operation, MapPartition shuffleType don't " + diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index b3e7aa90ab5..1461f23ec01 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -144,22 +144,14 @@ class ReducePartitionCommitHandler( if (mockShuffleLost) { mockShuffleLostShuffle == shuffleId } else { - dataLostShuffleSet.contains(shuffleId) || isStageDataLostInUnknownWorker(shuffleId) + dataLostShuffleSet.contains(shuffleId) } } - private def isStageDataLostInUnknownWorker(shuffleId: Int): Boolean = { - if (conf.clientShuffleDataLostOnUnknownWorkerEnabled && !conf.clientPushReplicateEnabled) { - val allocatedWorkers = shuffleAllocatedWorkers.get(shuffleId) - if (allocatedWorkers != null) { - return workerStatusTracker.excludedWorkers.asScala.collect { - case (workerId, (status, _)) - if status == StatusCode.WORKER_UNKNOWN && allocatedWorkers.contains(workerId) => - workerId - }.nonEmpty - } - } - false + override def markShuffleDataLost(shuffleId: Int): Unit = { + logWarning(s"Marking shuffle $shuffleId data as lost due to unknown/crashed worker.") + dataLostShuffleSet.add(shuffleId) + setStageEnd(shuffleId) // unblocks all pending GetReducerFileGroup waiters immediately } override def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean = { diff --git a/client/src/test/scala/org/apache/celeborn/client/CommitManagerSuite.scala b/client/src/test/scala/org/apache/celeborn/client/CommitManagerSuite.scala new file mode 100644 index 00000000000..3551eceaeea --- /dev/null +++ b/client/src/test/scala/org/apache/celeborn/client/CommitManagerSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client + +import java.util +import java.util.concurrent.ThreadPoolExecutor + +import scala.collection.JavaConverters._ +import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ + +import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.Mockito.{doAnswer, mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.client.LifecycleManager.ShuffleAllocatedWorkers +import org.apache.celeborn.client.listener.WorkerStatusListener +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.CelebornConf.{ + CLIENT_BATCH_HANDLE_COMMIT_PARTITION_ENABLED, + CLIENT_PUSH_REPLICATE_ENABLED, + CLIENT_SHUFFLE_DATA_LOST_ON_UNKNOWN_WORKER_ENABLED +} +import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} +import org.apache.celeborn.common.network.protocol.SerdeVersion +import org.apache.celeborn.common.protocol.PartitionType +import org.apache.celeborn.common.protocol.message.ControlMessages.{ + GetReducerFileGroupResponse, + HeartbeatFromApplicationResponse +} +import org.apache.celeborn.common.protocol.message.StatusCode +import org.apache.celeborn.common.rpc.RpcAddress +import org.apache.celeborn.common.rpc.netty.LocalNettyRpcCallContext +import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils} + +class CommitManagerSuite extends CelebornFunSuite { + + // Background daemon pools are created inside CommitManager; skip thread audit. + override protected val enableAutoThreadAudit = false + + private var rpcPool: ThreadPoolExecutor = _ + + override def beforeAll(): Unit = { + super.beforeAll() + rpcPool = ThreadUtils.newDaemonCachedThreadPool("test-cm-rpc") + } + + override def afterAll(): Unit = { + if (rpcPool != null) rpcPool.shutdownNow() + super.afterAll() + } + + private def worker(host: String): WorkerInfo = new WorkerInfo(host, 1, 2, 3, 4) + + private def pendingContext(): (LocalNettyRpcCallContext, Promise[Any]) = { + val p = Promise[Any]() + (new LocalNettyRpcCallContext(RpcAddress("localhost", 0), p), p) + } + + private def makeManager( + conf: CelebornConf, + allocatedWorkers: ShuffleAllocatedWorkers): (CommitManager, WorkerStatusTracker) = { + val tracker = new WorkerStatusTracker(conf, null) + val lm = mock(classOf[LifecycleManager]) + + doAnswer(new Answer[Unit] { + override def answer(inv: InvocationOnMock): Unit = + tracker.registerWorkerStatusListener(inv.getArgument[WorkerStatusListener](0)) + }).when(lm).registerWorkerStatusListener(any(classOf[WorkerStatusListener])) + + when(lm.shuffleAllocatedWorkers).thenReturn(allocatedWorkers) + when(lm.getPartitionType(anyInt())).thenReturn(PartitionType.REDUCE) + when(lm.workerStatusTracker).thenReturn(tracker) + when(lm.rpcSharedThreadPool).thenReturn(rpcPool) + + val mgr = new CommitManager("test-app", conf, lm) + mgr.start() + (mgr, tracker) + } + + private def baseConf( + dataLostEnabled: Boolean = true, + replicateEnabled: Boolean = false): CelebornConf = { + val c = new CelebornConf() + c.set(CLIENT_SHUFFLE_DATA_LOST_ON_UNKNOWN_WORKER_ENABLED, dataLostEnabled) + c.set(CLIENT_PUSH_REPLICATE_ENABLED, replicateEnabled) + c.set(CLIENT_BATCH_HANDLE_COMMIT_PARTITION_ENABLED, false) + c + } + + private def unknownHeartbeat(tracker: WorkerStatusTracker, workers: WorkerInfo*): Unit = + tracker.handleHeartbeatResponse(HeartbeatFromApplicationResponse( + StatusCode.SUCCESS, + new util.ArrayList[WorkerInfo](), + new util.ArrayList[WorkerInfo](workers.asJava), + new util.ArrayList[WorkerInfo](), + new util.ArrayList[Integer](), + null)) + + private def allocate( + alloc: ShuffleAllocatedWorkers, + shuffleId: Int, + w: WorkerInfo): Unit = { + val m = JavaUtils.newConcurrentHashMap[String, ShufflePartitionLocationInfo]() + m.put(w.toUniqueId, new ShufflePartitionLocationInfo(w)) + alloc.put(shuffleId, m) + } + + test("UnknownWorkerListener replies SHUFFLE_DATA_LOST to pending GetReducerFileGroup when worker goes unknown") { + val w = worker("crashed") + val alloc = new ShuffleAllocatedWorkers() + val shuffleId = 1 + allocate(alloc, shuffleId, w) + + val (mgr, tracker) = makeManager(baseConf(), alloc) + mgr.registerShuffle(shuffleId, 2, false, 4) + + val (ctx, promise) = pendingContext() + mgr.handleGetReducerFileGroup(ctx, shuffleId, SerdeVersion.V1) + assert(!promise.isCompleted, "request must be pending before heartbeat") + + unknownHeartbeat(tracker, w) + + assert(promise.isCompleted) + val resp = Await.result(promise.future, 1.second).asInstanceOf[GetReducerFileGroupResponse] + assert(resp.status == StatusCode.SHUFFLE_DATA_LOST) + } + + test("UnknownWorkerListener is a no-op when replication is enabled") { + val w = worker("crashed") + val alloc = new ShuffleAllocatedWorkers() + val shuffleId = 2 + allocate(alloc, shuffleId, w) + + val (mgr, tracker) = makeManager(baseConf(replicateEnabled = true), alloc) + mgr.registerShuffle(shuffleId, 2, false, 4) + + val (ctx, promise) = pendingContext() + mgr.handleGetReducerFileGroup(ctx, shuffleId, SerdeVersion.V1) + + unknownHeartbeat(tracker, w) + assert(!promise.isCompleted) + } + + test("UnknownWorkerListener is a no-op when feature is disabled") { + val w = worker("crashed") + val alloc = new ShuffleAllocatedWorkers() + val shuffleId = 3 + allocate(alloc, shuffleId, w) + + val (mgr, tracker) = makeManager(baseConf(dataLostEnabled = false), alloc) + mgr.registerShuffle(shuffleId, 2, false, 4) + + val (ctx, promise) = pendingContext() + mgr.handleGetReducerFileGroup(ctx, shuffleId, SerdeVersion.V1) + + unknownHeartbeat(tracker, w) + assert(!promise.isCompleted) + } + + test("UnknownWorkerListener is a no-op when the crashed worker holds no shuffle data") { + val dataWorker = worker("healthy") + val crashedWorker = worker("crashed") + val alloc = new ShuffleAllocatedWorkers() + val shuffleId = 4 + allocate(alloc, shuffleId, dataWorker) + + val (mgr, tracker) = makeManager(baseConf(), alloc) + mgr.registerShuffle(shuffleId, 2, false, 4) + + val (ctx, promise) = pendingContext() + mgr.handleGetReducerFileGroup(ctx, shuffleId, SerdeVersion.V1) + + unknownHeartbeat(tracker, crashedWorker) + + assert(!promise.isCompleted) + } + + test("UnknownWorkerListener marks data lost even when stage already ended (post-commit crash)") { + // The write-side commit succeeded before the crash, so stage ended as SUCCESS. + // But committed data on a crashed worker is unreadable — restarted reducer tasks + // must get SHUFFLE_DATA_LOST immediately rather than discovering it mid-read. + val w = worker("crashed-after-commit") + val alloc = new ShuffleAllocatedWorkers() + val shuffleId = 5 + allocate(alloc, shuffleId, w) + + val (mgr, tracker) = makeManager(baseConf(), alloc) + mgr.registerShuffle(shuffleId, 1, false, 2) + + mgr.setStageEnd(shuffleId) + assert(mgr.isStageEnd(shuffleId)) + assert(!mgr.getCommitHandler(shuffleId).isStageDataLost(shuffleId)) + + unknownHeartbeat(tracker, w) + + assert(mgr.getCommitHandler(shuffleId).isStageDataLost(shuffleId)) + + val (ctx, promise) = pendingContext() + mgr.handleGetReducerFileGroup(ctx, shuffleId, SerdeVersion.V1) + assert(promise.isCompleted) + val resp = Await.result(promise.future, 1.second).asInstanceOf[GetReducerFileGroupResponse] + assert(resp.status == StatusCode.SHUFFLE_DATA_LOST) + } + + test("UnknownWorkerListener only fast-fails shuffles whose data is on the crashed worker") { + val crashedWorker = worker("crashed") + val healthyWorker = worker("healthy") + val alloc = new ShuffleAllocatedWorkers() + val affectedId = 10 + val unaffectedId = 11 + allocate(alloc, affectedId, crashedWorker) + allocate(alloc, unaffectedId, healthyWorker) + + val (mgr, tracker) = makeManager(baseConf(), alloc) + mgr.registerShuffle(affectedId, 2, false, 4) + mgr.registerShuffle(unaffectedId, 2, false, 4) + + val (ctx1, p1) = pendingContext() + val (ctx2, p2) = pendingContext() + mgr.handleGetReducerFileGroup(ctx1, affectedId, SerdeVersion.V1) + mgr.handleGetReducerFileGroup(ctx2, unaffectedId, SerdeVersion.V1) + + unknownHeartbeat(tracker, crashedWorker) + + assert(p1.isCompleted) + assert( + Await.result(p1.future, 1.second).asInstanceOf[GetReducerFileGroupResponse].status + == StatusCode.SHUFFLE_DATA_LOST) + assert(!p2.isCompleted) + } +} diff --git a/client/src/test/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandlerSuite.scala b/client/src/test/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandlerSuite.scala new file mode 100644 index 00000000000..a92800c7824 --- /dev/null +++ b/client/src/test/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandlerSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.commit + +import java.util.concurrent.{ScheduledExecutorService, ThreadPoolExecutor} + +import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.client.WorkerStatusTracker +import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo +import org.apache.celeborn.client.LifecycleManager.ShuffleAllocatedWorkers +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.network.protocol.SerdeVersion +import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse +import org.apache.celeborn.common.protocol.message.StatusCode +import org.apache.celeborn.common.rpc.RpcAddress +import org.apache.celeborn.common.rpc.netty.LocalNettyRpcCallContext +import org.apache.celeborn.common.util.ThreadUtils + +class ReducePartitionCommitHandlerSuite extends CelebornFunSuite { + + // The handler spins up daemon pools; skip the thread audit to avoid flaky leak warnings. + override protected val enableAutoThreadAudit = false + + private var rpcPool: ThreadPoolExecutor = _ + private var commitScheduler: ScheduledExecutorService = _ + + override def beforeAll(): Unit = { + super.beforeAll() + rpcPool = ThreadUtils.newDaemonCachedThreadPool("test-reduce-commit-rpc") + commitScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("test-reduce-commit-scheduler") + } + + override def afterAll(): Unit = { + if (rpcPool != null) rpcPool.shutdownNow() + if (commitScheduler != null) commitScheduler.shutdownNow() + super.afterAll() + } + + private def newHandler(): ReducePartitionCommitHandler = { + val conf = new CelebornConf() + new ReducePartitionCommitHandler( + "test-app", + conf, + new ShuffleAllocatedWorkers(), + new CommittedPartitionInfo(), + new WorkerStatusTracker(conf, null), + rpcPool, + commitScheduler, + null) + } + + private def pendingContext(): (LocalNettyRpcCallContext, Promise[Any]) = { + val p = Promise[Any]() + (new LocalNettyRpcCallContext(RpcAddress("localhost", 0), p), p) + } + + test("markShuffleDataLost replies SHUFFLE_DATA_LOST to GetReducerFileGroup contexts") { + val handler = newHandler() + val shuffleId = 1 + handler.registerShuffle(shuffleId, numMappers = 2, isSegmentGranularityVisible = false, numPartitions = 4) + + val (ctx1, p1) = pendingContext() + handler.handleGetReducerFileGroup(ctx1, shuffleId, SerdeVersion.V1) + assert(!handler.isStageEnd(shuffleId)) + assert(!handler.isStageDataLost(shuffleId)) + + handler.markShuffleDataLost(shuffleId) + + val (ctx2, p2) = pendingContext() + handler.handleGetReducerFileGroup(ctx2, shuffleId, SerdeVersion.V1) + assert(handler.isStageEnd(shuffleId)) + assert(handler.isStageDataLost(shuffleId)) + + Seq(p1, p2).foreach { p => + assert(p.isCompleted) + val resp = Await.result(p.future, 1.second).asInstanceOf[GetReducerFileGroupResponse] + assert(resp.status == StatusCode.SHUFFLE_DATA_LOST) + } + } + + test("markShuffleDataLost marks data lost even when stage already ended (worker crash after commit)") { + val handler = newHandler() + val shuffleId = 1 + handler.registerShuffle(shuffleId, numMappers = 1, isSegmentGranularityVisible = false, numPartitions = 2) + + // Clean stage-end + handler.setStageEnd(shuffleId) + assert(handler.isStageEnd(shuffleId)) + assert(!handler.isStageDataLost(shuffleId)) + + // Worker crashes after commit + handler.markShuffleDataLost(shuffleId) + assert(handler.isStageDataLost(shuffleId)) + } +} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 668fbcf79e4..b6b9a60076b 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -7000,7 +7000,7 @@ object CelebornConf extends Logging { .version("0.6.3") .doc("Whether to mark shuffle data lost when unknown worker is detected.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val ENDPOINT_VERIFIER_SEPARATE_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.rpc.RpcEndpointVerifier.separate.enabled") From 3094c92a04964bbe340e9d13aa545487b01bb4a6 Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Thu, 11 Jun 2026 00:43:33 +0530 Subject: [PATCH 2/3] Fixes --- .../apache/celeborn/client/CommitManager.scala | 3 ++- .../celeborn/client/CommitManagerSuite.scala | 11 ++--------- .../ReducePartitionCommitHandlerSuite.scala | 17 +++++++++++++---- .../apache/celeborn/common/CelebornConf.scala | 6 +++++- docs/configuration/client.md | 2 +- docs/migration.md | 2 ++ 6 files changed, 25 insertions(+), 16 deletions(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala index 03e8b3c3aca..d56aecb927a 100644 --- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala @@ -372,7 +372,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage } private class UnknownWorkerListener extends WorkerStatusListener { - private val shuffleDataLostOnUnknownWorkerEnabled = conf.clientShuffleDataLostOnUnknownWorkerEnabled + private val shuffleDataLostOnUnknownWorkerEnabled = + conf.clientShuffleDataLostOnUnknownWorkerEnabled private val pushReplicateEnabled = conf.clientPushReplicateEnabled override def notifyChangedWorkersStatus(workersStatus: WorkersStatus): Unit = { diff --git a/client/src/test/scala/org/apache/celeborn/client/CommitManagerSuite.scala b/client/src/test/scala/org/apache/celeborn/client/CommitManagerSuite.scala index 3551eceaeea..758ddb4a7d4 100644 --- a/client/src/test/scala/org/apache/celeborn/client/CommitManagerSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/CommitManagerSuite.scala @@ -33,18 +33,11 @@ import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.client.LifecycleManager.ShuffleAllocatedWorkers import org.apache.celeborn.client.listener.WorkerStatusListener import org.apache.celeborn.common.CelebornConf -import org.apache.celeborn.common.CelebornConf.{ - CLIENT_BATCH_HANDLE_COMMIT_PARTITION_ENABLED, - CLIENT_PUSH_REPLICATE_ENABLED, - CLIENT_SHUFFLE_DATA_LOST_ON_UNKNOWN_WORKER_ENABLED -} +import org.apache.celeborn.common.CelebornConf.{CLIENT_BATCH_HANDLE_COMMIT_PARTITION_ENABLED, CLIENT_PUSH_REPLICATE_ENABLED, CLIENT_SHUFFLE_DATA_LOST_ON_UNKNOWN_WORKER_ENABLED} import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} import org.apache.celeborn.common.network.protocol.SerdeVersion import org.apache.celeborn.common.protocol.PartitionType -import org.apache.celeborn.common.protocol.message.ControlMessages.{ - GetReducerFileGroupResponse, - HeartbeatFromApplicationResponse -} +import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, HeartbeatFromApplicationResponse} import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc.RpcAddress import org.apache.celeborn.common.rpc.netty.LocalNettyRpcCallContext diff --git a/client/src/test/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandlerSuite.scala b/client/src/test/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandlerSuite.scala index a92800c7824..4e7600cd930 100644 --- a/client/src/test/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandlerSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandlerSuite.scala @@ -23,9 +23,9 @@ import scala.concurrent.{Await, Promise} import scala.concurrent.duration._ import org.apache.celeborn.CelebornFunSuite -import org.apache.celeborn.client.WorkerStatusTracker import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo import org.apache.celeborn.client.LifecycleManager.ShuffleAllocatedWorkers +import org.apache.celeborn.client.WorkerStatusTracker import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.network.protocol.SerdeVersion import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse @@ -76,7 +76,11 @@ class ReducePartitionCommitHandlerSuite extends CelebornFunSuite { test("markShuffleDataLost replies SHUFFLE_DATA_LOST to GetReducerFileGroup contexts") { val handler = newHandler() val shuffleId = 1 - handler.registerShuffle(shuffleId, numMappers = 2, isSegmentGranularityVisible = false, numPartitions = 4) + handler.registerShuffle( + shuffleId, + numMappers = 2, + isSegmentGranularityVisible = false, + numPartitions = 4) val (ctx1, p1) = pendingContext() handler.handleGetReducerFileGroup(ctx1, shuffleId, SerdeVersion.V1) @@ -97,10 +101,15 @@ class ReducePartitionCommitHandlerSuite extends CelebornFunSuite { } } - test("markShuffleDataLost marks data lost even when stage already ended (worker crash after commit)") { + test( + "markShuffleDataLost marks data lost even when stage already ended (worker crash after commit)") { val handler = newHandler() val shuffleId = 1 - handler.registerShuffle(shuffleId, numMappers = 1, isSegmentGranularityVisible = false, numPartitions = 2) + handler.registerShuffle( + shuffleId, + numMappers = 1, + isSegmentGranularityVisible = false, + numPartitions = 2) // Clean stage-end handler.setStageEnd(shuffleId) diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index b6b9a60076b..3f6ebed0310 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -6998,7 +6998,11 @@ object CelebornConf extends Logging { buildConf("celeborn.client.shuffleDataLostOnUnknownWorker.enabled") .categories("client") .version("0.6.3") - .doc("Whether to mark shuffle data lost when unknown worker is detected.") + .doc("When enabled, any shuffle that had partitions on the (crashed) " + + "unknown worker is immediately marked as data lost. " + + "On the write flow revive/commit request for that shuffle will fast fail. " + + "GetReducerFileGroup requests are replied with SHUFFLE_DATA_LOST. " + + "This has no effect when ${CLIENT_PUSH_REPLICATE_ENABLED.key}=true") .booleanConf .createWithDefault(true) diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 5d15c6859fb..b57e803fde0 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -122,7 +122,7 @@ license: | | celeborn.client.shuffle.rangeReadFilter.enabled | false | false | If a spark application have skewed partition, this value can set to true to improve performance. | 0.2.0 | celeborn.shuffle.rangeReadFilter.enabled | | celeborn.client.shuffle.register.filterExcludedWorker.enabled | false | false | Whether to filter excluded worker when register shuffle. | 0.4.0 | | | celeborn.client.shuffle.reviseLostShuffles.enabled | false | false | Whether to revise lost shuffles. | 0.6.0 | | -| celeborn.client.shuffleDataLostOnUnknownWorker.enabled | false | false | Whether to mark shuffle data lost when unknown worker is detected. | 0.6.3 | | +| celeborn.client.shuffleDataLostOnUnknownWorker.enabled | true | false | When enabled, any shuffle that had partitions on the (crashed) unknown worker is immediately marked as data lost. On the write flow revive/commit request for that shuffle will fast fail. GetReducerFileGroup requests are replied with SHUFFLE_DATA_LOST. This has no effect when ${CLIENT_PUSH_REPLICATE_ENABLED.key}=true | 0.6.3 | | | celeborn.client.slot.assign.maxWorkers | 10000 | false | Max workers that slots of one shuffle can be allocated on. Will choose the smaller positive one from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. | 0.3.1 | | | celeborn.client.spark.batch.openStream.parallelClientCreation.enabled | true | false | Whether to create data clients in parallel before sending Spark batch open-stream requests. When false, data clients are created serially. | 0.6.3 | | | celeborn.client.spark.fetch.cleanFailedShuffle | false | false | whether to clean those disk space occupied by shuffles which cannot be fetched | 0.6.0 | | diff --git a/docs/migration.md b/docs/migration.md index 8b3625f4463..4581001fadd 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -37,6 +37,8 @@ license: | - Since 0.7.0, Celeborn changed the default value of `celeborn.port.maxRetries` from `1` to `16`. +- Since 0.7.0, Celeborn change the default value of `celeborn.client.shuffleDataLostOnUnknownWorker.enabled` from `false` to `true`, which means Celeborn will treat shuffle data lost when unknown worker is detected at default. + # Upgrading from 0.5 to 0.6 - Since 0.6.0, Celeborn deprecate `celeborn.client.spark.fetch.throwsFetchFailure`. Please use `celeborn.client.spark.stageRerun.enabled` instead. From ac8d5729d5869923e145e0068dee01bc706a3a4a Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Thu, 11 Jun 2026 00:54:36 +0530 Subject: [PATCH 3/3] stage end --- .../celeborn/client/commit/ReducePartitionCommitHandler.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index 1461f23ec01..6d9cd9b6904 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -151,7 +151,9 @@ class ReducePartitionCommitHandler( override def markShuffleDataLost(shuffleId: Int): Unit = { logWarning(s"Marking shuffle $shuffleId data as lost due to unknown/crashed worker.") dataLostShuffleSet.add(shuffleId) - setStageEnd(shuffleId) // unblocks all pending GetReducerFileGroup waiters immediately + if (!isStageEnd(shuffleId)) { + setStageEnd(shuffleId) + } } override def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean = {