diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 155fc088616..e2d760439ef 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -166,8 +166,7 @@ class CelebornShuffleReader[K, C]( true } try { - // startPartition is irrelevant - fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition) + fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition, endPartition) } catch { case ce: CelebornIOException if ce.getCause != null && ce.getCause.isInstanceOf[ diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 8390a9c1b5a..0b4a1929f30 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -183,6 +183,12 @@ public ShuffleClientImpl.ReduceFileGroups updateFileGroup(int shuffleId, int par return null; } + @Override + public ShuffleClientImpl.ReduceFileGroups updateFileGroup( + int shuffleId, int startPartition, int endPartition) throws CelebornIOException { + return null; + } + @Override public boolean isShuffleStageEnd(int shuffleId) throws Exception { return true; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 7035478ebf3..aded1b28005 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -326,6 +326,37 @@ public abstract void mapPartitionMapperEnd( public abstract ShuffleClientImpl.ReduceFileGroups updateFileGroup(int shuffleId, int partitionId) throws CelebornIOException; + public ShuffleClientImpl.ReduceFileGroups updateFileGroup( + int shuffleId, int startPartition, int endPartition) throws CelebornIOException { + if (startPartition < 0 || endPartition < startPartition) { + throw new IllegalArgumentException( + String.format("Invalid reducer file group range [%d, %d)", startPartition, endPartition)); + } + + ShuffleClientImpl.ReduceFileGroups merged = + new ShuffleClientImpl.ReduceFileGroups( + new ConcurrentHashMap<>(), + null, + ConcurrentHashMap.newKeySet(), + new ConcurrentHashMap<>()); + for (int partitionId = startPartition; partitionId < endPartition; partitionId++) { + ShuffleClientImpl.ReduceFileGroups current = updateFileGroup(shuffleId, partitionId); + if (current.partitionGroups != null) { + merged.partitionGroups.putAll(current.partitionGroups); + } + if (current.partitionIds != null) { + merged.partitionIds.addAll(current.partitionIds); + } + if (current.pushFailedBatches != null) { + merged.pushFailedBatches.putAll(current.pushFailedBatches); + } + if (merged.mapAttempts == null) { + merged.mapAttempts = current.mapAttempts; + } + } + return merged; + } + public abstract boolean isShuffleStageEnd(int shuffleId) throws Exception; // Reduce side read partition which is deduplicated by mapperId+mapperAttemptNum+batchId, batchId diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 358bc227a16..61b98846cf0 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -21,7 +21,9 @@ import java.nio.ByteBuffer; import java.util.*; import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -37,6 +39,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.fs.FileSystem; +import org.roaringbitmap.RoaringBitmap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -154,16 +157,33 @@ public static class ReduceFileGroups { public Map pushFailedBatches; public int[] mapAttempts; public Set partitionIds; + private boolean hasPartitionRange; + private int startPartition; + private int endPartition; ReduceFileGroups( Map> partitionGroups, int[] mapAttempts, Set partitionIds, Map pushFailedBatches) { + this(partitionGroups, mapAttempts, partitionIds, pushFailedBatches, false, 0, 0); + } + + ReduceFileGroups( + Map> partitionGroups, + int[] mapAttempts, + Set partitionIds, + Map pushFailedBatches, + boolean hasPartitionRange, + int startPartition, + int endPartition) { this.partitionGroups = partitionGroups; this.mapAttempts = mapAttempts; this.partitionIds = partitionIds; this.pushFailedBatches = pushFailedBatches; + this.hasPartitionRange = hasPartitionRange; + this.startPartition = startPartition; + this.endPartition = endPartition; } public ReduceFileGroups() { @@ -178,6 +198,9 @@ public void update(ReduceFileGroups fileGroups) { mapAttempts = fileGroups.mapAttempts; partitionIds = fileGroups.partitionIds; pushFailedBatches = fileGroups.pushFailedBatches; + hasPartitionRange = fileGroups.hasPartitionRange; + startPartition = fileGroups.startPartition; + endPartition = fileGroups.endPartition; } } @@ -185,6 +208,153 @@ public void update(ReduceFileGroups fileGroups) { protected final Map> reduceFileGroupsMap = JavaUtils.newConcurrentHashMap(); + private static final class ReducerFileGroupRange { + private final int startPartition; + private final int endPartition; + private final boolean segmentGranularityVisible; + + private ReducerFileGroupRange( + int startPartition, int endPartition, boolean segmentGranularityVisible) { + this.startPartition = startPartition; + this.endPartition = endPartition; + this.segmentGranularityVisible = segmentGranularityVisible; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ReducerFileGroupRange)) { + return false; + } + ReducerFileGroupRange that = (ReducerFileGroupRange) other; + return startPartition == that.startPartition + && endPartition == that.endPartition + && segmentGranularityVisible == that.segmentGranularityVisible; + } + + @Override + public int hashCode() { + return Objects.hash(startPartition, endPartition, segmentGranularityVisible); + } + } + + private static final class ReducerFileGroupCache { + private final ReduceFileGroups fileGroups = + new ReduceFileGroups( + JavaUtils.newConcurrentHashMap(), + null, + ConcurrentHashMap.newKeySet(), + JavaUtils.newConcurrentHashMap()); + private final RoaringBitmap loadedPartitions = new RoaringBitmap(); + private final RoaringBitmap segmentVisibleLoadedPartitions = new RoaringBitmap(); + private final Map< + ReducerFileGroupRange, CompletableFuture>> + inFlightLoads = JavaUtils.newConcurrentHashMap(); + private boolean active = true; + private boolean allPartitionsLoaded = false; + private boolean segmentVisibleAllPartitionsLoaded = false; + + private synchronized boolean contains( + int startPartition, int endPartition, boolean segmentGranularityVisible) { + if (!active) { + return false; + } + RoaringBitmap loaded = + segmentGranularityVisible ? segmentVisibleLoadedPartitions : loadedPartitions; + boolean allLoaded = + segmentGranularityVisible ? segmentVisibleAllPartitionsLoaded : allPartitionsLoaded; + return allLoaded || loaded.contains((long) startPartition, (long) endPartition); + } + + private synchronized boolean hasMapAttempts() { + return fileGroups.mapAttempts != null; + } + + private synchronized void merge( + int startPartition, + int endPartition, + boolean segmentGranularityVisible, + ReduceFileGroups loadedFileGroups) { + if (loadedFileGroups.hasPartitionRange + && (loadedFileGroups.startPartition != startPartition + || loadedFileGroups.endPartition != endPartition)) { + throw new IllegalStateException( + String.format( + "Reducer file group response range [%d, %d) does not match request [%d, %d)", + loadedFileGroups.startPartition, + loadedFileGroups.endPartition, + startPartition, + endPartition)); + } + fileGroups.partitionGroups.putAll(loadedFileGroups.partitionGroups); + fileGroups.pushFailedBatches.putAll(loadedFileGroups.pushFailedBatches); + fileGroups.partitionIds.addAll(loadedFileGroups.partitionIds); + if (fileGroups.mapAttempts == null) { + fileGroups.mapAttempts = loadedFileGroups.mapAttempts; + } + if (loadedFileGroups.hasPartitionRange) { + RoaringBitmap loaded = + segmentGranularityVisible ? segmentVisibleLoadedPartitions : loadedPartitions; + loaded.add((long) startPartition, (long) endPartition); + } else if (segmentGranularityVisible) { + // Older drivers ignore the range fields and return shuffle-wide metadata. Remember that + // coverage so later requests do not repeat the full download. + segmentVisibleAllPartitionsLoaded = true; + } else { + // See the mixed-version compatibility note above. + allPartitionsLoaded = true; + } + } + + private synchronized ReduceFileGroups getRange(int startPartition, int endPartition) { + Map> partitionGroups = new HashMap<>(); + Set partitionIds = new HashSet<>(); + Map pushFailedBatches = new HashMap<>(); + for (int partitionId = startPartition; partitionId < endPartition; partitionId++) { + Set locations = fileGroups.partitionGroups.get(partitionId); + if (locations != null) { + partitionGroups.put(partitionId, locations); + for (PartitionLocation location : locations) { + LocationPushFailedBatches failedBatches = + fileGroups.pushFailedBatches.get(location.getUniqueId()); + if (failedBatches != null) { + pushFailedBatches.put(location.getUniqueId(), failedBatches); + } + } + } + if (fileGroups.partitionIds.contains(partitionId)) { + partitionIds.add(partitionId); + } + } + return new ReduceFileGroups( + partitionGroups, + fileGroups.mapAttempts, + partitionIds, + pushFailedBatches, + true, + startPartition, + endPartition); + } + + private synchronized void deactivate() { + active = false; + Tuple3 cleanedUp = + Tuple3.apply(null, "Shuffle cleaned up", null); + inFlightLoads.values().forEach(future -> future.complete(cleanedUp)); + inFlightLoads.clear(); + } + } + + // key: shuffleId. Each executor keeps only reducer ranges it has actually read. + private final Map reduceFileGroupRangeCaches = + JavaUtils.newConcurrentHashMap(); + + @VisibleForTesting Runnable reduceFileGroupsAfterCacheMiss = () -> {}; + + @VisibleForTesting Runnable reduceFileGroupsBeforeInFlightRelease = () -> {}; + private final TransportMessagesHelper messagesHelper = new TransportMessagesHelper(); public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier userIdentifier) { @@ -1842,6 +2012,10 @@ public boolean cleanupShuffle(int shuffleId) { // clear status reducePartitionMap.remove(shuffleId); reduceFileGroupsMap.remove(shuffleId); + ReducerFileGroupCache rangeCache = reduceFileGroupRangeCaches.remove(shuffleId); + if (rangeCache != null) { + rangeCache.deactivate(); + } mapperEndMap.remove(shuffleId); stageEndShuffleSet.remove(shuffleId); splitting.remove(shuffleId); @@ -1852,6 +2026,21 @@ public boolean cleanupShuffle(int shuffleId) { protected Tuple3 loadFileGroupInternal( int shuffleId, boolean isSegmentGranularityVisible) { + return loadFileGroupInternal(shuffleId, 0, 0, isSegmentGranularityVisible, false); + } + + protected Tuple3 loadFileGroupInternal( + int shuffleId, int startPartition, int endPartition, boolean isSegmentGranularityVisible) { + return loadFileGroupInternal( + shuffleId, startPartition, endPartition, isSegmentGranularityVisible, true); + } + + private Tuple3 loadFileGroupInternal( + int shuffleId, + int startPartition, + int endPartition, + boolean isSegmentGranularityVisible, + boolean hasPartitionRange) { long getReducerFileGroupStartTime = System.nanoTime(); String exceptionMsg = null; Exception exception = null; @@ -1861,8 +2050,18 @@ protected Tuple3 loadFileGroupInternal( return Tuple3.apply(null, exceptionMsg, exception); } try { + ReducerFileGroupCache rangeCache = reduceFileGroupRangeCaches.get(shuffleId); + boolean omitMapAttempts = + hasPartitionRange && rangeCache != null && rangeCache.hasMapAttempts(); GetReducerFileGroup getReducerFileGroup = - new GetReducerFileGroup(shuffleId, isSegmentGranularityVisible, SerdeVersion.V1); + new GetReducerFileGroup( + shuffleId, + isSegmentGranularityVisible, + SerdeVersion.V1, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts); GetReducerFileGroupResponse response = lifecycleManagerRef.askSync( @@ -1881,9 +2080,12 @@ protected Tuple3 loadFileGroupInternal( "Failed to get GetReducerFileGroupResponse broadcast for shuffle: " + shuffleId); } } - logger.info( - "Shuffle {} request reducer file group success using {} ms, result partition size {}.", + logger.debug( + "Shuffle {} request reducer file group {} success using {} ms, result partition size {}.", shuffleId, + hasPartitionRange + ? String.format("range [%d, %d)", startPartition, endPartition) + : "for all partitions", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime), response.fileGroup().size()); return Tuple3.apply( @@ -1891,7 +2093,10 @@ protected Tuple3 loadFileGroupInternal( response.fileGroup(), response.attempts(), response.partitionIds(), - response.pushFailedBatches()), + response.pushFailedBatches(), + response.hasPartitionRange(), + response.startPartition(), + response.endPartition()), null, null); case SHUFFLE_UNREGISTERED: @@ -1903,7 +2108,10 @@ protected Tuple3 loadFileGroupInternal( response.fileGroup(), response.attempts(), response.partitionIds(), - response.pushFailedBatches()), + response.pushFailedBatches(), + response.hasPartitionRange(), + response.startPartition(), + response.endPartition()), null, null); case STAGE_END_TIMEOUT: @@ -1933,6 +2141,12 @@ public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId) return updateFileGroup(shuffleId, partitionId, false); } + @Override + public ReduceFileGroups updateFileGroup(int shuffleId, int startPartition, int endPartition) + throws CelebornIOException { + return updateFileGroup(shuffleId, startPartition, endPartition, false); + } + @Override public boolean isShuffleStageEnd(int shuffleId) throws Exception { if (null != lifecycleManagerRef) { @@ -1952,25 +2166,142 @@ public boolean isShuffleStageEnd(int shuffleId) throws Exception { public ReduceFileGroups updateFileGroup( int shuffleId, int partitionId, boolean isSegmentGranularityVisible) throws CelebornIOException { + return updateFileGroup(shuffleId, partitionId, partitionId + 1, isSegmentGranularityVisible); + } + + public ReduceFileGroups updateFileGroup( + int shuffleId, int startPartition, int endPartition, boolean isSegmentGranularityVisible) + throws CelebornIOException { + if (startPartition < 0 || endPartition < startPartition) { + throw new IllegalArgumentException( + String.format("Invalid reducer file group range [%d, %d)", startPartition, endPartition)); + } + if (startPartition == endPartition) { + return super.updateFileGroup(shuffleId, startPartition, endPartition); + } Tuple3 fileGroupTuple = - reduceFileGroupsMap.compute( - shuffleId, - (id, existsTuple) -> { - if (existsTuple == null || existsTuple._1() == null) { - return loadFileGroupInternal(shuffleId, isSegmentGranularityVisible); - } else { - return existsTuple; - } - }); + loadFileGroup(shuffleId, startPartition, endPartition, isSegmentGranularityVisible); if (fileGroupTuple._1() == null) { throw new CelebornIOException( - loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2())), + loadFileGroupException(shuffleId, startPartition, (fileGroupTuple._2())), fileGroupTuple._3()); } else { return fileGroupTuple._1(); } } + private Tuple3 loadFileGroup( + int shuffleId, int startPartition, int endPartition, boolean isSegmentGranularityVisible) { + ReducerFileGroupCache rangeCache = + reduceFileGroupRangeCaches.computeIfAbsent( + shuffleId, ignored -> new ReducerFileGroupCache()); + ReducerFileGroupRange range = + new ReducerFileGroupRange(startPartition, endPartition, isSegmentGranularityVisible); + while (true) { + synchronized (rangeCache) { + if (!rangeCache.active || reduceFileGroupRangeCaches.get(shuffleId) != rangeCache) { + return Tuple3.apply(null, "Shuffle cleaned up", null); + } + if (rangeCache.contains(startPartition, endPartition, isSegmentGranularityVisible)) { + return Tuple3.apply(rangeCache.getRange(startPartition, endPartition), null, null); + } + } + reduceFileGroupsAfterCacheMiss.run(); + + CompletableFuture> newLoad = + new CompletableFuture<>(); + CompletableFuture> inFlightLoad = null; + boolean waitingForMapAttempts = false; + synchronized (rangeCache) { + if (!rangeCache.active || reduceFileGroupRangeCaches.get(shuffleId) != rangeCache) { + return Tuple3.apply(null, "Shuffle cleaned up", null); + } + inFlightLoad = rangeCache.inFlightLoads.get(range); + if (inFlightLoad == null && !rangeCache.hasMapAttempts()) { + // Let only the first cold range fetch the shuffle-wide mapper attempts. Once it merges, + // unrelated ranges proceed independently and omit that array from their responses. + inFlightLoad = rangeCache.inFlightLoads.values().stream().findAny().orElse(null); + waitingForMapAttempts = inFlightLoad != null; + } + if (inFlightLoad == null) { + inFlightLoad = rangeCache.inFlightLoads.putIfAbsent(range, newLoad); + } + } + if (inFlightLoad != null) { + Tuple3 loadedTuple = + waitForFileGroupLoad(inFlightLoad); + boolean retryAfterOwnerInterrupt = + loadedTuple._3() instanceof InterruptedException + && !Thread.currentThread().isInterrupted(); + if ((waitingForMapAttempts && loadedTuple._1() != null) || retryAfterOwnerInterrupt) { + continue; + } + return loadedTuple; + } + + if (rangeCache.contains(startPartition, endPartition, isSegmentGranularityVisible)) { + Tuple3 cached = + Tuple3.apply(rangeCache.getRange(startPartition, endPartition), null, null); + rangeCache.inFlightLoads.remove(range, newLoad); + newLoad.complete(cached); + return cached; + } + + try { + Tuple3 loadedTuple = + loadFileGroupInternal( + shuffleId, startPartition, endPartition, isSegmentGranularityVisible); + Tuple3 completedTuple = loadedTuple; + synchronized (rangeCache) { + boolean shouldPublish = + reduceFileGroupRangeCaches.get(shuffleId) == rangeCache + && rangeCache.active + && rangeCache.inFlightLoads.get(range) == newLoad; + if (shouldPublish && loadedTuple._1() != null) { + rangeCache.merge( + startPartition, endPartition, isSegmentGranularityVisible, loadedTuple._1()); + completedTuple = + Tuple3.apply(rangeCache.getRange(startPartition, endPartition), null, null); + reduceFileGroupsBeforeInFlightRelease.run(); + } else if (!shouldPublish) { + completedTuple = Tuple3.apply(null, "Shuffle cleaned up", null); + } + rangeCache.inFlightLoads.remove(range, newLoad); + } + newLoad.complete(completedTuple); + return completedTuple; + } catch (Throwable e) { + rangeCache.inFlightLoads.remove(range, newLoad); + newLoad.completeExceptionally(e); + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else if (e instanceof Error) { + throw (Error) e; + } else { + throw new RuntimeException(e); + } + } + } + } + + private Tuple3 waitForFileGroupLoad( + CompletableFuture> inFlightLoad) { + try { + return inFlightLoad.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return Tuple3.apply(null, e.getMessage(), e); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof Error) { + throw (Error) cause; + } + Exception exception = + cause instanceof Exception ? (Exception) cause : new RuntimeException(cause); + return Tuple3.apply(null, exception.getMessage(), exception); + } + } + protected String loadFileGroupException(int shuffleId, int partitionId, String exceptionMsg) { return String.format( "Failed to load file group of shuffle %d partition %d! %s", @@ -2048,6 +2379,11 @@ public Map> getReduceFileGr return reduceFileGroupsMap; } + @VisibleForTesting + public boolean hasReducerFileGroupRangeCache(int shuffleId) { + return reduceFileGroupRangeCaches.containsKey(shuffleId); + } + @Override public void shutdown() { if (null != reviveManager) { diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala index 64a7c95e9d8..d108bde6b00 100644 --- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala @@ -293,8 +293,19 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage def handleGetReducerFileGroup( context: RpcCallContext, shuffleId: Int, + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean, + omitMapAttempts: Boolean, serdeVersion: SerdeVersion): Unit = { - getCommitHandler(shuffleId).handleGetReducerFileGroup(context, shuffleId, serdeVersion) + getCommitHandler(shuffleId).handleGetReducerFileGroup( + context, + shuffleId, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts, + serdeVersion) } def handleGetStageEnd(context: RpcCallContext, shuffleId: Int): Unit = { diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index f9508cfe6c8..095e9d7bb9c 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -495,10 +495,24 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends case GetReducerFileGroup( shuffleId: Int, isSegmentGranularityVisible: Boolean, - serdeVersion: SerdeVersion) => + serdeVersion: SerdeVersion, + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean, + omitMapAttempts: Boolean) => logDebug( - s"Received GetShuffleFileGroup request for shuffleId $shuffleId, isSegmentGranularityVisible $isSegmentGranularityVisible") - handleGetReducerFileGroup(context, shuffleId, isSegmentGranularityVisible, serdeVersion) + s"Received GetShuffleFileGroup request for shuffleId $shuffleId, " + + s"isSegmentGranularityVisible $isSegmentGranularityVisible, " + + s"partitionRange ${if (hasPartitionRange) s"[$startPartition, $endPartition)" else "all"}") + handleGetReducerFileGroup( + context, + shuffleId, + isSegmentGranularityVisible, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts, + serdeVersion) case pb: PbGetStageEnd => val shuffleId = pb.getShuffleId @@ -967,7 +981,19 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends context: RpcCallContext, shuffleId: Int, isSegmentGranularityVisible: Boolean, + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean, + omitMapAttempts: Boolean, serdeVersion: SerdeVersion): Unit = { + if (hasPartitionRange && (startPartition < 0 || endPartition <= startPartition)) { + logWarning( + s"Reject invalid reducer file group partition range [$startPartition, $endPartition) " + + s"for shuffle $shuffleId") + context.reply( + GetReducerFileGroupResponse(StatusCode.REQUEST_FAILED, serdeVersion = serdeVersion)) + return + } // If isSegmentGranularityVisible is set to true, the downstream reduce task may start early than upstream map task, e.g. flink hybrid shuffle. // Under these circumstances, there's a possibility that the shuffle might not yet be registered when the downstream reduce task send GetReduceFileGroup request, // so we shouldn't send a SHUFFLE_NOT_REGISTERED response directly, should enqueue this request to pending list, and response to the downstream reduce task the ReduceFileGroup when the upstream map task register shuffle done @@ -980,7 +1006,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends serdeVersion = serdeVersion)) return } - commitManager.handleGetReducerFileGroup(context, shuffleId, serdeVersion) + commitManager.handleGetReducerFileGroup( + context, + shuffleId, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts, + serdeVersion) } private def handleGetStageEnd(context: RpcCallContext, shuffleId: Int): Unit = { diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala index f54c2b990d9..9849dc6b76d 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala @@ -184,6 +184,10 @@ abstract class CommitHandler( def handleGetReducerFileGroup( context: RpcCallContext, shuffleId: Int, + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean, + omitMapAttempts: Boolean, serdeVersion: SerdeVersion): Unit /** diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala index 04701f2f478..f2957d8d45a 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala @@ -360,6 +360,10 @@ class MapPartitionCommitHandler( override def handleGetReducerFileGroup( context: RpcCallContext, shuffleId: Int, + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean, + omitMapAttempts: Boolean, serdeVersion: SerdeVersion): Unit = { // TODO: if support the downstream map task start early before the upstream reduce task, it should // waiting the upstream task register shuffle, then reply these GetReducerFileGroup. @@ -368,14 +372,28 @@ class MapPartitionCommitHandler( // we need obtain the last succeed partitionIds val lastSucceedPartitionIds = shuffleSucceedPartitionIds.getOrDefault(shuffleId, new util.HashSet[Integer]()) - val succeedPartitionIds = new util.HashSet[Integer](lastSucceedPartitionIds) + val allFileGroups = + reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()) + val fileGroups = ReducerFileGroupFilter.fileGroupsForRange( + allFileGroups, + startPartition, + endPartition, + hasPartitionRange) + val succeedPartitionIds = ReducerFileGroupFilter.partitionIdsForRange( + lastSucceedPartitionIds, + startPartition, + endPartition, + hasPartitionRange) context.reply(GetReducerFileGroupResponse( StatusCode.SUCCESS, - reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()), - getMapperAttempts(shuffleId), + fileGroups, + if (omitMapAttempts) Array.emptyIntArray else getMapperAttempts(shuffleId), succeedPartitionIds, - serdeVersion = serdeVersion)) + serdeVersion = serdeVersion, + startPartition = startPartition, + endPartition = endPartition, + hasPartitionRange = hasPartitionRange)) } override def releasePartitionResource(shuffleId: Int, partitionId: Int): Unit = { diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index b3e7aa90ab5..8d7f9c6a190 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -69,7 +69,13 @@ class ReducePartitionCommitHandler( commitRetryScheduler) with Logging { - class MultiSerdeVersionRpcContext(val ctx: RpcCallContext, val serdeVersion: SerdeVersion) {} + class MultiSerdeVersionRpcContext( + val ctx: RpcCallContext, + val serdeVersion: SerdeVersion, + val startPartition: Int, + val endPartition: Int, + val hasPartitionRange: Boolean, + val omitMapAttempts: Boolean) {} private val getReducerFileGroupRequest = JavaUtils.newConcurrentHashMap[Int, util.Set[MultiSerdeVersionRpcContext]]() @@ -456,12 +462,65 @@ class ReducePartitionCommitHandler( private def replyGetReducerFileGroup( context: MultiSerdeVersionRpcContext, shuffleId: Int): Unit = { - replyGetReducerFileGroup(context.ctx, shuffleId, context.serdeVersion) + replyGetReducerFileGroup( + context.ctx, + shuffleId, + context.startPartition, + context.endPartition, + context.hasPartitionRange, + context.omitMapAttempts, + context.serdeVersion) + } + + private def buildGetReducerFileGroupResponse( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean, + omitMapAttempts: Boolean, + serdeVersion: SerdeVersion): GetReducerFileGroupResponse = { + val allFileGroups = + reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()) + val allPushFailedBatches = shufflePushFailedBatches.getOrDefault( + shuffleId, + new util.HashMap[String, LocationPushFailedBatches]()) + + if (!hasPartitionRange) { + return GetReducerFileGroupResponse( + StatusCode.SUCCESS, + allFileGroups, + getMapperAttempts(shuffleId), + pushFailedBatches = allPushFailedBatches, + serdeVersion = serdeVersion) + } + + val fileGroups = ReducerFileGroupFilter.fileGroupsForRange( + allFileGroups, + startPartition, + endPartition, + hasPartitionRange) + val pushFailedBatches = ReducerFileGroupFilter.pushFailedBatchesForFileGroups( + fileGroups, + allPushFailedBatches) + + GetReducerFileGroupResponse( + StatusCode.SUCCESS, + fileGroups, + if (omitMapAttempts) Array.emptyIntArray else getMapperAttempts(shuffleId), + pushFailedBatches = pushFailedBatches, + serdeVersion = serdeVersion, + startPartition = startPartition, + endPartition = endPartition, + hasPartitionRange = hasPartitionRange) } private def replyGetReducerFileGroup( context: RpcCallContext, shuffleId: Int, + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean, + omitMapAttempts: Boolean, serdeVersion: SerdeVersion): Unit = { if (isStageDataLost(shuffleId)) { context.reply( @@ -469,36 +528,51 @@ class ReducePartitionCommitHandler( StatusCode.SHUFFLE_DATA_LOST, JavaUtils.newConcurrentHashMap(), Array.empty, - new util.HashSet[Integer]())) + new util.HashSet[Integer](), + serdeVersion = serdeVersion)) } else { // LocalNettyRpcCallContext is for the UTs if (context.isInstanceOf[LocalNettyRpcCallContext]) { - var response = GetReducerFileGroupResponse( - StatusCode.SUCCESS, - reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()), - getMapperAttempts(shuffleId), - serdeVersion = serdeVersion) + var response = buildGetReducerFileGroupResponse( + shuffleId, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts, + serdeVersion) // only check whether broadcast enabled for the UTs - if (getReducerFileGroupResponseBroadcastEnabled) { + if (getReducerFileGroupResponseBroadcastEnabled && !hasPartitionRange) { response = broadcastGetReducerFileGroup(shuffleId, response) } context.reply(response) + } else if (hasPartitionRange) { + val returnedMsg = buildGetReducerFileGroupResponse( + shuffleId, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts, + serdeVersion) + val serializedMsg = + context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg) + logDebug( + s"Shuffle $shuffleId reducer range [$startPartition, $endPartition) " + + s"GetReducerFileGroupResponse size ${serializedMsg.capacity()}") + context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(serializedMsg) } else { val cachedMsg = getReducerFileGroupRpcCache.get( shuffleId, new Callable[ByteBuffer]() { override def call(): ByteBuffer = { - val returnedMsg = GetReducerFileGroupResponse( - StatusCode.SUCCESS, - reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()), - getMapperAttempts(shuffleId), - pushFailedBatches = - shufflePushFailedBatches.getOrDefault( - shuffleId, - new util.HashMap[String, LocationPushFailedBatches]()), - serdeVersion = serdeVersion) + val returnedMsg = buildGetReducerFileGroupResponse( + shuffleId, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts, + serdeVersion) val serializedMsg = context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg) @@ -545,19 +619,41 @@ class ReducePartitionCommitHandler( override def handleGetReducerFileGroup( context: RpcCallContext, shuffleId: Int, + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean, + omitMapAttempts: Boolean, serdeVersion: SerdeVersion): Unit = { // Quick return for ended stage, avoid occupy sync lock. if (isStageEnd(shuffleId)) { - replyGetReducerFileGroup(context, shuffleId, serdeVersion) + replyGetReducerFileGroup( + context, + shuffleId, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts, + serdeVersion) } else { getReducerFileGroupRequest.synchronized { // If setStageEnd() called after isStageEnd and before got lock, should reply here. if (isStageEnd(shuffleId)) { - replyGetReducerFileGroup(context, shuffleId, serdeVersion) + replyGetReducerFileGroup( + context, + shuffleId, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts, + serdeVersion) } else { getReducerFileGroupRequest.get(shuffleId).add(new MultiSerdeVersionRpcContext( context, - serdeVersion)) + serdeVersion, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts)) } } } diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducerFileGroupFilter.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducerFileGroupFilter.scala new file mode 100644 index 00000000000..e8d9d0b38ec --- /dev/null +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducerFileGroupFilter.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.commit + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.celeborn.common.protocol.PartitionLocation +import org.apache.celeborn.common.write.LocationPushFailedBatches + +private[celeborn] object ReducerFileGroupFilter { + + def fileGroupsForRange( + allFileGroups: util.Map[Integer, util.Set[PartitionLocation]], + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean): util.Map[Integer, util.Set[PartitionLocation]] = { + if (!hasPartitionRange) { + return allFileGroups + } + + val fileGroups = new util.HashMap[Integer, util.Set[PartitionLocation]]() + var partitionId = startPartition + while (partitionId < endPartition) { + val locations = allFileGroups.get(partitionId) + if (locations != null) { + fileGroups.put(partitionId, locations) + } + partitionId += 1 + } + fileGroups + } + + def partitionIdsForRange( + allPartitionIds: util.Set[Integer], + startPartition: Int, + endPartition: Int, + hasPartitionRange: Boolean): util.Set[Integer] = { + if (!hasPartitionRange) { + return new util.HashSet[Integer](allPartitionIds) + } + + val partitionIds = new util.HashSet[Integer]() + var partitionId = startPartition + while (partitionId < endPartition) { + if (allPartitionIds.contains(partitionId)) { + partitionIds.add(partitionId) + } + partitionId += 1 + } + partitionIds + } + + def pushFailedBatchesForFileGroups( + fileGroups: util.Map[Integer, util.Set[PartitionLocation]], + allPushFailedBatches: util.Map[String, LocationPushFailedBatches]) + : util.Map[String, LocationPushFailedBatches] = { + val pushFailedBatches = new util.HashMap[String, LocationPushFailedBatches]() + fileGroups.values().asScala.foreach(_.asScala.foreach { location => + val failedBatches = allPushFailedBatches.get(location.getUniqueId) + if (failedBatches != null) { + pushFailedBatches.put(location.getUniqueId, failedBatches) + } + }) + pushFailedBatches + } +} diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java index 74fe6379c11..bd9e840e81e 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java @@ -29,8 +29,10 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import scala.reflect.ClassTag; @@ -57,6 +59,7 @@ import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.common.protocol.PbReadReducerPartitionEnd; import org.apache.celeborn.common.protocol.PbReadReducerPartitionEndResponse; +import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroup; import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse$; import org.apache.celeborn.common.protocol.message.ControlMessages.RegisterShuffleResponse$; import org.apache.celeborn.common.protocol.message.StatusCode; @@ -110,6 +113,41 @@ public class ShuffleClientSuiteJ { private static final byte[] TEST_BUF1 = "hello world".getBytes(StandardCharsets.UTF_8); private final int BATCH_HEADER_SIZE = 4 * 4; + private static void assertWaitingOnInFlightLoad(Thread thread) throws InterruptedException { + long deadlineNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + Thread.State state = thread.getState(); + while (System.nanoTime() < deadlineNanos) { + state = thread.getState(); + if (state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING) { + return; + } + Thread.sleep(10); + } + Assert.fail("Expected thread to wait on the in-flight reducer file group load, state=" + state); + } + + private static void assertThreadBlocked(Thread thread) throws InterruptedException { + long deadlineNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + Thread.State state = thread.getState(); + while (System.nanoTime() < deadlineNanos) { + state = thread.getState(); + if (state == Thread.State.BLOCKED) { + return; + } + Thread.sleep(10); + } + Assert.fail("Expected thread to block on reducer file group publication, state=" + state); + } + + private static void awaitLatch(CountDownLatch latch) { + try { + Assert.assertTrue(latch.await(10, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } + } + @Test public void testPushData() throws IOException, InterruptedException { for (CompressionCodec codec : CompressionCodec.values()) { @@ -439,7 +477,10 @@ public void testUpdateReducerFileGroupInterrupted() throws InterruptedException Collections.emptySet(), Collections.emptyMap(), new byte[0], - SerdeVersion.V1); + SerdeVersion.V1, + 0, + 0, + false); }); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) @@ -453,7 +494,10 @@ public void testUpdateReducerFileGroupInterrupted() throws InterruptedException Collections.emptySet(), Collections.emptyMap(), new byte[0], - SerdeVersion.V1); + SerdeVersion.V1, + 0, + 0, + false); }); shuffleClient = @@ -484,13 +528,15 @@ public void run() { } @Test - public void testUpdateReducerFileGroupNonFetchFailureExceptions() { + public void testUpdateReducerFileGroupNonFetchFailureExceptions() throws CelebornIOException { CelebornConf conf = new CelebornConf(); conf.set("celeborn.client.spark.stageRerun.enabled", "true"); Map> locations = new HashMap<>(); + AtomicInteger unregisteredRpcCalls = new AtomicInteger(); when(endpointRef.askSync(any(), any(), any())) .thenAnswer( t -> { + unregisteredRpcCalls.incrementAndGet(); return GetReducerFileGroupResponse$.MODULE$.apply( StatusCode.SHUFFLE_UNREGISTERED, locations, @@ -498,12 +544,16 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { Collections.emptySet(), Collections.emptyMap(), new byte[0], - SerdeVersion.V1); + SerdeVersion.V1, + 0, + 0, + false); }); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) .thenAnswer( t -> { + unregisteredRpcCalls.incrementAndGet(); return GetReducerFileGroupResponse$.MODULE$.apply( StatusCode.SHUFFLE_UNREGISTERED, locations, @@ -511,18 +561,19 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { Collections.emptySet(), Collections.emptyMap(), new byte[0], - SerdeVersion.V1); + SerdeVersion.V1, + 0, + 0, + false); }); shuffleClient = new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); shuffleClient.setupLifecycleManagerRef(endpointRef); - try { - shuffleClient.updateFileGroup(0, 0); - } catch (CelebornIOException e) { - Assert.assertTrue(e.getCause() == null); - } + Assert.assertNotNull(shuffleClient.updateFileGroup(0, 0)); + Assert.assertNotNull(shuffleClient.updateFileGroup(0, 1)); + Assert.assertEquals(1, unregisteredRpcCalls.get()); when(endpointRef.askSync(any(), any(), any())) .thenAnswer( @@ -534,7 +585,10 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { Collections.emptySet(), Collections.emptyMap(), new byte[0], - SerdeVersion.V1); + SerdeVersion.V1, + 0, + 0, + false); }); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) @@ -547,18 +601,19 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { Collections.emptySet(), Collections.emptyMap(), new byte[0], - SerdeVersion.V1); + SerdeVersion.V1, + 0, + 0, + false); }); shuffleClient = new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); shuffleClient.setupLifecycleManagerRef(endpointRef); - try { - shuffleClient.updateFileGroup(0, 0); - } catch (CelebornIOException e) { - Assert.assertTrue(e.getCause() == null); - } + CelebornIOException stageEndTimeout = + Assert.assertThrows(CelebornIOException.class, () -> shuffleClient.updateFileGroup(0, 0)); + Assert.assertNull(stageEndTimeout.getCause()); when(endpointRef.askSync(any(), any(), any())) .thenAnswer( @@ -570,7 +625,10 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { Collections.emptySet(), Collections.emptyMap(), new byte[0], - SerdeVersion.V1); + SerdeVersion.V1, + 0, + 0, + false); }); when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) @@ -583,18 +641,19 @@ public void testUpdateReducerFileGroupNonFetchFailureExceptions() { Collections.emptySet(), Collections.emptyMap(), new byte[0], - SerdeVersion.V1); + SerdeVersion.V1, + 0, + 0, + false); }); shuffleClient = new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); shuffleClient.setupLifecycleManagerRef(endpointRef); - try { - shuffleClient.updateFileGroup(0, 0); - } catch (CelebornIOException e) { - Assert.assertTrue(e.getCause() == null); - } + CelebornIOException shuffleDataLost = + Assert.assertThrows(CelebornIOException.class, () -> shuffleClient.updateFileGroup(0, 0)); + Assert.assertNull(shuffleDataLost.getCause()); } @Test @@ -806,4 +865,1083 @@ public void testComputeBatchCRCAttemptIdConsistency() { int[] crc1 = pushState1.getCRC32PerPartition(true, 2); assertEquals(0, crc1[0]); } + + @Test + public void testUpdateReducerFileGroupEmptyRangeDoesNotLoadMetadata() throws Exception { + CelebornConf conf = new CelebornConf(); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + rpcCalls.incrementAndGet(); + return null; + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + ShuffleClientImpl.ReduceFileGroups fileGroups = shuffleClient.updateFileGroup(7, 2, 2); + + Assert.assertTrue(fileGroups.partitionGroups.isEmpty()); + Assert.assertNull(fileGroups.mapAttempts); + Assert.assertTrue(fileGroups.partitionIds.isEmpty()); + Assert.assertTrue(fileGroups.pushFailedBatches.isEmpty()); + Assert.assertEquals(0, rpcCalls.get()); + Assert.assertFalse(shuffleClient.hasReducerFileGroupRangeCache(7)); + } + + @Test + public void testLegacyUpdateReducerFileGroupOverloadPreservesSubclassDispatch() throws Exception { + CelebornConf conf = new CelebornConf(); + AtomicInteger calls = new AtomicInteger(); + ShuffleClientImpl.ReduceFileGroups expected = new ShuffleClientImpl.ReduceFileGroups(); + ShuffleClientImpl overriddenClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")) { + @Override + public ShuffleClientImpl.ReduceFileGroups updateFileGroup( + int shuffleId, int partitionId, boolean isSegmentGranularityVisible) { + Assert.assertEquals(7, shuffleId); + Assert.assertEquals(2, partitionId); + Assert.assertFalse(isSegmentGranularityVisible); + calls.incrementAndGet(); + return expected; + } + }; + + try { + Assert.assertSame(expected, overriddenClient.updateFileGroup(7, 2)); + Assert.assertEquals(1, calls.get()); + } finally { + overriddenClient.shutdown(); + } + } + + @Test + public void testUpdateReducerFileGroupConcurrentLoadsShareSingleRpc() throws Exception { + CelebornConf conf = new CelebornConf(); + Map> locations = new HashMap<>(); + CountDownLatch rpcStarted = new CountDownLatch(1); + CountDownLatch releaseRpc = new CountDownLatch(1); + CountDownLatch secondThreadStarted = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + GetReducerFileGroup request = invocation.getArgument(0); + rpcCalls.incrementAndGet(); + rpcStarted.countDown(); + releaseRpc.await(10, TimeUnit.SECONDS); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + request.startPartition(), + request.endPartition(), + true); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference firstResult = new AtomicReference<>(); + AtomicReference secondResult = new AtomicReference<>(); + AtomicReference firstException = new AtomicReference<>(); + AtomicReference secondException = new AtomicReference<>(); + + Thread firstThread = + new Thread( + () -> { + try { + firstResult.set(shuffleClient.updateFileGroup(0, 0)); + } catch (Exception e) { + firstException.set(e); + } + }); + Thread secondThread = + new Thread( + () -> { + secondThreadStarted.countDown(); + try { + secondResult.set(shuffleClient.updateFileGroup(0, 0)); + } catch (Exception e) { + secondException.set(e); + } + }); + + firstThread.start(); + Assert.assertTrue(rpcStarted.await(10, TimeUnit.SECONDS)); + secondThread.start(); + Assert.assertTrue(secondThreadStarted.await(10, TimeUnit.SECONDS)); + assertWaitingOnInFlightLoad(secondThread); + + Assert.assertEquals(1, rpcCalls.get()); + + releaseRpc.countDown(); + firstThread.join(10 * 1000); + secondThread.join(10 * 1000); + + Assert.assertFalse(firstThread.isAlive()); + Assert.assertFalse(secondThread.isAlive()); + Assert.assertNull(firstException.get()); + Assert.assertNull(secondException.get()); + Assert.assertNotNull(firstResult.get()); + Assert.assertNotNull(secondResult.get()); + Assert.assertSame(firstResult.get(), secondResult.get()); + } + + @Test + public void testWaitingReducerFileGroupLoadCanBeInterrupted() throws Exception { + CelebornConf conf = new CelebornConf(); + Map> locations = new HashMap<>(); + CountDownLatch rpcStarted = new CountDownLatch(1); + CountDownLatch releaseRpc = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + GetReducerFileGroup request = invocation.getArgument(0); + rpcCalls.incrementAndGet(); + rpcStarted.countDown(); + releaseRpc.await(10, TimeUnit.SECONDS); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + request.startPartition(), + request.endPartition(), + true); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference ownerException = new AtomicReference<>(); + AtomicReference waiterException = new AtomicReference<>(); + Thread owner = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Exception e) { + ownerException.set(e); + } + }); + Thread waiter = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Exception e) { + waiterException.set(e); + } + }); + + owner.start(); + Assert.assertTrue(rpcStarted.await(10, TimeUnit.SECONDS)); + waiter.start(); + assertWaitingOnInFlightLoad(waiter); + + waiter.interrupt(); + waiter.join(10 * 1000); + + Assert.assertFalse(waiter.isAlive()); + Assert.assertTrue(owner.isAlive()); + Assert.assertTrue(waiterException.get() instanceof CelebornIOException); + Assert.assertTrue(waiterException.get().getCause() instanceof InterruptedException); + Assert.assertEquals(1, rpcCalls.get()); + + releaseRpc.countDown(); + owner.join(10 * 1000); + Assert.assertFalse(owner.isAlive()); + Assert.assertNull(ownerException.get()); + } + + @Test + public void testInterruptedReducerFileGroupOwnerDoesNotFailExactRangeWaiter() throws Exception { + assertInterruptedReducerFileGroupOwnerDoesNotFailWaiter(0); + } + + @Test + public void testInterruptedReducerFileGroupOwnerDoesNotFailColdRangeWaiter() throws Exception { + assertInterruptedReducerFileGroupOwnerDoesNotFailWaiter(1); + } + + private void assertInterruptedReducerFileGroupOwnerDoesNotFailWaiter(int waiterPartition) + throws Exception { + CelebornConf conf = new CelebornConf(); + CountDownLatch firstRpcStarted = new CountDownLatch(1); + CountDownLatch releaseFirstRpc = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + AtomicReference retryRequest = new AtomicReference<>(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + GetReducerFileGroup request = invocation.getArgument(0); + if (rpcCalls.incrementAndGet() == 1) { + firstRpcStarted.countDown(); + releaseFirstRpc.await(10, TimeUnit.SECONDS); + } else { + retryRequest.set(request); + } + Map> locations = new HashMap<>(); + locations.put(request.startPartition(), Collections.emptySet()); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + new int[] {0}, + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + request.startPartition(), + request.endPartition(), + true); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference ownerException = new AtomicReference<>(); + AtomicReference waiterResult = new AtomicReference<>(); + AtomicReference waiterException = new AtomicReference<>(); + Thread owner = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Throwable e) { + ownerException.set(e); + } + }); + Thread waiter = + new Thread( + () -> { + try { + waiterResult.set(shuffleClient.updateFileGroup(0, waiterPartition)); + } catch (Throwable e) { + waiterException.set(e); + } + }); + + owner.start(); + Assert.assertTrue(firstRpcStarted.await(10, TimeUnit.SECONDS)); + waiter.start(); + assertWaitingOnInFlightLoad(waiter); + + owner.interrupt(); + owner.join(10 * 1000); + waiter.join(10 * 1000); + + Assert.assertFalse(owner.isAlive()); + Assert.assertFalse(waiter.isAlive()); + Assert.assertTrue(ownerException.get() instanceof CelebornIOException); + Assert.assertTrue(ownerException.get().getCause() instanceof InterruptedException); + Assert.assertNull(waiterException.get()); + Assert.assertNotNull(waiterResult.get()); + Assert.assertEquals( + Collections.singleton(waiterPartition), waiterResult.get().partitionGroups.keySet()); + Assert.assertEquals(2, rpcCalls.get()); + Assert.assertEquals(waiterPartition, retryRequest.get().startPartition()); + Assert.assertEquals(waiterPartition + 1, retryRequest.get().endPartition()); + Assert.assertFalse(retryRequest.get().omitMapAttempts()); + } + + @Test + public void testUpdateReducerFileGroupWarmDifferentRangesDoNotWaitForEachOther() + throws Exception { + CelebornConf conf = new CelebornConf(); + Map> locations = new HashMap<>(); + int[] mapAttempts = new int[] {0}; + CountDownLatch slowRpcStarted = new CountDownLatch(1); + CountDownLatch releaseSlowRpc = new CountDownLatch(1); + CountDownLatch fastRpcFinished = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + GetReducerFileGroup request = invocation.getArgument(0); + rpcCalls.incrementAndGet(); + Assert.assertTrue(request.hasPartitionRange()); + if (request.startPartition() == 2) { + Assert.assertEquals(3, request.endPartition()); + Assert.assertFalse(request.omitMapAttempts()); + } else if (request.startPartition() == 0) { + Assert.assertEquals(1, request.endPartition()); + Assert.assertTrue(request.omitMapAttempts()); + slowRpcStarted.countDown(); + releaseSlowRpc.await(10, TimeUnit.SECONDS); + } else { + Assert.assertEquals(1, request.startPartition()); + Assert.assertEquals(2, request.endPartition()); + Assert.assertTrue(request.omitMapAttempts()); + fastRpcFinished.countDown(); + } + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + request.omitMapAttempts() ? new int[0] : mapAttempts, + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + request.startPartition(), + request.endPartition(), + true); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + shuffleClient.updateFileGroup(0, 2); + Assert.assertEquals(1, rpcCalls.get()); + + AtomicReference slowException = new AtomicReference<>(); + AtomicReference fastException = new AtomicReference<>(); + Thread slowThread = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Exception e) { + slowException.set(e); + } + }); + Thread fastThread = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 1); + } catch (Exception e) { + fastException.set(e); + } + }); + + slowThread.start(); + Assert.assertTrue(slowRpcStarted.await(10, TimeUnit.SECONDS)); + fastThread.start(); + Assert.assertTrue(fastRpcFinished.await(10, TimeUnit.SECONDS)); + fastThread.join(10 * 1000); + + Assert.assertFalse(fastThread.isAlive()); + Assert.assertTrue(slowThread.isAlive()); + Assert.assertNull(fastException.get()); + Assert.assertEquals(3, rpcCalls.get()); + + releaseSlowRpc.countDown(); + slowThread.join(10 * 1000); + Assert.assertFalse(slowThread.isAlive()); + Assert.assertNull(slowException.get()); + } + + @Test + public void testUpdateReducerFileGroupRequestsAndCachesOnlyNeededRange() throws Exception { + CelebornConf conf = new CelebornConf(); + AtomicInteger rpcCalls = new AtomicInteger(); + AtomicReference firstRequest = new AtomicReference<>(); + AtomicReference secondRequest = new AtomicReference<>(); + CountDownLatch firstRpcStarted = new CountDownLatch(1); + CountDownLatch releaseFirstRpc = new CountDownLatch(1); + int[] mapAttempts = new int[] {0, 1, 2}; + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + GetReducerFileGroup request = invocation.getArgument(0); + int call = rpcCalls.getAndIncrement(); + if (call == 0) { + firstRequest.set(request); + firstRpcStarted.countDown(); + releaseFirstRpc.await(10, TimeUnit.SECONDS); + } else if (call == 1) { + secondRequest.set(request); + } + Map> locations = new HashMap<>(); + for (int partitionId = request.startPartition(); + partitionId < request.endPartition(); + partitionId++) { + locations.put(partitionId, Collections.emptySet()); + } + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + request.omitMapAttempts() ? new int[0] : mapAttempts, + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + request.startPartition(), + request.endPartition(), + true); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference firstResult = new AtomicReference<>(); + AtomicReference adjacentResult = new AtomicReference<>(); + AtomicReference firstException = new AtomicReference<>(); + AtomicReference adjacentException = new AtomicReference<>(); + Thread firstThread = + new Thread( + () -> { + try { + firstResult.set(shuffleClient.updateFileGroup(7, 10, 20)); + } catch (Exception e) { + firstException.set(e); + } + }); + Thread adjacentThread = + new Thread( + () -> { + try { + adjacentResult.set(shuffleClient.updateFileGroup(7, 20, 21)); + } catch (Exception e) { + adjacentException.set(e); + } + }); + + firstThread.start(); + Assert.assertTrue(firstRpcStarted.await(10, TimeUnit.SECONDS)); + adjacentThread.start(); + assertWaitingOnInFlightLoad(adjacentThread); + + Assert.assertEquals(1, rpcCalls.get()); + releaseFirstRpc.countDown(); + firstThread.join(10 * 1000); + adjacentThread.join(10 * 1000); + + Assert.assertFalse(firstThread.isAlive()); + Assert.assertFalse(adjacentThread.isAlive()); + Assert.assertNull(firstException.get()); + Assert.assertNull(adjacentException.get()); + + ShuffleClientImpl.ReduceFileGroups first = firstResult.get(); + ShuffleClientImpl.ReduceFileGroups adjacent = adjacentResult.get(); + ShuffleClientImpl.ReduceFileGroups nested = shuffleClient.updateFileGroup(7, 11, 12); + + Assert.assertEquals(2, rpcCalls.get()); + Assert.assertEquals(10, first.partitionGroups.size()); + Assert.assertEquals(Collections.singleton(11), nested.partitionGroups.keySet()); + Assert.assertTrue(firstRequest.get().hasPartitionRange()); + Assert.assertEquals(10, firstRequest.get().startPartition()); + Assert.assertEquals(20, firstRequest.get().endPartition()); + Assert.assertFalse(firstRequest.get().omitMapAttempts()); + Assert.assertArrayEquals(mapAttempts, first.mapAttempts); + + Assert.assertEquals(Collections.singleton(20), adjacent.partitionGroups.keySet()); + Assert.assertTrue(secondRequest.get().omitMapAttempts()); + Assert.assertArrayEquals(mapAttempts, adjacent.mapAttempts); + } + + @Test + public void testLegacyFullResponseIsCachedForAllRanges() throws Exception { + CelebornConf conf = new CelebornConf(); + AtomicInteger rpcCalls = new AtomicInteger(); + Map> fullResponse = new HashMap<>(); + fullResponse.put(0, Collections.emptySet()); + fullResponse.put(100, Collections.emptySet()); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + rpcCalls.incrementAndGet(); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + fullResponse, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + 0, + 0, + false); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + shuffleClient.updateFileGroup(7, 0, 1); + ShuffleClientImpl.ReduceFileGroups second = shuffleClient.updateFileGroup(7, 100, 101); + + Assert.assertEquals(1, rpcCalls.get()); + Assert.assertEquals(Collections.singleton(100), second.partitionGroups.keySet()); + } + + @Test + public void testUpdateReducerFileGroupConcurrentFailuresAreSharedAndNextCallRetries() + throws Exception { + CelebornConf conf = new CelebornConf(); + Map> locations = new HashMap<>(); + CountDownLatch rpcStarted = new CountDownLatch(1); + CountDownLatch releaseFailure = new CountDownLatch(1); + CountDownLatch secondThreadStarted = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + if (rpcCalls.incrementAndGet() == 1) { + rpcStarted.countDown(); + releaseFailure.await(10, TimeUnit.SECONDS); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.STAGE_END_TIMEOUT, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + 0, + 0, + false); + } + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + 0, + 0, + false); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference firstException = new AtomicReference<>(); + AtomicReference secondResult = new AtomicReference<>(); + AtomicReference secondException = new AtomicReference<>(); + + Thread firstThread = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Exception e) { + firstException.set(e); + } + }); + Thread secondThread = + new Thread( + () -> { + secondThreadStarted.countDown(); + try { + secondResult.set(shuffleClient.updateFileGroup(0, 0)); + } catch (Exception e) { + secondException.set(e); + } + }); + + firstThread.start(); + Assert.assertTrue(rpcStarted.await(10, TimeUnit.SECONDS)); + secondThread.start(); + Assert.assertTrue(secondThreadStarted.await(10, TimeUnit.SECONDS)); + assertWaitingOnInFlightLoad(secondThread); + Assert.assertEquals(1, rpcCalls.get()); + + releaseFailure.countDown(); + firstThread.join(10 * 1000); + secondThread.join(10 * 1000); + + Assert.assertFalse(firstThread.isAlive()); + Assert.assertFalse(secondThread.isAlive()); + Assert.assertTrue(firstException.get() instanceof CelebornIOException); + Assert.assertTrue(secondException.get() instanceof CelebornIOException); + Assert.assertNull(secondResult.get()); + Assert.assertEquals(1, rpcCalls.get()); + + Assert.assertNotNull(shuffleClient.updateFileGroup(0, 0)); + Assert.assertEquals(2, rpcCalls.get()); + } + + @Test + public void testCleanupShuffleDoesNotRestoreReducerFileGroupCache() throws Exception { + CelebornConf conf = new CelebornConf(); + Map> locations = new HashMap<>(); + CountDownLatch rpcStarted = new CountDownLatch(1); + CountDownLatch releaseRpc = new CountDownLatch(1); + CountDownLatch cleanupFinished = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + if (rpcCalls.incrementAndGet() == 1) { + rpcStarted.countDown(); + releaseRpc.await(10, TimeUnit.SECONDS); + } + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + 0, + 0, + false); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference firstException = new AtomicReference<>(); + Thread firstThread = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Exception e) { + firstException.set(e); + } + }); + Thread cleanupThread = + new Thread( + () -> { + shuffleClient.cleanupShuffle(0); + cleanupFinished.countDown(); + }); + + firstThread.start(); + Assert.assertTrue(rpcStarted.await(10, TimeUnit.SECONDS)); + cleanupThread.start(); + Assert.assertTrue(cleanupFinished.await(10, TimeUnit.SECONDS)); + releaseRpc.countDown(); + firstThread.join(10 * 1000); + cleanupThread.join(10 * 1000); + + Assert.assertFalse(firstThread.isAlive()); + Assert.assertFalse(cleanupThread.isAlive()); + Assert.assertTrue(firstException.get() instanceof CelebornIOException); + Assert.assertFalse(shuffleClient.hasReducerFileGroupRangeCache(0)); + + Assert.assertNotNull(shuffleClient.updateFileGroup(0, 0)); + Assert.assertEquals(2, rpcCalls.get()); + } + + @Test + public void testCleanupShuffleFailsWaitingReducerFileGroupLoads() throws Exception { + CelebornConf conf = new CelebornConf(); + Map> locations = new HashMap<>(); + CountDownLatch rpcStarted = new CountDownLatch(1); + CountDownLatch releaseRpc = new CountDownLatch(1); + CountDownLatch secondThreadStarted = new CountDownLatch(1); + CountDownLatch cleanupFinished = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + if (rpcCalls.incrementAndGet() == 1) { + rpcStarted.countDown(); + releaseRpc.await(10, TimeUnit.SECONDS); + } + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + 0, + 0, + false); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference firstException = new AtomicReference<>(); + AtomicReference secondResult = new AtomicReference<>(); + AtomicReference secondException = new AtomicReference<>(); + Thread firstThread = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Exception e) { + firstException.set(e); + } + }); + Thread secondThread = + new Thread( + () -> { + secondThreadStarted.countDown(); + try { + secondResult.set(shuffleClient.updateFileGroup(0, 0)); + } catch (Exception e) { + secondException.set(e); + } + }); + Thread cleanupThread = + new Thread( + () -> { + shuffleClient.cleanupShuffle(0); + cleanupFinished.countDown(); + }); + + firstThread.start(); + Assert.assertTrue(rpcStarted.await(10, TimeUnit.SECONDS)); + secondThread.start(); + Assert.assertTrue(secondThreadStarted.await(10, TimeUnit.SECONDS)); + assertWaitingOnInFlightLoad(secondThread); + cleanupThread.start(); + Assert.assertTrue(cleanupFinished.await(10, TimeUnit.SECONDS)); + + releaseRpc.countDown(); + firstThread.join(10 * 1000); + secondThread.join(10 * 1000); + cleanupThread.join(10 * 1000); + + Assert.assertFalse(firstThread.isAlive()); + Assert.assertFalse(secondThread.isAlive()); + Assert.assertFalse(cleanupThread.isAlive()); + Assert.assertTrue(firstException.get() instanceof CelebornIOException); + Assert.assertTrue(secondException.get() instanceof CelebornIOException); + Assert.assertNull(secondResult.get()); + Assert.assertEquals(1, rpcCalls.get()); + + Assert.assertNotNull(shuffleClient.updateFileGroup(0, 0)); + Assert.assertEquals(2, rpcCalls.get()); + } + + @Test + public void testColdBootstrapWaiterDoesNotRecreateCacheAfterCleanup() throws Exception { + CelebornConf conf = new CelebornConf(); + CountDownLatch firstRpcStarted = new CountDownLatch(1); + CountDownLatch releaseFirstRpc = new CountDownLatch(1); + CountDownLatch publicationStarted = new CountDownLatch(1); + CountDownLatch releasePublication = new CountDownLatch(1); + CountDownLatch cleanupFinished = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + GetReducerFileGroup request = invocation.getArgument(0); + if (rpcCalls.incrementAndGet() == 1) { + firstRpcStarted.countDown(); + releaseFirstRpc.await(10, TimeUnit.SECONDS); + } + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + Collections.emptyMap(), + new int[] {0}, + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + request.startPartition(), + request.endPartition(), + true); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + shuffleClient.reduceFileGroupsBeforeInFlightRelease = + () -> { + publicationStarted.countDown(); + awaitLatch(releasePublication); + }; + + AtomicReference ownerException = new AtomicReference<>(); + AtomicReference waiterResult = new AtomicReference<>(); + AtomicReference waiterException = new AtomicReference<>(); + Thread owner = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Exception e) { + ownerException.set(e); + } + }); + Thread waiter = + new Thread( + () -> { + try { + waiterResult.set(shuffleClient.updateFileGroup(0, 1)); + } catch (Exception e) { + waiterException.set(e); + } + }); + Thread cleanupThread = + new Thread( + () -> { + shuffleClient.cleanupShuffle(0); + cleanupFinished.countDown(); + }); + + owner.start(); + Assert.assertTrue(firstRpcStarted.await(10, TimeUnit.SECONDS)); + waiter.start(); + assertWaitingOnInFlightLoad(waiter); + releaseFirstRpc.countDown(); + Assert.assertTrue(publicationStarted.await(10, TimeUnit.SECONDS)); + + cleanupThread.start(); + assertThreadBlocked(cleanupThread); + Assert.assertFalse(shuffleClient.hasReducerFileGroupRangeCache(0)); + releasePublication.countDown(); + + owner.join(10 * 1000); + waiter.join(10 * 1000); + cleanupThread.join(10 * 1000); + + Assert.assertFalse(owner.isAlive()); + Assert.assertFalse(waiter.isAlive()); + Assert.assertFalse(cleanupThread.isAlive()); + Assert.assertTrue(cleanupFinished.await(10, TimeUnit.SECONDS)); + Assert.assertNull(ownerException.get()); + Assert.assertNull(waiterResult.get()); + Assert.assertTrue(waiterException.get() instanceof CelebornIOException); + Assert.assertEquals(1, rpcCalls.get()); + Assert.assertFalse(shuffleClient.hasReducerFileGroupRangeCache(0)); + } + + @Test + public void testUpdateReducerFileGroupRechecksCacheAfterClaimingLoad() throws Exception { + CelebornConf conf = new CelebornConf(); + Map> locations = new HashMap<>(); + CountDownLatch firstRpcStarted = new CountDownLatch(1); + CountDownLatch releaseFirstRpc = new CountDownLatch(1); + CountDownLatch secondCacheMiss = new CountDownLatch(1); + CountDownLatch releaseSecondCacheMiss = new CountDownLatch(1); + AtomicInteger cacheMisses = new AtomicInteger(); + AtomicInteger rpcCalls = new AtomicInteger(); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + if (rpcCalls.incrementAndGet() == 1) { + firstRpcStarted.countDown(); + releaseFirstRpc.await(10, TimeUnit.SECONDS); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + 0, + 0, + false); + } + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.STAGE_END_TIMEOUT, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + 0, + 0, + false); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + shuffleClient.reduceFileGroupsAfterCacheMiss = + () -> { + if (cacheMisses.incrementAndGet() == 2) { + secondCacheMiss.countDown(); + awaitLatch(releaseSecondCacheMiss); + } + }; + + AtomicReference firstResult = new AtomicReference<>(); + AtomicReference secondResult = new AtomicReference<>(); + AtomicReference firstException = new AtomicReference<>(); + AtomicReference secondException = new AtomicReference<>(); + Thread firstThread = + new Thread( + () -> { + try { + firstResult.set(shuffleClient.updateFileGroup(0, 0)); + } catch (Exception e) { + firstException.set(e); + } + }); + Thread secondThread = + new Thread( + () -> { + try { + secondResult.set(shuffleClient.updateFileGroup(0, 0)); + } catch (Exception e) { + secondException.set(e); + } + }); + + firstThread.start(); + Assert.assertTrue(firstRpcStarted.await(10, TimeUnit.SECONDS)); + secondThread.start(); + Assert.assertTrue(secondCacheMiss.await(10, TimeUnit.SECONDS)); + releaseFirstRpc.countDown(); + firstThread.join(10 * 1000); + Assert.assertFalse(firstThread.isAlive()); + + releaseSecondCacheMiss.countDown(); + secondThread.join(10 * 1000); + + Assert.assertFalse(secondThread.isAlive()); + Assert.assertNull(firstException.get()); + Assert.assertNull(secondException.get()); + Assert.assertNotNull(firstResult.get()); + Assert.assertEquals(firstResult.get().partitionGroups, secondResult.get().partitionGroups); + Assert.assertEquals(1, rpcCalls.get()); + } + + @Test + public void testCleanupShuffleWaitsForReducerFileGroupPublication() throws Exception { + CelebornConf conf = new CelebornConf(); + Map> locations = new HashMap<>(); + CountDownLatch publicationStarted = new CountDownLatch(1); + CountDownLatch releasePublication = new CountDownLatch(1); + CountDownLatch cleanupFinished = new CountDownLatch(1); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> + GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, + locations, + new int[0], + Collections.emptySet(), + Collections.emptyMap(), + new byte[0], + SerdeVersion.V1, + 0, + 0, + false)); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + shuffleClient.reduceFileGroupsBeforeInFlightRelease = + () -> { + publicationStarted.countDown(); + awaitLatch(releasePublication); + }; + + AtomicReference loadException = new AtomicReference<>(); + Thread loadThread = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Exception e) { + loadException.set(e); + } + }); + Thread cleanupThread = + new Thread( + () -> { + shuffleClient.cleanupShuffle(0); + cleanupFinished.countDown(); + }); + + loadThread.start(); + Assert.assertTrue(publicationStarted.await(10, TimeUnit.SECONDS)); + cleanupThread.start(); + assertThreadBlocked(cleanupThread); + Assert.assertFalse(cleanupFinished.await(100, TimeUnit.MILLISECONDS)); + + releasePublication.countDown(); + loadThread.join(10 * 1000); + cleanupThread.join(10 * 1000); + + Assert.assertFalse(loadThread.isAlive()); + Assert.assertFalse(cleanupThread.isAlive()); + Assert.assertNull(loadException.get()); + Assert.assertFalse(shuffleClient.hasReducerFileGroupRangeCache(0)); + } + + @Test + public void testUpdateReducerFileGroupWaitingReadersPreserveFatalErrors() throws Exception { + CelebornConf conf = new CelebornConf(); + CountDownLatch rpcStarted = new CountDownLatch(1); + CountDownLatch releaseRpc = new CountDownLatch(1); + AtomicInteger rpcCalls = new AtomicInteger(); + AssertionError fatalError = new AssertionError("test fatal error"); + + when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any())) + .thenAnswer( + invocation -> { + rpcCalls.incrementAndGet(); + rpcStarted.countDown(); + releaseRpc.await(10, TimeUnit.SECONDS); + throw fatalError; + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference firstError = new AtomicReference<>(); + AtomicReference secondError = new AtomicReference<>(); + Thread firstThread = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Throwable e) { + firstError.set(e); + } + }); + Thread secondThread = + new Thread( + () -> { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (Throwable e) { + secondError.set(e); + } + }); + + firstThread.start(); + Assert.assertTrue(rpcStarted.await(10, TimeUnit.SECONDS)); + secondThread.start(); + assertWaitingOnInFlightLoad(secondThread); + releaseRpc.countDown(); + firstThread.join(10 * 1000); + secondThread.join(10 * 1000); + + Assert.assertFalse(firstThread.isAlive()); + Assert.assertFalse(secondThread.isAlive()); + Assert.assertSame(fatalError, firstError.get()); + Assert.assertSame(fatalError, secondError.get()); + Assert.assertEquals(1, rpcCalls.get()); + } } diff --git a/client/src/test/scala/org/apache/celeborn/client/commit/ReducerFileGroupFilterSuite.scala b/client/src/test/scala/org/apache/celeborn/client/commit/ReducerFileGroupFilterSuite.scala new file mode 100644 index 00000000000..b608f86c8ce --- /dev/null +++ b/client/src/test/scala/org/apache/celeborn/client/commit/ReducerFileGroupFilterSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.commit + +import java.util + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.protocol.PartitionLocation +import org.apache.celeborn.common.protocol.PartitionLocation.Mode +import org.apache.celeborn.common.write.LocationPushFailedBatches + +class ReducerFileGroupFilterSuite extends CelebornFunSuite { + + private def location(partitionId: Int): PartitionLocation = + new PartitionLocation( + partitionId, + 0, + s"host-$partitionId", + 1000 + partitionId, + 2000 + partitionId, + 3000 + partitionId, + 4000 + partitionId, + Mode.PRIMARY) + + test("filter reducer file groups and failed batches to the requested range") { + val allFileGroups = new util.HashMap[Integer, util.Set[PartitionLocation]]() + val allPushFailedBatches = new util.HashMap[String, LocationPushFailedBatches]() + (0 until 4).foreach { partitionId => + val partitionLocation = location(partitionId) + allFileGroups.put(partitionId, util.Collections.singleton(partitionLocation)) + allPushFailedBatches.put( + partitionLocation.getUniqueId, + new LocationPushFailedBatches()) + } + + val fileGroups = ReducerFileGroupFilter.fileGroupsForRange( + allFileGroups, + startPartition = 1, + endPartition = 3, + hasPartitionRange = true) + val failedBatches = ReducerFileGroupFilter.pushFailedBatchesForFileGroups( + fileGroups, + allPushFailedBatches) + val expectedFailedBatchIds = new util.HashSet[String]() + expectedFailedBatchIds.add(fileGroups.get(1).iterator().next().getUniqueId) + expectedFailedBatchIds.add(fileGroups.get(2).iterator().next().getUniqueId) + + val expectedPartitionIds = new util.HashSet[Integer]() + expectedPartitionIds.add(1) + expectedPartitionIds.add(2) + assert(fileGroups.keySet() == expectedPartitionIds) + assert(failedBatches.keySet() == expectedFailedBatchIds) + } + + test("filter map partition success IDs and preserve the legacy full response") { + val allPartitionIds = new util.HashSet[Integer]() + allPartitionIds.add(0) + allPartitionIds.add(2) + allPartitionIds.add(4) + val ranged = ReducerFileGroupFilter.partitionIdsForRange( + allPartitionIds, + startPartition = 1, + endPartition = 4, + hasPartitionRange = true) + + assert(ranged == util.Collections.singleton(2)) + + val allFileGroups = new util.HashMap[Integer, util.Set[PartitionLocation]]() + allFileGroups.put(0, util.Collections.singleton(location(0))) + val legacy = ReducerFileGroupFilter.fileGroupsForRange( + allFileGroups, + startPartition = 0, + endPartition = 0, + hasPartitionRange = false) + assert(legacy eq allFileGroups) + } +} diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index a813a9e5015..c39aefd9bb9 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -394,6 +394,10 @@ message PbMapperEndResponse { message PbGetReducerFileGroup { int32 shuffleId = 1; bool isSegmentGranularityVisible = 2; + int32 startPartition = 3; + int32 endPartition = 4; + bool hasPartitionRange = 5; + bool omitMapAttempts = 6; } message PbGetReducerFileGroupResponse { @@ -410,6 +414,10 @@ message PbGetReducerFileGroupResponse { map pushFailedBatches = 5; bytes broadcast = 6; + + int32 startPartition = 7; + int32 endPartition = 8; + bool hasPartitionRange = 9; } message PbGetShuffleId { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index ea6b819fcfa..3505507b751 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -5748,6 +5748,8 @@ object CelebornConf extends Logging { .categories("client") .doc( "Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. " + + "This applies only to legacy requests for shuffle-wide reducer metadata; " + + "partition-range requests return their smaller task-specific response directly. " + "If the response size is large and Spark executor number is large, the Spark driver network " + "may be exhausted because each executor will pull the response from the driver. With broadcasting " + "GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple " + @@ -5759,7 +5761,8 @@ object CelebornConf extends Logging { val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE = buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize") .categories("client") - .doc("The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors.") + .doc("The size at which we use Broadcast to send a legacy shuffle-wide " + + "GetReducerFileGroupResponse to the executors.") .version("0.6.0") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("512k") diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index f1e34aa54f8..3e0e3d0382b 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -237,7 +237,11 @@ object ControlMessages extends Logging { case class GetReducerFileGroup( shuffleId: Int, isSegmentGranularityVisible: Boolean, - serdeVersion: SerdeVersion) + serdeVersion: SerdeVersion, + startPartition: Int = 0, + endPartition: Int = 0, + hasPartitionRange: Boolean = false, + omitMapAttempts: Boolean = false) extends MasterMessage // util.Set[String] -> util.Set[Path.toString] @@ -250,7 +254,10 @@ object ControlMessages extends Logging { pushFailedBatches: util.Map[String, LocationPushFailedBatches] = Collections.emptyMap(), broadcast: Array[Byte] = Array.emptyByteArray, - serdeVersion: SerdeVersion = SerdeVersion.V1) + serdeVersion: SerdeVersion = SerdeVersion.V1, + startPartition: Int = 0, + endPartition: Int = 0, + hasPartitionRange: Boolean = false) extends MasterMessage object WorkerExclude { @@ -762,11 +769,23 @@ object ControlMessages extends Logging { .build().toByteArray new TransportMessage(MessageType.MAPPER_END_RESPONSE, payload, serdeVersion) - case GetReducerFileGroup(shuffleId, isSegmentGranularityVisible, serdeVersion) => - val payload = PbGetReducerFileGroup.newBuilder() + case GetReducerFileGroup( + shuffleId, + isSegmentGranularityVisible, + serdeVersion, + startPartition, + endPartition, + hasPartitionRange, + omitMapAttempts) => + val builder = PbGetReducerFileGroup.newBuilder() .setShuffleId(shuffleId) .setIsSegmentGranularityVisible(isSegmentGranularityVisible) - .build().toByteArray + .setHasPartitionRange(hasPartitionRange) + .setOmitMapAttempts(omitMapAttempts) + if (hasPartitionRange) { + builder.setStartPartition(startPartition).setEndPartition(endPartition) + } + val payload = builder.build().toByteArray new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload, serdeVersion) case GetReducerFileGroupResponse( @@ -776,7 +795,10 @@ object ControlMessages extends Logging { partitionIds, failedBatches, broadcast, - serdeVersion) => + serdeVersion, + startPartition, + endPartition, + hasPartitionRange) => val builder = PbGetReducerFileGroupResponse .newBuilder() .setStatus(status.getValue) @@ -795,6 +817,10 @@ object ControlMessages extends Logging { (uniqueId, PbSerDeUtils.toPbLocationPushFailedBatches(pushFailedBatchSet)) }.asJava) builder.setBroadcast(ByteString.copyFrom(broadcast)) + builder.setHasPartitionRange(hasPartitionRange) + if (hasPartitionRange) { + builder.setStartPartition(startPartition).setEndPartition(endPartition) + } val payload = builder.build().toByteArray new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload, serdeVersion) @@ -1274,7 +1300,11 @@ object ControlMessages extends Logging { GetReducerFileGroup( pbGetReducerFileGroup.getShuffleId, pbGetReducerFileGroup.getIsSegmentGranularityVisible, - message.getSerdeVersion) + message.getSerdeVersion, + pbGetReducerFileGroup.getStartPartition, + pbGetReducerFileGroup.getEndPartition, + pbGetReducerFileGroup.getHasPartitionRange, + pbGetReducerFileGroup.getOmitMapAttempts) case GET_REDUCER_FILE_GROUP_RESPONSE_VALUE => val pbGetReducerFileGroupResponse = PbGetReducerFileGroupResponse @@ -1311,7 +1341,11 @@ object ControlMessages extends Logging { attempts, partitionIds, pushFailedBatches, - broadcast) + broadcast, + message.getSerdeVersion, + pbGetReducerFileGroupResponse.getStartPartition, + pbGetReducerFileGroupResponse.getEndPartition, + pbGetReducerFileGroupResponse.getHasPartitionRange) case GET_SHUFFLE_ID_VALUE => message.getParsedPayload() diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index 5b2dd6a1097..4e26fa87a30 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -1357,8 +1357,9 @@ object Utils extends Logging { try { TimeUnit.MILLISECONDS.sleep(retryWaitMs) } catch { - case _: InterruptedException => - throw e + case interrupted: InterruptedException => + Thread.currentThread().interrupt() + throw interrupted } } else { throw e diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index c83a83b95a7..5ef4289c24f 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -17,8 +17,11 @@ package org.apache.celeborn.common.util +import java.io.IOException import java.util import java.util.Collections +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import org.scalatest.matchers.must.Matchers.contain import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper @@ -30,7 +33,7 @@ import org.apache.celeborn.common.exception.CelebornException import org.apache.celeborn.common.identity.DefaultIdentityProvider import org.apache.celeborn.common.network.protocol.SerdeVersion import org.apache.celeborn.common.protocol.{PartitionLocation, PbReviseLostShuffles, PbReviseLostShufflesResponse, TransportModuleConstants} -import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, MapperEnd, ReviseLostShuffles, ReviseLostShufflesResponse} +import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroup, GetReducerFileGroupResponse, MapperEnd, ReviseLostShuffles, ReviseLostShufflesResponse} import org.apache.celeborn.common.protocol.message.StatusCode class UtilsSuite extends CelebornFunSuite { @@ -248,15 +251,101 @@ class UtilsSuite extends CelebornFunSuite { fileGroup.put(2, partitionLocation(2)) val attempts = Array(0, 0, 1) - val response = GetReducerFileGroupResponse(StatusCode.STAGE_ENDED, fileGroup, attempts) - val responseTrans = Utils.fromTransportMessage(Utils.toTransportMessage(response)).asInstanceOf[ - GetReducerFileGroupResponse] - - assert(response.status == responseTrans.status) - assert(util.Arrays.equals(response.attempts, responseTrans.attempts)) - val set = - (response.fileGroup.values().toArray diff responseTrans.fileGroup.values().toArray).toSet - assert(set.size == 0) + Seq(SerdeVersion.V1, SerdeVersion.V2).foreach { serdeVersion => + val response = GetReducerFileGroupResponse( + StatusCode.STAGE_ENDED, + fileGroup, + attempts, + serdeVersion = serdeVersion, + startPartition = 1, + endPartition = 3, + hasPartitionRange = true) + val responseTrans = + Utils.fromTransportMessage(Utils.toTransportMessage(response)).asInstanceOf[ + GetReducerFileGroupResponse] + + assert(response.status == responseTrans.status) + assert(util.Arrays.equals(response.attempts, responseTrans.attempts)) + assert(responseTrans.serdeVersion == serdeVersion) + assert(responseTrans.startPartition == 1) + assert(responseTrans.endPartition == 3) + assert(responseTrans.hasPartitionRange) + val set = + (response.fileGroup.values().toArray diff responseTrans.fileGroup.values().toArray).toSet + assert(set.size == 0) + } + } + + test("GetReducerFileGroup partition range converts with pb") { + Seq(SerdeVersion.V1, SerdeVersion.V2).foreach { serdeVersion => + val request = GetReducerFileGroup( + 7, + isSegmentGranularityVisible = false, + serdeVersion, + startPartition = 11, + endPartition = 19, + hasPartitionRange = true, + omitMapAttempts = true) + val converted = Utils.fromTransportMessage(Utils.toTransportMessage(request)).asInstanceOf[ + GetReducerFileGroup] + + assert(converted.shuffleId == 7) + assert(converted.serdeVersion == serdeVersion) + assert(converted.startPartition == 11) + assert(converted.endPartition == 19) + assert(converted.hasPartitionRange) + assert(converted.omitMapAttempts) + } + + val legacy = GetReducerFileGroup(7, false, SerdeVersion.V1) + val convertedLegacy = Utils.fromTransportMessage(Utils.toTransportMessage(legacy)).asInstanceOf[ + GetReducerFileGroup] + assert(!convertedLegacy.hasPartitionRange) + assert(!convertedLegacy.omitMapAttempts) + } + + test("retry sleep preserves interruption") { + val retryStarted = new CountDownLatch(1) + val thrown = new AtomicReference[Throwable]() + val interruptRestored = new AtomicBoolean(false) + val retryThread = new Thread(new Runnable { + override def run(): Unit = { + try { + Utils.withRetryOnTimeoutOrIOException(Int.MaxValue, Int.MaxValue.toLong) { + retryStarted.countDown() + throw new IOException("retry") + } + } catch { + case t: Throwable => + thrown.set(t) + interruptRestored.set(Thread.currentThread().isInterrupted) + } + } + }) + + retryThread.start() + try { + assert(retryStarted.await(10, TimeUnit.SECONDS)) + val deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10) + while (retryThread.isAlive && + retryThread.getState != Thread.State.TIMED_WAITING && + System.nanoTime() < deadline) { + Thread.sleep(1) + } + assert(retryThread.getState == Thread.State.TIMED_WAITING) + + retryThread.interrupt() + retryThread.join(TimeUnit.SECONDS.toMillis(10)) + + assert(!retryThread.isAlive) + assert(thrown.get().isInstanceOf[InterruptedException]) + assert(interruptRestored.get()) + } finally { + if (retryThread.isAlive) { + retryThread.interrupt() + retryThread.join(TimeUnit.SECONDS.toMillis(10)) + } + } } test("validate number of client/server netty threads") { diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 5d15c6859fb..1e30ba6e46e 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -137,8 +137,8 @@ license: | | celeborn.client.spark.shuffle.fallback.numPartitionsThreshold | 2147483647 | false | Celeborn will only accept shuffle of partition number lower than this configuration value. This configuration only takes effect when `celeborn.client.spark.shuffle.fallback.policy` is `AUTO`. | 0.5.0 | celeborn.shuffle.forceFallback.numPartitionsThreshold,celeborn.client.spark.shuffle.forceFallback.numPartitionsThreshold | | celeborn.client.spark.shuffle.fallback.policy | AUTO | false | Celeborn supports the following kind of fallback policies. 1. ALWAYS: always use spark built-in shuffle implementation; 2. AUTO: prefer to use celeborn shuffle implementation, and fallback to use spark built-in shuffle implementation based on certain factors, e.g. availability of enough workers and quota, shuffle partition number; 3. NEVER: always use celeborn shuffle implementation, and fail fast when it it is concluded that fallback is required based on factors above. | 0.5.0 | | | celeborn.client.spark.shuffle.forceFallback.enabled | false | false | Always use spark built-in shuffle implementation. This configuration is deprecated, consider configuring `celeborn.client.spark.shuffle.fallback.policy` instead. | 0.3.0 | celeborn.shuffle.forceFallback.enabled | -| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false | false | Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. If the response size is large and Spark executor number is large, the Spark driver network may be exhausted because each executor will pull the response from the driver. With broadcasting GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). | 0.6.0 | | -| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k | false | The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors. | 0.6.0 | | +| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false | false | Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. This applies only to legacy requests for shuffle-wide reducer metadata; partition-range requests return their smaller task-specific response directly. If the response size is large and Spark executor number is large, the Spark driver network may be exhausted because each executor will pull the response from the driver. With broadcasting GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). | 0.6.0 | | +| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k | false | The size at which we use Broadcast to send a legacy shuffle-wide GetReducerFileGroupResponse to the executors. | 0.6.0 | | | celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | celeborn.shuffle.writer | | celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable stage rerun. If true, client throws FetchFailedException instead of CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure | | celeborn.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.6.0 | celeborn.quota.identity.provider | diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala index afc2956acf3..113f30de997 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala @@ -67,7 +67,7 @@ class CelebornHashSuite extends AnyFunSuite celebornSparkSession.stop() } - test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") { + test("celeborn spark integration test - scoped reducer metadata bypasses broadcast") { SparkUtils.getReducerFileGroupResponseBroadcasts.clear() SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0) val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]") @@ -99,7 +99,7 @@ class CelebornHashSuite extends AnyFunSuite assert(repartitionResult.equals(celebornRepartitionResult)) assert(combineResult.equals(celebornCombineResult)) assert(sqlResult.equals(celebornSqlResult)) - assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0) + assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() == 0) celebornSparkSession.stop() SparkUtils.getReducerFileGroupResponseBroadcasts.clear() diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala index e4f6cc93b5d..1b5349b659e 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala @@ -68,7 +68,7 @@ class CelebornSortSuite extends AnyFunSuite celebornSparkSession.stop() } - test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") { + test("celeborn spark integration test - scoped reducer metadata bypasses broadcast") { SparkUtils.getReducerFileGroupResponseBroadcasts.clear() SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0) val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]") @@ -102,7 +102,7 @@ class CelebornSortSuite extends AnyFunSuite assert(repartitionResult.equals(celebornRepartitionResult)) assert(combineResult.equals(celebornCombineResult)) assert(sqlResult.equals(celebornSqlResult)) - assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0) + assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() == 0) celebornSparkSession.stop() SparkUtils.getReducerFileGroupResponseBroadcasts.clear()