From c0b60dd30b715979d28f163cfc6989801ea38be6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E7=A6=8F?= Date: Wed, 22 Apr 2026 13:01:08 +0800 Subject: [PATCH 1/2] [CELEBORN] Optimize RegisterShuffle for large partition counts 1. Replace partitionIdList (ArrayList) transmission with a single numPartitions integer via new PbRequestSlotsV2 message type, eliminating ~10MB protobuf payload for 2M-partition shuffles. Old PbRequestSlots is preserved for backward compatibility. 2. Optimize SlotsAllocator.roundRobin(): - Pre-compute per-worker usable slots into long[] arrays, replacing O(N*W) haveUsableSlots() stream calls with O(1) array lookups. - Replace LinkedList iterator + remove with index-based traversal, eliminating O(N^2) element shifting overhead that dominated CPU (90% in flame graph for 2M partitions). --- .../client/ChangePartitionManager.scala | 4 +- .../celeborn/client/LifecycleManager.scala | 8 +- common/src/main/proto/TransportMessages.proto | 18 +++ .../protocol/message/ControlMessages.scala | 32 ++++- .../SlotsAllocatorBenchmark-jdk17-results.txt | 55 +++++++++ .../service/deploy/master/SlotsAllocator.java | 52 ++++++-- .../service/deploy/master/Master.scala | 13 +- .../service/deploy/master/MasterSuite.scala | 2 +- .../master/SlotsAllocatorBenchmark.scala | 111 ++++++++++++++++++ ...gePartitionManagerUpdateWorkersSuite.scala | 31 +---- .../LifecycleManagerCommitFilesSuite.scala | 24 +--- .../LifecycleManagerDestroySlotsSuite.scala | 18 +-- .../LifecycleManagerSetupEndpointSuite.scala | 12 +- .../tests/client/LifecycleManagerSuite.scala | 27 +---- ...fecycleManagerUnregisterShuffleSuite.scala | 12 +- .../tests/client/ShuffleClientSuite.scala | 7 +- 16 files changed, 286 insertions(+), 140 deletions(-) create mode 100644 master/benchmarks/SlotsAllocatorBenchmark-jdk17-results.txt create mode 100644 master/src/test/scala/org/apache/celeborn/service/deploy/master/SlotsAllocatorBenchmark.scala diff --git a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala index 71096a952fb..93304fd6ada 100644 --- a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala @@ -305,11 +305,9 @@ class ChangePartitionManager( || (unavailableWorkerRatio >= dynamicResourceUnavailableFactor)) { // get new available workers for the request partition ids - val partitionIds = new util.ArrayList[Integer]( - changePartitions.map(_.partitionId).map(Integer.valueOf).toList.asJava) // The partition id value is not important here because we're just trying to get the workers to use val requestSlotsRes = - lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, partitionIds) + lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, changePartitions.size) requestSlotsRes.status match { case StatusCode.REQUEST_FAILED => diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index a37513a236f..8a5bb6259f8 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -776,9 +776,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } // First, request to get allocated slots from Primary - val ids = new util.ArrayList[Integer](numPartitions) - (0 until numPartitions).foreach(idx => ids.add(Integer.valueOf(idx))) - val res = requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = requestMasterRequestSlotsWithRetry(shuffleId, numPartitions) res.status match { case StatusCode.REQUEST_FAILED => @@ -1832,7 +1830,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends def requestMasterRequestSlotsWithRetry( shuffleId: Int, - ids: util.ArrayList[Integer]): RequestSlotsResponse = { + numPartitions: Int): RequestSlotsResponse = { val excludedWorkerSet = if (excludedWorkersFilter) { workerStatusTracker.excludedWorkers.asScala.keys.toSet @@ -1845,7 +1843,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends RequestSlots( appUniqueId, shuffleId, - ids, + numPartitions, lifecycleHost, pushReplicateEnabled, pushRackAwareEnabled, diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index a813a9e5015..63c1d26853e 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -117,6 +117,8 @@ enum MessageType { READ_REDUCER_PARTITION_END = 94; READ_REDUCER_PARTITION_END_RESPONSE = 95; REGISTER_APPLICATION_INFO = 96; + + REQUEST_SLOTS_V2 = 97; } enum StreamType { @@ -325,6 +327,22 @@ message PbRequestSlots { string tagsExpr = 14; } +message PbRequestSlotsV2 { + string applicationId = 1; + int32 shuffleId = 2; + int32 numPartitions = 3; + string hostname = 4; + bool shouldReplicate = 5; + string requestId = 6; + PbUserIdentifier userIdentifier = 7; + bool shouldRackAware = 8; + int32 maxWorkers = 9; + int32 availableStorageTypes = 10; + repeated PbWorkerInfo excludedWorkerSet = 11; + bool packed = 12; + string tagsExpr = 13; +} + message PbSlotInfo { map slot = 1; } diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 36f164d697e..ebe400222bc 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -164,7 +164,7 @@ object ControlMessages extends Logging { case class RequestSlots( applicationId: String, shuffleId: Int, - partitionIdList: util.ArrayList[Integer], + numPartitions: Int, hostname: String, shouldReplicate: Boolean, shouldRackAware: Boolean, @@ -650,7 +650,7 @@ object ControlMessages extends Logging { case RequestSlots( applicationId, shuffleId, - partitionIdList, + numPartitions, hostname, shouldReplicate, shouldRackAware, @@ -661,10 +661,10 @@ object ControlMessages extends Logging { packed, tagsExpr, requestId) => - val payload = PbRequestSlots.newBuilder() + val payload = PbRequestSlotsV2.newBuilder() .setApplicationId(applicationId) .setShuffleId(shuffleId) - .addAllPartitionIdList(partitionIdList) + .setNumPartitions(numPartitions) .setHostname(hostname) .setShouldReplicate(shouldReplicate) .setShouldRackAware(shouldRackAware) @@ -677,7 +677,7 @@ object ControlMessages extends Logging { .setPacked(packed) .setTagsExpr(tagsExpr) .build().toByteArray - new TransportMessage(MessageType.REQUEST_SLOTS, payload) + new TransportMessage(MessageType.REQUEST_SLOTS_V2, payload) case RequestSlotsResponse(status, workerResource, packed) => val builder = PbRequestSlotsResponse.newBuilder() @@ -1151,7 +1151,7 @@ object ControlMessages extends Logging { RequestSlots( pbRequestSlots.getApplicationId, pbRequestSlots.getShuffleId, - new util.ArrayList[Integer](pbRequestSlots.getPartitionIdListList), + pbRequestSlots.getPartitionIdListList.size(), pbRequestSlots.getHostname, pbRequestSlots.getShouldReplicate, pbRequestSlots.getShouldRackAware, @@ -1163,6 +1163,26 @@ object ControlMessages extends Logging { pbRequestSlots.getTagsExpr, pbRequestSlots.getRequestId) + case REQUEST_SLOTS_V2_VALUE => + val pb = PbRequestSlotsV2.parseFrom(message.getPayload) + val userIdentifier = PbSerDeUtils.fromPbUserIdentifier(pb.getUserIdentifier) + val excludedWorkerInfoSet = + pb.getExcludedWorkerSetList.asScala.map(PbSerDeUtils.fromPbWorkerInfo).toSet + RequestSlots( + pb.getApplicationId, + pb.getShuffleId, + pb.getNumPartitions, + pb.getHostname, + pb.getShouldReplicate, + pb.getShouldRackAware, + userIdentifier, + pb.getMaxWorkers, + pb.getAvailableStorageTypes, + excludedWorkerInfoSet, + pb.getPacked, + pb.getTagsExpr, + pb.getRequestId) + case REQUEST_SLOTS_RESPONSE_VALUE => val pbRequestSlotsResponse = PbRequestSlotsResponse.parseFrom(message.getPayload) val workerResource = diff --git a/master/benchmarks/SlotsAllocatorBenchmark-jdk17-results.txt b/master/benchmarks/SlotsAllocatorBenchmark-jdk17-results.txt new file mode 100644 index 00000000000..7c1f9cbad04 --- /dev/null +++ b/master/benchmarks/SlotsAllocatorBenchmark-jdk17-results.txt @@ -0,0 +1,55 @@ +================================================================================================ +200 workers, 10K partitions, no replication +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4 +Apple M2 Pro +200 workers, 10K partitions, no replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------- +offerSlotsRoundRobin 1 1 0 15.6 64.3 1.0X + + +================================================================================================ +200 workers, 100K partitions, no replication +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4 +Apple M2 Pro +200 workers, 100K partitions, no replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +---------------------------------------------------------------------------------------------------------------------------- +offerSlotsRoundRobin 6 7 0 15.8 63.4 1.0X + + +================================================================================================ +500 workers, 100K partitions, with replication +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4 +Apple M2 Pro +500 workers, 100K partitions, with replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------ +offerSlotsRoundRobin 12 15 2 8.0 124.5 1.0X + + +================================================================================================ +500 workers, 2M partitions, no replication +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4 +Apple M2 Pro +500 workers, 2M partitions, no replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------- +offerSlotsRoundRobin 252 351 102 7.9 126.1 1.0X + + +================================================================================================ +1000 workers, 500K partitions, with replication +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.17+10 on Mac OS X 15.4 +Apple M2 Pro +1000 workers, 500K partitions, with replication: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------- +offerSlotsRoundRobin 77 159 46 6.5 154.7 1.0X + + diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java index 5580cd341a4..432af5d795a 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java @@ -517,7 +517,6 @@ private static List roundRobin( } // workerInfo -> (diskIndexForPrimaryAndReplica) Map workerDiskIndex = new HashMap<>(); - List partitionIdList = new LinkedList<>(partitionIds); final int primaryWorkersSize = primaryWorkers.size(); final int replicaWorkersSize = replicaWorkers.size(); @@ -533,19 +532,27 @@ private static List roundRobin( replicaIndex = -1; } - ListIterator iter = partitionIdList.listIterator(partitionIdList.size()); - // Iterate from the end to preserve O(1) removal of processed partitions. - // This is important when we have a high number of concurrent apps that have a - // high number of partitions. + // Pre-compute usable slots per worker to avoid repeated stream operations O(N*W) -> O(W) + long[] primaryUsableSlots = null; + long[] replicaUsableSlots = null; + if (slotsRestrictions != null && !slotsRestrictions.isEmpty()) { + primaryUsableSlots = computeUsableSlots(primaryWorkers, slotsRestrictions); + if (shouldReplicate) { + replicaUsableSlots = computeUsableSlots(replicaWorkers, slotsRestrictions); + } + } + + // Use index-based iteration to avoid O(N^2) LinkedList.remove() overhead. + int allocatedCount = 0; outer: - while (iter.hasPrevious()) { + for (int pidIdx = 0; pidIdx < partitionIds.size(); pidIdx++) { int nextPrimaryInd = primaryIndex; - int partitionId = iter.previous(); + int partitionId = partitionIds.get(pidIdx); StorageInfo storageInfo; - if (slotsRestrictions != null && !slotsRestrictions.isEmpty()) { + if (primaryUsableSlots != null) { // this means that we'll select a mount point - while (!haveUsableSlots(slotsRestrictions, primaryWorkers, nextPrimaryInd)) { + while (primaryUsableSlots[nextPrimaryInd] <= 0) { nextPrimaryInd = primaryWorkersIncrementIndex.applyAsInt(nextPrimaryInd); if (nextPrimaryInd == primaryIndex) { break outer; @@ -558,6 +565,7 @@ private static List roundRobin( slotsRestrictions, workerDiskIndex, availableStorageTypes); + primaryUsableSlots[nextPrimaryInd]--; } else { if (StorageInfo.localDiskAvailable(availableStorageTypes)) { while (!primaryWorkers.get(nextPrimaryInd).haveDisk()) { @@ -576,9 +584,9 @@ private static List roundRobin( if (shouldReplicate) { int nextReplicaInd = replicaIndex; - if (slotsRestrictions != null) { + if (replicaUsableSlots != null) { while ((nextReplicaInd == nextPrimaryInd && skipLocationsOnSameWorkerCheck) - || !haveUsableSlots(slotsRestrictions, replicaWorkers, nextReplicaInd) + || replicaUsableSlots[nextReplicaInd] <= 0 || !satisfyRackAware( shouldRackAware, primaryWorkers, @@ -597,6 +605,7 @@ private static List roundRobin( slotsRestrictions, workerDiskIndex, availableStorageTypes); + replicaUsableSlots[nextReplicaInd]--; } else if (shouldRackAware) { while ((nextReplicaInd == nextPrimaryInd && skipLocationsOnSameWorkerCheck) || !satisfyRackAware( @@ -642,9 +651,26 @@ private static List roundRobin( v -> new Tuple2<>(new ArrayList<>(), new ArrayList<>())); locations._1.add(primaryPartition); primaryIndex = primaryWorkersIncrementIndex.applyAsInt(nextPrimaryInd); - iter.remove(); + allocatedCount++; + } + if (allocatedCount == partitionIds.size()) { + return Collections.emptyList(); } - return partitionIdList; + return new ArrayList<>(partitionIds.subList(allocatedCount, partitionIds.size())); + } + + private static long[] computeUsableSlots( + List workers, Map> restrictions) { + long[] slots = new long[workers.size()]; + for (int i = 0; i < workers.size(); i++) { + List disks = restrictions.get(workers.get(i)); + if (disks != null) { + for (UsableDiskInfo d : disks) { + slots[i] += d.usableSlots; + } + } + } + return slots; } private static boolean haveUsableSlots( diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala index 2ff7cfadeeb..1a0318b30c1 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala @@ -930,7 +930,7 @@ private[celeborn] class Master( } def handleRequestSlots(context: RpcCallContext, requestSlots: RequestSlots): Unit = { - val numReducers = requestSlots.partitionIdList.size() + val numReducers = requestSlots.numPartitions val shuffleKey = Utils.makeShuffleKey(requestSlots.applicationId, requestSlots.shuffleId) var availableWorkers = workersAvailable(requestSlots.excludedWorkerSet) @@ -966,6 +966,13 @@ private[celeborn] class Master( 0, startIndex + numWorkers - numAvailableWorkers)) } + // Build partitionIds list locally from numPartitions + val partitionIds = new util.ArrayList[Integer](numReducers) + var i = 0 + while (i < numReducers) { + partitionIds.add(Integer.valueOf(i)) + i += 1 + } // offer slots val slots = masterSource.sample(MasterSource.OFFER_SLOTS_TIME, s"offerSlots-${Random.nextInt()}") { @@ -973,7 +980,7 @@ private[celeborn] class Master( if (slotsAssignPolicy == SlotsAssignPolicy.LOADAWARE) { SlotsAllocator.offerSlotsLoadAware( selectedWorkers, - requestSlots.partitionIdList, + partitionIds, requestSlots.shouldReplicate, requestSlots.shouldRackAware, slotsAssignLoadAwareDiskGroupNum, @@ -986,7 +993,7 @@ private[celeborn] class Master( } else { SlotsAllocator.offerSlotsRoundRobin( selectedWorkers, - requestSlots.partitionIdList, + partitionIds, requestSlots.shouldReplicate, requestSlots.shouldRackAware, requestSlots.availableStorageTypes, diff --git a/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala b/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala index 0a8a7b592f6..8524701524a 100644 --- a/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala +++ b/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterSuite.scala @@ -177,7 +177,7 @@ class MasterSuite extends AnyFunSuite val requestSlots = RequestSlots( "app1", 0, - new util.ArrayList[Integer](), + 0, "localhost", shouldReplicate = false, shouldRackAware = false, diff --git a/master/src/test/scala/org/apache/celeborn/service/deploy/master/SlotsAllocatorBenchmark.scala b/master/src/test/scala/org/apache/celeborn/service/deploy/master/SlotsAllocatorBenchmark.scala new file mode 100644 index 00000000000..a02d43c1f7f --- /dev/null +++ b/master/src/test/scala/org/apache/celeborn/service/deploy/master/SlotsAllocatorBenchmark.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.master + +import java.util +import java.util.{HashMap => JHashMap, Random} +import java.util.stream.{Collectors, IntStream} + +import org.apache.celeborn.benchmark.{Benchmark, BenchmarkBase} +import org.apache.celeborn.common.meta.WorkerInfo +import org.apache.celeborn.common.protocol.StorageInfo + +/** + * SlotsAllocator roundRobin benchmark. + * + * To run this benchmark: + * {{{ + * 1. build/sbt "celeborn-master/test:runMain + * org.apache.celeborn.service.deploy.master.SlotsAllocatorBenchmark" + * 2. generate result: + * CELEBORN_GENERATE_BENCHMARK_FILES=1 build/sbt "celeborn-master/test:runMain + * org.apache.celeborn.service.deploy.master.SlotsAllocatorBenchmark" + * Results will be written to "benchmarks/SlotsAllocatorBenchmark-results.txt". + * }}} + */ +object SlotsAllocatorBenchmark extends BenchmarkBase { + + private val PARTITION_SIZE = 64 * 1024 * 1024L + private val DISK_PATH = "/mnt/disk" + private val DISK_SPACE = 1024L * 1024 * 1024 * 1024 + private val NUM_NETWORK_LOCATIONS = 20 + private val random = new Random(42) + + private def prepareWorkers(numWorkers: Int): util.List[WorkerInfo] = { + val diskPartitionToSize = new JHashMap[String, java.lang.Long]() + diskPartitionToSize.put(DISK_PATH, java.lang.Long.valueOf(DISK_SPACE)) + SlotsAllocatorSuiteJ.basePrepareWorkers( + numWorkers, + true, + diskPartitionToSize, + PARTITION_SIZE, + NUM_NETWORK_LOCATIONS, + false, + random) + } + + private def preparePartitionIds(numPartitions: Int): util.List[Integer] = { + java.util.Collections.unmodifiableList( + IntStream.range(0, numPartitions).boxed().collect(Collectors.toList())) + } + + private def benchmarkRoundRobin( + name: String, + numWorkers: Int, + numPartitions: Int, + shouldReplicate: Boolean): Unit = { + runBenchmark(name) { + val benchmark = new Benchmark(name, numPartitions, output = output) + + benchmark.addTimerCase("offerSlotsRoundRobin") { timer => + val workers = prepareWorkers(numWorkers) + val partitionIds = preparePartitionIds(numPartitions) + timer.startTiming() + SlotsAllocator.offerSlotsRoundRobin( + workers, + partitionIds, + shouldReplicate, + false, + StorageInfo.ALL_TYPES_AVAILABLE_MASK, + false, + 0) + timer.stopTiming() + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + benchmarkRoundRobin( + "200 workers, 10K partitions, no replication", + 200, 10000, shouldReplicate = false) + benchmarkRoundRobin( + "200 workers, 100K partitions, no replication", + 200, 100000, shouldReplicate = false) + benchmarkRoundRobin( + "500 workers, 100K partitions, with replication", + 500, 100000, shouldReplicate = true) + benchmarkRoundRobin( + "500 workers, 2M partitions, no replication", + 500, 2000000, shouldReplicate = false) + benchmarkRoundRobin( + "1000 workers, 500K partitions, with replication", + 1000, 500000, shouldReplicate = true) + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala index 3c0303f4064..b14f9d46da8 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala @@ -61,11 +61,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) val changePartitionManager: ChangePartitionManager = new ChangePartitionManager(conf, lifecycleManager) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 1) @@ -135,11 +131,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) val changePartitionManager: ChangePartitionManager = new ChangePartitionManager(conf, lifecycleManager) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) val workerNum = res.workerResource.keySet().size() assert(workerNum == 2) @@ -241,11 +233,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) val changePartitionManager: ChangePartitionManager = new ChangePartitionManager(conf, lifecycleManager) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) // workerNum is 1 @@ -312,11 +300,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) val changePartitionManager: ChangePartitionManager = new ChangePartitionManager(conf, lifecycleManager) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) // workerNum is 1 @@ -388,12 +372,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite val changePartitionManager: ChangePartitionManager = new ChangePartitionManager(conf, lifecycleManager) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) lifecycleManager.setupEndpoints( res.workerResource.keySet(), diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala index fa505e29ff3..f79b0577250 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala @@ -51,11 +51,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC val conf = celebornConf.clone conf.set(CelebornConf.TEST_MOCK_COMMIT_FILES_FAILURE.key, "false") val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 3) @@ -108,11 +104,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC val conf = celebornConf.clone conf.set(CelebornConf.TEST_MOCK_COMMIT_FILES_FAILURE.key, "true") val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 3) @@ -180,11 +172,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC val conf = celebornConf.clone conf.set(CelebornConf.TEST_MOCK_COMMIT_FILES_FAILURE.key, "true") val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) - val ids = new util.ArrayList[Integer](1000) - 0 until 1000 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 1000) assert(res.status == StatusCode.SUCCESS) lifecycleManager.setupEndpoints( @@ -240,11 +228,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC val shuffleClient = new ShuffleClientImpl(APP, conf, userIdentifier) shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) - val ids = new util.ArrayList[Integer](3) - 0 until 3 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 3) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 3) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala index e8326820209..b4755dd4b93 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerDestroySlotsSuite.scala @@ -48,11 +48,7 @@ class LifecycleManagerDestroySlotsSuite extends WithShuffleClientSuite with Mini val conf = celebornConf.clone conf.set(CelebornConf.TEST_CLIENT_MOCK_DESTROY_SLOTS_FAILURE.key, "false") val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 3) @@ -90,11 +86,7 @@ class LifecycleManagerDestroySlotsSuite extends WithShuffleClientSuite with Mini conf.set(CelebornConf.TEST_CLIENT_MOCK_DESTROY_SLOTS_FAILURE.key, "true") .set(CelebornConf.CLIENT_RPC_MAX_RETIRES.key, "5") val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 3) @@ -132,11 +124,7 @@ class LifecycleManagerDestroySlotsSuite extends WithShuffleClientSuite with Mini val conf = celebornConf.clone conf.set(CelebornConf.TEST_CLIENT_MOCK_DESTROY_SLOTS_FAILURE.key, "false") val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) - val ids = new util.ArrayList[Integer](10) - 0 until 10 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 10) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 3) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala index 9ad455756d4..8bb1760a68e 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSetupEndpointSuite.scala @@ -17,8 +17,6 @@ package org.apache.celeborn.tests.client -import java.util - import scala.collection.JavaConverters._ import org.apache.celeborn.client.{LifecycleManager, WithShuffleClientSuite} @@ -45,9 +43,7 @@ class LifecycleManagerSetupEndpointSuite extends WithShuffleClientSuite with Min test("test setup endpoints with all workers good") { val lifecycleManager: LifecycleManager = new LifecycleManager(APP, celebornConf) - val ids = new util.ArrayList[Integer](100) - 0 until 100 foreach { ids.add(_) } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(0, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(0, 100) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 3) @@ -60,11 +56,7 @@ class LifecycleManagerSetupEndpointSuite extends WithShuffleClientSuite with Min test("test setup endpoints with one worker down") { val lifecycleManager: LifecycleManager = new LifecycleManager(APP, celebornConf) - val ids = new util.ArrayList[Integer](100) - 0 until 100 foreach { - ids.add(_) - } - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(0, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(0, 100) assert(res.status == StatusCode.SUCCESS) assert(res.workerResource.keySet().size() == 3) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala index d34c79419b8..04426f6f0fe 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerSuite.scala @@ -17,8 +17,6 @@ package org.apache.celeborn.tests.client -import java.util - import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.{interval, timeout} import org.scalatest.time.SpanSugar.convertIntToGrainOfTime @@ -52,14 +50,9 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu celebornConf.set(CelebornConf.REGISTER_SHUFFLE_FILTER_EXCLUDED_WORKER_ENABLED, true) val lifecycleManager: LifecycleManager = new LifecycleManager(APP, celebornConf) - val arrayList = new util.ArrayList[Integer]() - (0 to 10).foreach(i => { - arrayList.add(i) - }) - // test request slots without worker excluded val headWorkerInfo = workerInfos.keySet.head.workerInfo - val res1 = lifecycleManager.requestMasterRequestSlotsWithRetry(0, arrayList) + val res1 = lifecycleManager.requestMasterRequestSlotsWithRetry(0, 11) .workerResource.keySet() assert(res1.contains(headWorkerInfo)) @@ -69,7 +62,7 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu workerInfos.keySet.head.workerInfo, (StatusCode.PUSH_DATA_TIMEOUT_PRIMARY, System.currentTimeMillis())) lifecycleManager.workerStatusTracker.recordWorkerFailure(commitFilesFailedWorkers) - val res2 = lifecycleManager.requestMasterRequestSlotsWithRetry(1, arrayList) + val res2 = lifecycleManager.requestMasterRequestSlotsWithRetry(1, 11) .workerResource.keySet() assert(!res2.contains(headWorkerInfo)) @@ -79,7 +72,7 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu worker.workerInfo, (StatusCode.PUSH_DATA_TIMEOUT_PRIMARY, System.currentTimeMillis()))) lifecycleManager.workerStatusTracker.recordWorkerFailure(commitFilesFailedWorkers) - val status = lifecycleManager.requestMasterRequestSlotsWithRetry(2, arrayList).status + val status = lifecycleManager.requestMasterRequestSlotsWithRetry(2, 11).status assert(status == StatusCode.WORKER_EXCLUDED) lifecycleManager.stop() @@ -89,11 +82,6 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu celebornConf.set(CelebornConf.REGISTER_SHUFFLE_FILTER_EXCLUDED_WORKER_ENABLED, false) val lifecycleManager: LifecycleManager = new LifecycleManager(APP, celebornConf) - val arrayList = new util.ArrayList[Integer]() - (0 to 10).foreach(i => { - arrayList.add(i) - }) - // test request slots with all workers excluded, response should not excluded any worker val commitFilesFailedWorkers = new LifecycleManager.ShuffleFailedWorkers() workerInfos.keySet.foreach(worker => @@ -101,7 +89,7 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu worker.workerInfo, (StatusCode.PUSH_DATA_TIMEOUT_PRIMARY, System.currentTimeMillis()))) lifecycleManager.workerStatusTracker.recordWorkerFailure(commitFilesFailedWorkers) - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(0, arrayList) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(0, 11) .workerResource.keySet() assert(res.size() == workerInfos.size) assert(res.contains(workerInfos.keySet.head.workerInfo)) @@ -111,12 +99,7 @@ class LifecycleManagerSuite extends WithShuffleClientSuite with MiniClusterFeatu test("CELEBORN-1258: Support to register application info with user identifier and extra info") { val lifecycleManager: LifecycleManager = new LifecycleManager(APP, celebornConf) - val arrayList = new util.ArrayList[Integer]() - (0 to 10).foreach(i => { - arrayList.add(i) - }) - - lifecycleManager.requestMasterRequestSlotsWithRetry(0, arrayList) + lifecycleManager.requestMasterRequestSlotsWithRetry(0, 11) eventually(timeout(3.seconds), interval(0.milliseconds)) { val appInfo = masterInfo._1.statusSystem.applicationInfos.get(APP) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala index d689ead7e0c..347a3475c6d 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala @@ -17,10 +17,6 @@ package org.apache.celeborn.tests.client -import java.util - -import scala.collection.JavaConverters._ - import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.{interval, timeout} import org.scalatest.time.SpanSugar.convertIntToGrainOfTime @@ -50,12 +46,10 @@ class LifecycleManagerUnregisterShuffleSuite extends WithShuffleClientSuite conf.set(CelebornConf.CLIENT_BATCH_REMOVE_EXPIRED_SHUFFLE_ENABLED.key, "true") val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) val counts = 10 - val ids = - new util.ArrayList[Integer]((0 until counts).toList.map(x => Integer.valueOf(x)).asJava) val shuffleIds = (1 to counts).toList shuffleIds.foreach { shuffleId: Int => - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, counts) assert(res.status == StatusCode.SUCCESS) lifecycleManager.registeredShuffle.add(shuffleId) assert(lifecycleManager.registeredShuffle.contains(shuffleId)) @@ -84,12 +78,10 @@ class LifecycleManagerUnregisterShuffleSuite extends WithShuffleClientSuite val conf = celebornConf.clone val lifecycleManager: LifecycleManager = new LifecycleManager(APP, conf) val counts = 10 - val ids = - new util.ArrayList[Integer]((0 until counts).toList.map(x => Integer.valueOf(x)).asJava) val shuffleIds = (1 to counts).toList shuffleIds.foreach { shuffleId: Int => - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, counts) assert(res.status == StatusCode.SUCCESS) lifecycleManager.registeredShuffle.add(shuffleId) assert(lifecycleManager.registeredShuffle.contains(shuffleId)) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala index 88eb52ef56e..ef8ab9bc7ec 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala @@ -18,9 +18,6 @@ package org.apache.celeborn.tests.client import java.io.IOException -import java.util - -import scala.collection.JavaConverters._ import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl, WithShuffleClientSuite} import org.apache.celeborn.common.CelebornConf @@ -70,9 +67,7 @@ class ShuffleClientSuite extends WithShuffleClientSuite with MiniClusterFeature prepareService() val shuffleId = 0 val counts = 10 - val ids = - new util.ArrayList[Integer]((0 until counts).toList.map(x => Integer.valueOf(x)).asJava) - val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, ids) + val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, counts) assert(res.status == StatusCode.SUCCESS) lifecycleManager.registeredShuffle.add(shuffleId) assert(!shuffleClient.isShuffleStageEnd(shuffleId)) From b683bacc0c1793eb970c24b7fa9d3c0eba341323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E7=A6=8F?= Date: Thu, 14 May 2026 13:18:18 +0800 Subject: [PATCH 2/2] fix style --- .../master/SlotsAllocatorBenchmark.scala | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/master/src/test/scala/org/apache/celeborn/service/deploy/master/SlotsAllocatorBenchmark.scala b/master/src/test/scala/org/apache/celeborn/service/deploy/master/SlotsAllocatorBenchmark.scala index a02d43c1f7f..7afd7b3ea8c 100644 --- a/master/src/test/scala/org/apache/celeborn/service/deploy/master/SlotsAllocatorBenchmark.scala +++ b/master/src/test/scala/org/apache/celeborn/service/deploy/master/SlotsAllocatorBenchmark.scala @@ -94,18 +94,28 @@ object SlotsAllocatorBenchmark extends BenchmarkBase { override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { benchmarkRoundRobin( "200 workers, 10K partitions, no replication", - 200, 10000, shouldReplicate = false) + 200, + 10000, + shouldReplicate = false) benchmarkRoundRobin( "200 workers, 100K partitions, no replication", - 200, 100000, shouldReplicate = false) + 200, + 100000, + shouldReplicate = false) benchmarkRoundRobin( "500 workers, 100K partitions, with replication", - 500, 100000, shouldReplicate = true) + 500, + 100000, + shouldReplicate = true) benchmarkRoundRobin( "500 workers, 2M partitions, no replication", - 500, 2000000, shouldReplicate = false) + 500, + 2000000, + shouldReplicate = false) benchmarkRoundRobin( "1000 workers, 500K partitions, with replication", - 1000, 500000, shouldReplicate = true) + 1000, + 500000, + shouldReplicate = true) } }